summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py60
-rw-r--r--Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py5
-rwxr-xr-xWrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py25
-rwxr-xr-xWrappers/Python/wip/pdhg_TV_denoising.py146
-rw-r--r--Wrappers/Python/wip/pdhg_TV_denoising3D.py360
-rw-r--r--Wrappers/Python/wip/pdhg_TV_tomography2D.py47
6 files changed, 497 insertions, 146 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
index 439149c..5e92767 100644
--- a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
@@ -126,10 +126,6 @@ def PDHG_old(f, g, operator, tau = None, sigma = None, opt = None, **kwargs):
show_iter = opt['show_iter'] if 'show_iter' in opt.keys() else False
stop_crit = opt['stop_crit'] if 'stop_crit' in opt.keys() else False
- if memopt:
- print ("memopt")
- else:
- print("no memopt")
x_old = operator.domain_geometry().allocate()
y_old = operator.range_geometry().allocate()
@@ -183,65 +179,13 @@ def PDHG_old(f, g, operator, tau = None, sigma = None, opt = None, **kwargs):
g.proximal(x_tmp, tau, out = x)
- xbar = x - x_old
+ x.subtract(x_old, out=xbar)
xbar *= theta
xbar += x
-
-
+
x_old.fill(x)
y_old.fill(y)
-
-# pass
-#
-## # Gradient descent, Dual problem solution
-## y_tmp = y_old + sigma * operator.direct(xbar)
-# y_tmp = operator.direct(xbar)
-# y_tmp *= sigma
-# y_tmp +=y_old
-#
-# y = f.proximal_conjugate(y_tmp, sigma)
-## f.proximal_conjugate(y_tmp, sigma, out = y)
-#
-# # Gradient ascent, Primal problem solution
-## x_tmp = x_old - tau * operator.adjoint(y)
-#
-# x_tmp = operator.adjoint(y)
-# x_tmp *=-tau
-# x_tmp +=x_old
-#
-# x = g.proximal(x_tmp, tau)
-## g.proximal(x_tmp, tau, out = x)
-#
-# #Update
-## xbar = x + theta * (x - x_old)
-# xbar = x - x_old
-# xbar *= theta
-# xbar += x
-#
-# x_old = x
-# y_old = y
-#
-## operator.direct(xbar, out = y_tmp)
-## y_tmp *= sigma
-## y_tmp +=y_old
-# if isinstance(f, FunctionOperatorComposition):
-# p1 = f(x) + g(x)
-# else:
-# p1 = f(operator.direct(x)) + g(x)
-# d1 = -(f.convex_conjugate(y) + g(-1*operator.adjoint(y)))
-# pd1 = p1 - d1
-
-# primal.append(p1)
-# dual.append(d1)
-# pdgap.append(pd1)
-
-# if i%100==0:
-# print(p1, d1, pd1)
-# if isinstance(f, FunctionOperatorComposition):
-# p1 = f(x) + g(x)
-# else:
-
t_end = time.time()
diff --git a/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py b/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py
index 7397cfb..2d0a00a 100644
--- a/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py
@@ -116,9 +116,10 @@ class L2NormSquared(Function):
return x/(1 + tau/2)
else:
if self.b is not None:
- out.fill( (x - tau*self.b)/(1 + tau/2) )
+ x.subtract(tau*self.b, out=out)
+ out.divide(1+tau/2, out=out)
else:
- out.fill( x/(1 + tau/2) )
+ x.divide(1 + tau/2, out=out)
def __rmul__(self, scalar):
diff --git a/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py b/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py
index f524c5f..3541bc2 100755
--- a/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py
@@ -94,19 +94,22 @@ class MixedL21Norm(Function):
else:
if out is None:
-# tmp = [ el*el for el in x.containers]
-# res = sum(tmp).sqrt().maximum(1.0)
-# frac = [el/res for el in x.containers]
-# res = BlockDataContainer(*frac)
-# return res
-
- return x.divide(x.pnorm().maximum(1.0))
+ tmp = [ el*el for el in x.containers]
+ res = sum(tmp).sqrt().maximum(1.0)
+ frac = [el/res for el in x.containers]
+ return BlockDataContainer(*frac)
+
+ #TODO this is slow, why???
+# return x.divide(x.pnorm().maximum(1.0))
else:
-# res1 = functools.reduce(lambda a,b: a + b*b, x.containers, x.get_item(0) * 0 )
-# res = res1.sqrt().maximum(1.0)
-# x.divide(res, out=out)
- x.divide(x.pnorm().maximum(1.0), out=out)
+ res1 = functools.reduce(lambda a,b: a + b*b, x.containers, x.get_item(0) * 0 )
+ res = res1.sqrt().maximum(1.0)
+ x.divide(res, out=out)
+
+# x.divide(sum([el*el for el in x.containers]).sqrt().maximum(1.0), out=out)
+ #TODO this is slow, why ???
+# x.divide(x.norm().maximum(1.0), out=out)
def __rmul__(self, scalar):
diff --git a/Wrappers/Python/wip/pdhg_TV_denoising.py b/Wrappers/Python/wip/pdhg_TV_denoising.py
index d885bca..e142d94 100755
--- a/Wrappers/Python/wip/pdhg_TV_denoising.py
+++ b/Wrappers/Python/wip/pdhg_TV_denoising.py
@@ -27,7 +27,7 @@ def dt(steps):
# Create phantom for TV denoising
-N = 200
+N = 500
data = np.zeros((N,N))
data[round(N/4):round(3*N/4),round(N/4):round(3*N/4)] = 0.5
@@ -40,8 +40,8 @@ ag = ig
n1 = random_noise(data, mode = 'gaussian', mean=0, var = 0.05, seed=10)
noisy_data = ImageData(n1)
-#plt.imshow(noisy_data.as_array())
-#plt.show()
+plt.imshow(noisy_data.as_array())
+plt.show()
#%%
@@ -82,7 +82,6 @@ else:
# Compute operator Norm
normK = operator.norm()
-print ("normK", normK)
# Primal & dual stepsizes
sigma = 1
@@ -91,54 +90,113 @@ tau = 1/(sigma*normK**2)
opt = {'niter':2000}
opt1 = {'niter':2000, 'memopt': True}
-#t1 = timer()
-#res, time, primal, dual, pdgap = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt)
-#print(timer()-t1)
-#
-#print("with memopt \n")
-#
-#t2 = timer()
-#res1, time1, primal1, dual1, pdgap1 = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt1)
-#print(timer()-t2)
-
-pdhg = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma)
-pdhg.max_iteration = 2000
-pdhg.update_objective_interval = 100
-
+t1 = timer()
+res, time, primal, dual, pdgap = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt)
+t2 = timer()
-pdhgo = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma, memopt=True)
-pdhgo.max_iteration = 2000
-pdhgo.update_objective_interval = 100
-steps = [timer()]
-pdhgo.run(2000)
-steps.append(timer())
-t1 = dt(steps)
-
-pdhg.run(2000)
-steps.append(timer())
-t2 = dt(steps)
+t3 = timer()
+res1, time1, primal1, dual1, pdgap1 = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt1)
+t4 = timer()
+#
+print ("No memopt in {}s, memopt in {}s ".format(t2-t1, t4 -t3))
-print ("Time difference {}s {}s {}s Speedup {:.2f}".format(t1,t2,t2-t1, t2/t1))
-res = pdhg.get_output()
-res1 = pdhgo.get_output()
+#
+#%%
+#pdhg = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma)
+#pdhg.max_iteration = 2000
+#pdhg.update_objective_interval = 100
+##
+#pdhgo = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma, memopt=True)
+#pdhgo.max_iteration = 2000
+#pdhgo.update_objective_interval = 100
+##
+#steps = [timer()]
+#pdhgo.run(2000)
+#steps.append(timer())
+#t1 = dt(steps)
+##
+#pdhg.run(2000)
+#steps.append(timer())
+#t2 = dt(steps)
+#
+#print ("Time difference {}s {}s {}s Speedup {:.2f}".format(t1,t2,t2-t1, t2/t1))
+#res = pdhg.get_output()
+#res1 = pdhgo.get_output()
-diff = (res-res1)
-print ("diff norm {} max {}".format(diff.norm(), diff.abs().as_array().max()))
-print ("Sum ( abs(diff) ) {}".format(diff.abs().sum()))
+#%%
+#plt.figure(figsize=(15,15))
+#plt.subplot(3,1,1)
+#plt.imshow(res.as_array())
+#plt.title('no memopt')
+#plt.colorbar()
+#plt.subplot(3,1,2)
+#plt.imshow(res1.as_array())
+#plt.title('memopt')
+#plt.colorbar()
+#plt.subplot(3,1,3)
+#plt.imshow((res1 - res).abs().as_array())
+#plt.title('diff')
+#plt.colorbar()
+#plt.show()
-plt.figure(figsize=(5,5))
-plt.subplot(1,3,1)
-plt.imshow(res.as_array())
-plt.colorbar()
+#plt.figure(figsize=(15,15))
+#plt.subplot(3,1,1)
+#plt.imshow(pdhg.get_output().as_array())
+#plt.title('no memopt class')
+#plt.colorbar()
+#plt.subplot(3,1,2)
+#plt.imshow(res.as_array())
+#plt.title('no memopt')
+#plt.colorbar()
+#plt.subplot(3,1,3)
+#plt.imshow((pdhg.get_output() - res).abs().as_array())
+#plt.title('diff')
+#plt.colorbar()
#plt.show()
-
-#plt.figure(figsize=(5,5))
-plt.subplot(1,3,2)
-plt.imshow(res1.as_array())
-plt.colorbar()
+#
+#
+#
+#plt.figure(figsize=(15,15))
+#plt.subplot(3,1,1)
+#plt.imshow(pdhgo.get_output().as_array())
+#plt.title('memopt class')
+#plt.colorbar()
+#plt.subplot(3,1,2)
+#plt.imshow(res1.as_array())
+#plt.title('no memopt')
+#plt.colorbar()
+#plt.subplot(3,1,3)
+#plt.imshow((pdhgo.get_output() - res1).abs().as_array())
+#plt.title('diff')
+#plt.colorbar()
+#plt.show()
+
+
+
+
+# print ("Time difference {}s {}s {}s Speedup {:.2f}".format(t1,t2,t2-t1, t2/t1))
+# res = pdhg.get_output()
+# res1 = pdhgo.get_output()
+#
+# diff = (res-res1)
+# print ("diff norm {} max {}".format(diff.norm(), diff.abs().as_array().max()))
+# print ("Sum ( abs(diff) ) {}".format(diff.abs().sum()))
+#
+#
+# plt.figure(figsize=(5,5))
+# plt.subplot(1,3,1)
+# plt.imshow(res.as_array())
+# plt.colorbar()
+# #plt.show()
+#
+# #plt.figure(figsize=(5,5))
+# plt.subplot(1,3,2)
+# plt.imshow(res1.as_array())
+# plt.colorbar()
+
#plt.show()
diff --git a/Wrappers/Python/wip/pdhg_TV_denoising3D.py b/Wrappers/Python/wip/pdhg_TV_denoising3D.py
new file mode 100644
index 0000000..06ecfa2
--- /dev/null
+++ b/Wrappers/Python/wip/pdhg_TV_denoising3D.py
@@ -0,0 +1,360 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Created on Fri Feb 22 14:53:03 2019
+
+@author: evangelos
+"""
+
+from ccpi.framework import ImageData, ImageGeometry, BlockDataContainer
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+from ccpi.optimisation.algorithms import PDHG, PDHG_old
+
+from ccpi.optimisation.operators import BlockOperator, Identity, Gradient
+from ccpi.optimisation.functions import ZeroFunction, L2NormSquared, \
+ MixedL21Norm, FunctionOperatorComposition, BlockFunction
+
+from skimage.util import random_noise
+
+from timeit import default_timer as timer
+def dt(steps):
+ return steps[-1] - steps[-2]
+
+#%%
+
+# Create phantom for TV denoising
+
+import timeit
+import os
+from tomophantom import TomoP3D
+import tomophantom
+
+print ("Building 3D phantom using TomoPhantom software")
+tic=timeit.default_timer()
+model = 13 # select a model number from the library
+N_size = 64 # Define phantom dimensions using a scalar value (cubic phantom)
+path = os.path.dirname(tomophantom.__file__)
+path_library3D = os.path.join(path, "Phantom3DLibrary.dat")
+#This will generate a N_size x N_size x N_size phantom (3D)
+phantom_tm = TomoP3D.Model(model, N_size, path_library3D)
+#toc=timeit.default_timer()
+#Run_time = toc - tic
+#print("Phantom has been built in {} seconds".format(Run_time))
+#
+#sliceSel = int(0.5*N_size)
+##plt.gray()
+#plt.figure()
+#plt.subplot(131)
+#plt.imshow(phantom_tm[sliceSel,:,:],vmin=0, vmax=1)
+#plt.title('3D Phantom, axial view')
+#
+#plt.subplot(132)
+#plt.imshow(phantom_tm[:,sliceSel,:],vmin=0, vmax=1)
+#plt.title('3D Phantom, coronal view')
+#
+#plt.subplot(133)
+#plt.imshow(phantom_tm[:,:,sliceSel],vmin=0, vmax=1)
+#plt.title('3D Phantom, sagittal view')
+#plt.show()
+
+#%%
+
+N = N_size
+ig = ImageGeometry(voxel_num_x=N, voxel_num_y=N, voxel_num_z=N)
+
+n1 = random_noise(phantom_tm, mode = 'gaussian', mean=0, var = 0.001, seed=10)
+noisy_data = ImageData(n1)
+#plt.imshow(noisy_data.as_array()[:,:,32])
+
+#%%
+
+# Regularisation Parameter
+alpha = 0.02
+
+#method = input("Enter structure of PDHG (0=Composite or 1=NotComposite): ")
+method = '0'
+
+if method == '0':
+
+ # Create operators
+ op1 = Gradient(ig)
+ op2 = Identity(ig)
+
+ # Form Composite Operator
+ operator = BlockOperator(op1, op2, shape=(2,1) )
+
+ #### Create functions
+
+ f1 = alpha * MixedL21Norm()
+ f2 = 0.5 * L2NormSquared(b = noisy_data)
+ f = BlockFunction(f1, f2)
+
+ g = ZeroFunction()
+
+else:
+
+ ###########################################################################
+ # No Composite #
+ ###########################################################################
+ operator = Gradient(ig)
+ f = alpha * FunctionOperatorComposition(operator, MixedL21Norm())
+ g = L2NormSquared(b = noisy_data)
+
+ ###########################################################################
+#%%
+
+# Compute operator Norm
+normK = operator.norm()
+
+# Primal & dual stepsizes
+sigma = 1
+tau = 1/(sigma*normK**2)
+
+opt = {'niter':2000}
+opt1 = {'niter':2000, 'memopt': True}
+
+#t1 = timer()
+#res, time, primal, dual, pdgap = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt)
+#t2 = timer()
+
+
+t3 = timer()
+res1, time1, primal1, dual1, pdgap1 = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt1)
+t4 = timer()
+
+#import cProfile
+#cProfile.run('res1, time1, primal1, dual1, pdgap1 = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt1) ')
+###
+print ("No memopt in {}s, memopt in {}s ".format(t2-t1, t4 -t3))
+#
+##
+##%%
+#
+#plt.figure(figsize=(10,10))
+#plt.subplot(311)
+#plt.imshow(res1.as_array()[sliceSel,:,:])
+#plt.colorbar()
+#plt.title('3D Phantom, axial view')
+#
+#plt.subplot(312)
+#plt.imshow(res1.as_array()[:,sliceSel,:])
+#plt.colorbar()
+#plt.title('3D Phantom, coronal view')
+#
+#plt.subplot(313)
+#plt.imshow(res1.as_array()[:,:,sliceSel])
+#plt.colorbar()
+#plt.title('3D Phantom, sagittal view')
+#plt.show()
+#
+#plt.figure(figsize=(10,10))
+#plt.subplot(311)
+#plt.imshow(res.as_array()[sliceSel,:,:])
+#plt.colorbar()
+#plt.title('3D Phantom, axial view')
+#
+#plt.subplot(312)
+#plt.imshow(res.as_array()[:,sliceSel,:])
+#plt.colorbar()
+#plt.title('3D Phantom, coronal view')
+#
+#plt.subplot(313)
+#plt.imshow(res.as_array()[:,:,sliceSel])
+#plt.colorbar()
+#plt.title('3D Phantom, sagittal view')
+#plt.show()
+#
+#diff = (res1 - res).abs()
+#
+#plt.figure(figsize=(10,10))
+#plt.subplot(311)
+#plt.imshow(diff.as_array()[sliceSel,:,:])
+#plt.colorbar()
+#plt.title('3D Phantom, axial view')
+#
+#plt.subplot(312)
+#plt.imshow(diff.as_array()[:,sliceSel,:])
+#plt.colorbar()
+#plt.title('3D Phantom, coronal view')
+#
+#plt.subplot(313)
+#plt.imshow(diff.as_array()[:,:,sliceSel])
+#plt.colorbar()
+#plt.title('3D Phantom, sagittal view')
+#plt.show()
+#
+#
+#
+#
+##%%
+#pdhg = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma)
+#pdhg.max_iteration = 2000
+#pdhg.update_objective_interval = 100
+####
+#pdhgo = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma, memopt=True)
+#pdhgo.max_iteration = 2000
+#pdhgo.update_objective_interval = 100
+####
+#steps = [timer()]
+#pdhgo.run(2000)
+#steps.append(timer())
+#t1 = dt(steps)
+##
+#pdhg.run(2000)
+#steps.append(timer())
+#t2 = dt(steps)
+#
+#print ("Time difference {}s {}s {}s Speedup {:.2f}".format(t1,t2,t2-t1, t2/t1))
+#res = pdhg.get_output()
+#res1 = pdhgo.get_output()
+
+#%%
+#plt.figure(figsize=(15,15))
+#plt.subplot(3,1,1)
+#plt.imshow(res.as_array())
+#plt.title('no memopt')
+#plt.colorbar()
+#plt.subplot(3,1,2)
+#plt.imshow(res1.as_array())
+#plt.title('memopt')
+#plt.colorbar()
+#plt.subplot(3,1,3)
+#plt.imshow((res1 - res).abs().as_array())
+#plt.title('diff')
+#plt.colorbar()
+#plt.show()
+
+
+#plt.figure(figsize=(15,15))
+#plt.subplot(3,1,1)
+#plt.imshow(pdhg.get_output().as_array())
+#plt.title('no memopt class')
+#plt.colorbar()
+#plt.subplot(3,1,2)
+#plt.imshow(res.as_array())
+#plt.title('no memopt')
+#plt.colorbar()
+#plt.subplot(3,1,3)
+#plt.imshow((pdhg.get_output() - res).abs().as_array())
+#plt.title('diff')
+#plt.colorbar()
+#plt.show()
+#
+#
+#
+#plt.figure(figsize=(15,15))
+#plt.subplot(3,1,1)
+#plt.imshow(pdhgo.get_output().as_array())
+#plt.title('memopt class')
+#plt.colorbar()
+#plt.subplot(3,1,2)
+#plt.imshow(res1.as_array())
+#plt.title('no memopt')
+#plt.colorbar()
+#plt.subplot(3,1,3)
+#plt.imshow((pdhgo.get_output() - res1).abs().as_array())
+#plt.title('diff')
+#plt.colorbar()
+#plt.show()
+
+
+
+
+
+# print ("Time difference {}s {}s {}s Speedup {:.2f}".format(t1,t2,t2-t1, t2/t1))
+# res = pdhg.get_output()
+# res1 = pdhgo.get_output()
+#
+# diff = (res-res1)
+# print ("diff norm {} max {}".format(diff.norm(), diff.abs().as_array().max()))
+# print ("Sum ( abs(diff) ) {}".format(diff.abs().sum()))
+#
+#
+# plt.figure(figsize=(5,5))
+# plt.subplot(1,3,1)
+# plt.imshow(res.as_array())
+# plt.colorbar()
+# #plt.show()
+#
+# #plt.figure(figsize=(5,5))
+# plt.subplot(1,3,2)
+# plt.imshow(res1.as_array())
+# plt.colorbar()
+
+#plt.show()
+
+
+
+#=======
+## opt = {'niter':2000, 'memopt': True}
+#
+## res, time, primal, dual, pdgap = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt)
+#
+#>>>>>>> origin/pdhg_fix
+#
+#
+## opt = {'niter':2000, 'memopt': False}
+## res1, time1, primal1, dual1, pdgap1 = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt)
+#
+## plt.figure(figsize=(5,5))
+## plt.subplot(1,3,1)
+## plt.imshow(res.as_array())
+## plt.title('memopt')
+## plt.colorbar()
+## plt.subplot(1,3,2)
+## plt.imshow(res1.as_array())
+## plt.title('no memopt')
+## plt.colorbar()
+## plt.subplot(1,3,3)
+## plt.imshow((res1 - res).abs().as_array())
+## plt.title('diff')
+## plt.colorbar()
+## plt.show()
+#pdhg = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma)
+#pdhg.max_iteration = 2000
+#pdhg.update_objective_interval = 100
+#
+#
+#pdhgo = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma, memopt=True)
+#pdhgo.max_iteration = 2000
+#pdhgo.update_objective_interval = 100
+#
+#steps = [timer()]
+#pdhgo.run(200)
+#steps.append(timer())
+#t1 = dt(steps)
+#
+#pdhg.run(200)
+#steps.append(timer())
+#t2 = dt(steps)
+#
+#print ("Time difference {} {} {}".format(t1,t2,t2-t1))
+#sol = pdhg.get_output().as_array()
+##sol = result.as_array()
+##
+#fig = plt.figure()
+#plt.subplot(1,3,1)
+#plt.imshow(noisy_data.as_array())
+#plt.colorbar()
+#plt.subplot(1,3,2)
+#plt.imshow(sol)
+#plt.colorbar()
+#plt.subplot(1,3,3)
+#plt.imshow(pdhgo.get_output().as_array())
+#plt.colorbar()
+#
+#plt.show()
+###
+##
+####
+##plt.plot(np.linspace(0,N,N), data[int(N/2),:], label = 'GTruth')
+##plt.plot(np.linspace(0,N,N), sol[int(N/2),:], label = 'Recon')
+##plt.legend()
+##plt.show()
+#
+#
+##%%
+##
diff --git a/Wrappers/Python/wip/pdhg_TV_tomography2D.py b/Wrappers/Python/wip/pdhg_TV_tomography2D.py
index e0868f7..3fec34e 100644
--- a/Wrappers/Python/wip/pdhg_TV_tomography2D.py
+++ b/Wrappers/Python/wip/pdhg_TV_tomography2D.py
@@ -56,7 +56,7 @@ detectors = 150
angles = np.linspace(0,np.pi,100)
ag = AcquisitionGeometry('parallel','2D',angles, detectors)
-Aop = AstraProjectorSimple(ig, ag, 'cpu')
+Aop = AstraProjectorSimple(ig, ag, 'gpu')
sin = Aop.direct(data)
plt.imshow(sin.as_array())
@@ -113,43 +113,28 @@ else:
sigma = 1
tau = 1/(sigma*normK**2)
-#pdhg = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma)
-#pdhg.max_iteration = 5000
-#pdhg.update_objective_interval = 250
-#
-#pdhg.run(5000)
-
-opt = {'niter':300}
-opt1 = {'niter':300, 'memopt': True}
+# Compute operator Norm
+normK = operator.norm()
+
+# Primal & dual stepsizes
+sigma = 1
+tau = 1/(sigma*normK**2)
+opt = {'niter':2000}
+opt1 = {'niter':2000, 'memopt': True}
t1 = timer()
res, time, primal, dual, pdgap = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt)
-
-print(timer()-t1)
-plt.figure(figsize=(5,5))
-plt.imshow(res.as_array())
-plt.colorbar()
-plt.show()
-
-#%%
-print("with memopt \n")
-#
t2 = timer()
+
+
+t3 = timer()
res1, time1, primal1, dual1, pdgap1 = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt1)
-#print(timer()-t2)
-#
-#
-plt.figure(figsize=(5,5))
-plt.imshow(res1.as_array())
-plt.colorbar()
-plt.show()
+t4 = timer()
#
-#%%
-plt.figure(figsize=(5,5))
-plt.imshow(np.abs(res1.as_array()-res.as_array()))
-plt.colorbar()
-plt.show()
+print ("No memopt in {}s, memopt in {}s ".format(t2-t1, t4 -t3))
+
+
#%%
#sol = pdhg.get_output().as_array()
#fig = plt.figure()