summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rwxr-xr-xWrappers/Python/demos/PDHG_examples/GatherAll/PDHG_TV_Denoising.py187
1 files changed, 96 insertions, 91 deletions
diff --git a/Wrappers/Python/demos/PDHG_examples/GatherAll/PDHG_TV_Denoising.py b/Wrappers/Python/demos/PDHG_examples/GatherAll/PDHG_TV_Denoising.py
index 0f1effa..c472f36 100755
--- a/Wrappers/Python/demos/PDHG_examples/GatherAll/PDHG_TV_Denoising.py
+++ b/Wrappers/Python/demos/PDHG_examples/GatherAll/PDHG_TV_Denoising.py
@@ -24,15 +24,18 @@
Total Variation Denoising using PDHG algorithm:
-Problem: min_x, x>0 \alpha * ||\nabla x||_{2,1} + ||x-g||_{1}
+Problem: min_{u}, \alpha * ||\nabla u||_{2,1} + Fidelity(u, g)
\alpha: Regularization parameter
\nabla: Gradient operator
- g: Noisy Data with Salt & Pepper Noise
-
-
+ g: Noisy Data
+
+ Fidelity = 1) L2NormSquarred ( \frac{1}{2} * || u - g ||_{2}^{2} ) if Noise is Gaussian
+ 2) L1Norm ( ||u - g||_{1} )if Noise is Salt & Pepper
+ 3) Kullback Leibler (\int u - g * log(u) + Id_{u>0}) if Noise is Poisson
+
Method = 0 ( PDHG - split ) : K = [ \nabla,
Identity]
@@ -40,10 +43,14 @@ Problem: min_x, x>0 \alpha * ||\nabla x||_{2,1} + ||x-g||_{1}
Method = 1 (PDHG - explicit ): K = \nabla
+ Default: ROF denoising
+ noise = Gaussian
+ Fidelity = L2NormSquarred
+ method = 0
+
+
"""
-from ccpi.framework import ImageData, ImageGeometry
-
import numpy as np
import numpy
import matplotlib.pyplot as plt
@@ -65,7 +72,7 @@ else:
if len(sys.argv) > 1:
which_noise = int(sys.argv[1])
else:
- which_noise = 0
+ which_noise = 1
print ("Applying {} noise")
if len(sys.argv) > 2:
@@ -73,32 +80,27 @@ if len(sys.argv) > 2:
else:
method = '0'
print ("method ", method)
-# Create phantom for TV Salt & Pepper denoising
-N = 100
+
loader = TestData(data_dir=os.path.join(sys.prefix, 'share','ccpi'))
-data = loader.load(TestData.SIMPLE_PHANTOM_2D, size=(N,N))
-data = loader.load(TestData.PEPPERS, size=(N,N))
+data = loader.load(TestData.SHAPES)
ig = data.geometry
ag = ig
# Create noisy data.
-# Apply Salt & Pepper noise
-# gaussian
-# poisson
noises = ['gaussian', 'poisson', 's&p']
noise = noises[which_noise]
if noise == 's&p':
n1 = random_noise(data.as_array(), mode = noise, salt_vs_pepper = 0.9, amount=0.2)
elif noise == 'poisson':
- n1 = random_noise(data.as_array(), mode = noise, seed = 10)
+ scale = 5
+ n1 = random_noise( data.as_array()/scale, mode = noise, seed = 10)*scale
elif noise == 'gaussian':
n1 = random_noise(data.as_array(), mode = noise, seed = 10)
else:
raise ValueError('Unsupported Noise ', noise)
noisy_data = ig.allocate()
noisy_data.fill(n1)
-#noisy_data = ImageData(n1)
# Show Ground Truth and Noisy Data
plt.figure(figsize=(10,5))
@@ -112,8 +114,14 @@ plt.title('Noisy Data')
plt.colorbar()
plt.show()
-# Regularisation Parameter
-alpha = .2
+
+# Regularisation Parameter depending on the noise distribution
+if noise == 's&p':
+ alpha = 0.8
+elif noise == 'poisson':
+ alpha = 1
+elif noise == 'gaussian':
+ alpha = .3
# fidelity
if noise == 's&p':
@@ -121,30 +129,25 @@ if noise == 's&p':
elif noise == 'poisson':
f2 = KullbackLeibler(noisy_data)
elif noise == 'gaussian':
- f2 = L2NormSquared(b=noisy_data)
+ f2 = 0.5 * L2NormSquared(b=noisy_data)
if method == '0':
# Create operators
- op1 = Gradient(ig, correlation=Gradient.CORRELATION_SPACECHANNEL)
+ op1 = Gradient(ig, correlation=Gradient.CORRELATION_SPACE)
op2 = Identity(ig, ag)
# Create BlockOperator
operator = BlockOperator(op1, op2, shape=(2,1) )
# Create functions
- f1 = alpha * MixedL21Norm()
- #f2 = L1Norm(b = noisy_data)
- f = BlockFunction(f1, f2)
-
+ f = BlockFunction(alpha * MixedL21Norm(), f2)
g = ZeroFunction()
else:
- # Without the "Block Framework"
operator = Gradient(ig)
f = alpha * MixedL21Norm()
- #g = L1Norm(b = noisy_data)
g = f2
@@ -154,14 +157,15 @@ normK = operator.norm()
# Primal & dual stepsizes
sigma = 1
tau = 1/(sigma*normK**2)
-opt = {'niter':2000, 'memopt': True}
+
# Setup and run the PDHG algorithm
-pdhg = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma, memopt=True)
+pdhg = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma)
pdhg.max_iteration = 2000
-pdhg.update_objective_interval = 50
+pdhg.update_objective_interval = 100
pdhg.run(2000)
+
if data.geometry.channels > 1:
plt.figure(figsize=(20,15))
for row in range(data.geometry.channels):
@@ -179,11 +183,12 @@ if data.geometry.channels > 1:
plt.title('TV Reconstruction')
plt.colorbar()
plt.subplot(3,4,4+row*4)
- plt.plot(np.linspace(0,N,N), data.subset(channel=row).as_array()[int(N/2),:], label = 'GTruth')
- plt.plot(np.linspace(0,N,N), pdhg.get_output().subset(channel=row).as_array()[int(N/2),:], label = 'TV reconstruction')
+ plt.plot(np.linspace(0,ig.shape[1],ig.shape[1]), data.subset(channel=row).as_array()[int(N/2),:], label = 'GTruth')
+ plt.plot(np.linspace(0,ig.shape[1],ig.shape[1]), pdhg.get_output().subset(channel=row).as_array()[int(N/2),:], label = 'TV reconstruction')
plt.legend()
plt.title('Middle Line Profiles')
plt.show()
+
else:
plt.figure(figsize=(20,5))
plt.subplot(1,4,1)
@@ -199,68 +204,68 @@ else:
plt.title('TV Reconstruction')
plt.colorbar()
plt.subplot(1,4,4)
- plt.plot(np.linspace(0,N,N), data.as_array()[int(N/2),:], label = 'GTruth')
- plt.plot(np.linspace(0,N,N), pdhg.get_output().as_array()[int(N/2),:], label = 'TV reconstruction')
+ plt.plot(np.linspace(0,ig.shape[1],ig.shape[1]), data.as_array()[int(ig.shape[0]/2),:], label = 'GTruth')
+ plt.plot(np.linspace(0,ig.shape[1],ig.shape[1]), pdhg.get_output().as_array()[int(ig.shape[0]/2),:], label = 'TV reconstruction')
plt.legend()
plt.title('Middle Line Profiles')
plt.show()
-##%% Check with CVX solution
-
-from ccpi.optimisation.operators import SparseFiniteDiff
-
-try:
- from cvxpy import *
- cvx_not_installable = True
-except ImportError:
- cvx_not_installable = False
-
-
-if cvx_not_installable:
-
- ##Construct problem
- u = Variable(ig.shape)
-
- DY = SparseFiniteDiff(ig, direction=0, bnd_cond='Neumann')
- DX = SparseFiniteDiff(ig, direction=1, bnd_cond='Neumann')
-
- # Define Total Variation as a regulariser
- regulariser = alpha * sum(norm(vstack([DX.matrix() * vec(u), DY.matrix() * vec(u)]), 2, axis = 0))
- fidelity = pnorm( u - noisy_data.as_array(),1)
-
- # choose solver
- if 'MOSEK' in installed_solvers():
- solver = MOSEK
- else:
- solver = SCS
-
- obj = Minimize( regulariser + fidelity)
- prob = Problem(obj)
- result = prob.solve(verbose = True, solver = solver)
-
- diff_cvx = numpy.abs( pdhg.get_output().as_array() - u.value )
-
- plt.figure(figsize=(15,15))
- plt.subplot(3,1,1)
- plt.imshow(pdhg.get_output().as_array())
- plt.title('PDHG solution')
- plt.colorbar()
- plt.subplot(3,1,2)
- plt.imshow(u.value)
- plt.title('CVX solution')
- plt.colorbar()
- plt.subplot(3,1,3)
- plt.imshow(diff_cvx)
- plt.title('Difference')
- plt.colorbar()
- plt.show()
-
- plt.plot(np.linspace(0,N,N), pdhg.get_output().as_array()[int(N/2),:], label = 'PDHG')
- plt.plot(np.linspace(0,N,N), u.value[int(N/2),:], label = 'CVX')
- plt.legend()
- plt.title('Middle Line Profiles')
- plt.show()
-
- print('Primal Objective (CVX) {} '.format(obj.value))
- print('Primal Objective (PDHG) {} '.format(pdhg.objective[-1][0]))
+###%% Check with CVX solution
+#
+#from ccpi.optimisation.operators import SparseFiniteDiff
+#
+#try:
+# from cvxpy import *
+# cvx_not_installable = True
+#except ImportError:
+# cvx_not_installable = False
+#
+#
+#if cvx_not_installable:
+#
+# ##Construct problem
+# u = Variable(ig.shape)
+#
+# DY = SparseFiniteDiff(ig, direction=0, bnd_cond='Neumann')
+# DX = SparseFiniteDiff(ig, direction=1, bnd_cond='Neumann')
+#
+# # Define Total Variation as a regulariser
+# regulariser = alpha * sum(norm(vstack([DX.matrix() * vec(u), DY.matrix() * vec(u)]), 2, axis = 0))
+# fidelity = pnorm( u - noisy_data.as_array(),1)
+#
+# # choose solver
+# if 'MOSEK' in installed_solvers():
+# solver = MOSEK
+# else:
+# solver = SCS
+#
+# obj = Minimize( regulariser + fidelity)
+# prob = Problem(obj)
+# result = prob.solve(verbose = True, solver = solver)
+#
+# diff_cvx = numpy.abs( pdhg.get_output().as_array() - u.value )
+#
+# plt.figure(figsize=(15,15))
+# plt.subplot(3,1,1)
+# plt.imshow(pdhg.get_output().as_array())
+# plt.title('PDHG solution')
+# plt.colorbar()
+# plt.subplot(3,1,2)
+# plt.imshow(u.value)
+# plt.title('CVX solution')
+# plt.colorbar()
+# plt.subplot(3,1,3)
+# plt.imshow(diff_cvx)
+# plt.title('Difference')
+# plt.colorbar()
+# plt.show()
+#
+# plt.plot(np.linspace(0,N,N), pdhg.get_output().as_array()[int(N/2),:], label = 'PDHG')
+# plt.plot(np.linspace(0,N,N), u.value[int(N/2),:], label = 'CVX')
+# plt.legend()
+# plt.title('Middle Line Profiles')
+# plt.show()
+#
+# print('Primal Objective (CVX) {} '.format(obj.value))
+# print('Primal Objective (PDHG) {} '.format(pdhg.objective[-1][0]))