summaryrefslogtreecommitdiffstats
path: root/Wrappers
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-04-01 16:38:12 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2019-04-01 16:38:12 +0100
commit12ccc249a722a64c02d97e8e1513c065d4a7bf48 (patch)
tree09e176542065ce88b2a4ee8582b12c2afd1bc3ef /Wrappers
parentc3ac82e9f3beda552ee8d3e6ee35e4d768851fd7 (diff)
downloadframework-12ccc249a722a64c02d97e8e1513c065d4a7bf48.tar.gz
framework-12ccc249a722a64c02d97e8e1513c065d4a7bf48.tar.bz2
framework-12ccc249a722a64c02d97e8e1513c065d4a7bf48.tar.xz
framework-12ccc249a722a64c02d97e8e1513c065d4a7bf48.zip
updated example with PDHG algorithm class
Diffstat (limited to 'Wrappers')
-rwxr-xr-xWrappers/Python/wip/pdhg_TV_denoising.py38
1 files changed, 18 insertions, 20 deletions
diff --git a/Wrappers/Python/wip/pdhg_TV_denoising.py b/Wrappers/Python/wip/pdhg_TV_denoising.py
index 3819de5..a8e721f 100755
--- a/Wrappers/Python/wip/pdhg_TV_denoising.py
+++ b/Wrappers/Python/wip/pdhg_TV_denoising.py
@@ -19,10 +19,12 @@ from ccpi.optimisation.functions import ZeroFun, L2NormSquared, \
from skimage.util import random_noise
+
+
# ############################################################################
# Create phantom for TV denoising
-N = 200
+N = 600
data = np.zeros((N,N))
data[round(N/4):round(3*N/4),round(N/4):round(3*N/4)] = 0.5
data[round(N/8):round(7*N/8),round(3*N/8):round(5*N/8)] = 1
@@ -38,7 +40,7 @@ noisy_data = ImageData(n1)
#%%
# Regularisation Parameter
-alpha = 200
+alpha = 2
#method = input("Enter structure of PDHG (0=Composite or 1=NotComposite): ")
method = '0'
@@ -79,31 +81,27 @@ print ("normK", normK)
sigma = 1
tau = 1/(sigma*normK**2)
+pdhg = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma)
+pdhg.max_iteration = 2000
+pdhg.update_objective_interval = 10
+
+pdhg.run(2000)
-#%%
-## Number of iterations
-opt = {'niter':2000}
-##
-### Run algorithm
-result, total_time, objective = PDHG(f, g, operator, \
- tau = tau, sigma = sigma, opt = opt)
-#%%
-###Show results
-if isinstance(result, BlockDataContainer):
- sol = result.get_item(0).as_array()
-else:
- sol = result.as_array()
+
+sol = pdhg.get_output().as_array()
#sol = result.as_array()
#
+fig = plt.figure()
+plt.subplot(1,2,1)
+plt.imshow(noisy_data.as_array())
+#plt.colorbar()
+plt.subplot(1,2,2)
plt.imshow(sol)
-plt.colorbar()
+#plt.colorbar()
plt.show()
#
-###
-plt.imshow(noisy_data.as_array())
-plt.colorbar()
-plt.show()
+
##
plt.plot(np.linspace(0,N,N), data[int(N/2),:], label = 'GTruth')
plt.plot(np.linspace(0,N,N), sol[int(N/2),:], label = 'Recon')