summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-04-15 16:50:31 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2019-04-15 16:50:31 +0100
commitbc48b72e1e2cda814aa3c39569eb021fc5f0420d (patch)
tree91cb675152f0d894449561a19c4b60e772b3b9cd
parent5fd55e750e66d0e753cd7dca05c394d353b14b7f (diff)
parent89318446494a491c01b077aed802da5951aed910 (diff)
downloadframework-bc48b72e1e2cda814aa3c39569eb021fc5f0420d.tar.gz
framework-bc48b72e1e2cda814aa3c39569eb021fc5f0420d.tar.bz2
framework-bc48b72e1e2cda814aa3c39569eb021fc5f0420d.tar.xz
framework-bc48b72e1e2cda814aa3c39569eb021fc5f0420d.zip
Merge remote-tracking branch 'origin/composite_operator_datacontainer' into power_method
-rwxr-xr-xWrappers/Python/ccpi/framework/framework.py246
-rw-r--r--Wrappers/Python/ccpi/io/reader.py21
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/CGLS.py3
-rw-r--r--Wrappers/Python/ccpi/optimisation/algorithms/FBPD.py2
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/FISTA.py6
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py2
-rw-r--r--Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py60
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algs.py2
-rw-r--r--Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py5
-rwxr-xr-xWrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py25
-rwxr-xr-xWrappers/Python/test/test_DataContainer.py16
-rw-r--r--Wrappers/Python/test/test_functions.py4
-rwxr-xr-xWrappers/Python/test/test_run_test.py31
-rwxr-xr-xWrappers/Python/wip/pdhg_TV_denoising.py146
-rw-r--r--Wrappers/Python/wip/pdhg_TV_denoising3D.py360
-rw-r--r--Wrappers/Python/wip/pdhg_TV_tomography2D.py47
16 files changed, 694 insertions, 282 deletions
diff --git a/Wrappers/Python/ccpi/framework/framework.py b/Wrappers/Python/ccpi/framework/framework.py
index af4139b..7516447 100755
--- a/Wrappers/Python/ccpi/framework/framework.py
+++ b/Wrappers/Python/ccpi/framework/framework.py
@@ -772,61 +772,18 @@ class DataContainer(object):
class ImageData(DataContainer):
'''DataContainer for holding 2D or 3D DataContainer'''
__container_priority__ = 1
+
+
def __init__(self,
array = None,
deep_copy=False,
dimension_labels=None,
**kwargs):
- self.geometry = None
+ self.geometry = kwargs.get('geometry', None)
if array is None:
- if 'geometry' in kwargs.keys():
- geometry = kwargs['geometry']
- self.geometry = geometry
- channels = geometry.channels
- horiz_x = geometry.voxel_num_x
- horiz_y = geometry.voxel_num_y
- vert = 1 if geometry.voxel_num_z is None\
- else geometry.voxel_num_z # this should be 1 for 2D
- if dimension_labels is None:
- if channels > 1:
- if vert > 1:
- shape = (channels, vert, horiz_y, horiz_x)
- dim_labels = [ImageGeometry.CHANNEL,
- ImageGeometry.VERTICAL,
- ImageGeometry.HORIZONTAL_Y,
- ImageGeometry.HORIZONTAL_X]
- else:
- shape = (channels , horiz_y, horiz_x)
- dim_labels = [ImageGeometry.CHANNEL,
- ImageGeometry.HORIZONTAL_Y,
- ImageGeometry.HORIZONTAL_X]
- else:
- if vert > 1:
- shape = (vert, horiz_y, horiz_x)
- dim_labels = [ImageGeometry.VERTICAL,
- ImageGeometry.HORIZONTAL_Y,
- ImageGeometry.HORIZONTAL_X]
- else:
- shape = (horiz_y, horiz_x)
- dim_labels = [ImageGeometry.HORIZONTAL_Y,
- ImageGeometry.HORIZONTAL_X]
- dimension_labels = dim_labels
- else:
- shape = []
- for dim in dimension_labels:
- if dim == ImageGeometry.CHANNEL:
- shape.append(channels)
- elif dim == ImageGeometry.HORIZONTAL_Y:
- shape.append(horiz_y)
- elif dim == ImageGeometry.VERTICAL:
- shape.append(vert)
- elif dim == ImageGeometry.HORIZONTAL_X:
- shape.append(horiz_x)
- if len(shape) != len(dimension_labels):
- raise ValueError('Missing {0} axes'.format(
- len(dimension_labels) - len(shape)))
- shape = tuple(shape)
+ if self.geometry is not None:
+ shape, dimension_labels = self.get_shape_labels(self.geometry)
array = numpy.zeros( shape , dtype=numpy.float32)
super(ImageData, self).__init__(array, deep_copy,
@@ -836,6 +793,11 @@ class ImageData(DataContainer):
raise ValueError('Please pass either a DataContainer, ' +\
'a numpy array or a geometry')
else:
+ if self.geometry is not None:
+ shape, labels = self.get_shape_labels(self.geometry, dimension_labels)
+ if array.shape != shape:
+ raise ValueError('Shape mismatch {} {}'.format(shape, array.shape))
+
if issubclass(type(array) , DataContainer):
# if the array is a DataContainer get the info from there
if not ( array.number_of_dimensions == 2 or \
@@ -890,78 +852,85 @@ class ImageData(DataContainer):
#out.geometry = self.recalculate_geometry(dimensions , **kw)
out.geometry = self.geometry
return out
-
+
+ def get_shape_labels(self, geometry, dimension_labels=None):
+ channels = geometry.channels
+ horiz_x = geometry.voxel_num_x
+ horiz_y = geometry.voxel_num_y
+ vert = 1 if geometry.voxel_num_z is None\
+ else geometry.voxel_num_z # this should be 1 for 2D
+ if dimension_labels is None:
+ if channels > 1:
+ if vert > 1:
+ shape = (channels, vert, horiz_y, horiz_x)
+ dim_labels = [ImageGeometry.CHANNEL,
+ ImageGeometry.VERTICAL,
+ ImageGeometry.HORIZONTAL_Y,
+ ImageGeometry.HORIZONTAL_X]
+ else:
+ shape = (channels , horiz_y, horiz_x)
+ dim_labels = [ImageGeometry.CHANNEL,
+ ImageGeometry.HORIZONTAL_Y,
+ ImageGeometry.HORIZONTAL_X]
+ else:
+ if vert > 1:
+ shape = (vert, horiz_y, horiz_x)
+ dim_labels = [ImageGeometry.VERTICAL,
+ ImageGeometry.HORIZONTAL_Y,
+ ImageGeometry.HORIZONTAL_X]
+ else:
+ shape = (horiz_y, horiz_x)
+ dim_labels = [ImageGeometry.HORIZONTAL_Y,
+ ImageGeometry.HORIZONTAL_X]
+ dimension_labels = dim_labels
+ else:
+ shape = []
+ for i in range(len(dimension_labels)):
+ dim = dimension_labels[i]
+ if dim == ImageGeometry.CHANNEL:
+ shape.append(channels)
+ elif dim == ImageGeometry.HORIZONTAL_Y:
+ shape.append(horiz_y)
+ elif dim == ImageGeometry.VERTICAL:
+ shape.append(vert)
+ elif dim == ImageGeometry.HORIZONTAL_X:
+ shape.append(horiz_x)
+ if len(shape) != len(dimension_labels):
+ raise ValueError('Missing {0} axes {1} shape {2}'.format(
+ len(dimension_labels) - len(shape), dimension_labels, shape))
+ shape = tuple(shape)
+
+ return (shape, dimension_labels)
+
class AcquisitionData(DataContainer):
'''DataContainer for holding 2D or 3D sinogram'''
__container_priority__ = 1
+
+
def __init__(self,
array = None,
deep_copy=True,
dimension_labels=None,
**kwargs):
- self.geometry = None
+ self.geometry = kwargs.get('geometry', None)
if array is None:
if 'geometry' in kwargs.keys():
geometry = kwargs['geometry']
self.geometry = geometry
- channels = geometry.channels
- horiz = geometry.pixel_num_h
- vert = geometry.pixel_num_v
- angles = geometry.angles
- num_of_angles = numpy.shape(angles)[0]
- if dimension_labels is None:
- if channels > 1:
- if vert > 1:
- shape = (channels, num_of_angles , vert, horiz)
- dim_labels = [AcquisitionGeometry.CHANNEL,
- AcquisitionGeometry.ANGLE,
- AcquisitionGeometry.VERTICAL,
- AcquisitionGeometry.HORIZONTAL]
- else:
- shape = (channels , num_of_angles, horiz)
- dim_labels = [AcquisitionGeometry.CHANNEL,
- AcquisitionGeometry.ANGLE,
- AcquisitionGeometry.HORIZONTAL]
- else:
- if vert > 1:
- shape = (num_of_angles, vert, horiz)
- dim_labels = [AcquisitionGeometry.ANGLE,
- AcquisitionGeometry.VERTICAL,
- AcquisitionGeometry.HORIZONTAL
- ]
- else:
- shape = (num_of_angles, horiz)
- dim_labels = [AcquisitionGeometry.ANGLE,
- AcquisitionGeometry.HORIZONTAL
- ]
-
- dimension_labels = dim_labels
- else:
- shape = []
- for dim in dimension_labels:
- if dim == AcquisitionGeometry.CHANNEL:
- shape.append(channels)
- elif dim == AcquisitionGeometry.ANGLE:
- shape.append(num_of_angles)
- elif dim == AcquisitionGeometry.VERTICAL:
- shape.append(vert)
- elif dim == AcquisitionGeometry.HORIZONTAL:
- shape.append(horiz)
- if len(shape) != len(dimension_labels):
- raise ValueError('Missing {0} axes.\nExpected{1} got {2}'\
- .format(
- len(dimension_labels) - len(shape),
- dimension_labels, shape)
- )
- shape = tuple(shape)
+ shape, dimension_labels = self.get_shape_labels(geometry, dimension_labels)
+
array = numpy.zeros( shape , dtype=numpy.float32)
super(AcquisitionData, self).__init__(array, deep_copy,
dimension_labels, **kwargs)
else:
-
+ if self.geometry is not None:
+ shape, labels = self.get_shape_labels(self.geometry, dimension_labels)
+ if array.shape != shape:
+ raise ValueError('Shape mismatch {} {}'.format(shape, array.shape))
+
if issubclass(type(array) ,DataContainer):
# if the array is a DataContainer get the info from there
if not ( array.number_of_dimensions == 2 or \
@@ -982,19 +951,78 @@ class AcquisitionData(DataContainer):
if dimension_labels is None:
if array.ndim == 4:
- dimension_labels = ['channel' ,'angle' , 'vertical' ,
- 'horizontal']
+ dimension_labels = [AcquisitionGeometry.CHANNEL,
+ AcquisitionGeometry.ANGLE,
+ AcquisitionGeometry.VERTICAL,
+ AcquisitionGeometry.HORIZONTAL]
elif array.ndim == 3:
- dimension_labels = ['angle' , 'vertical' ,
- 'horizontal']
+ dimension_labels = [AcquisitionGeometry.ANGLE,
+ AcquisitionGeometry.VERTICAL,
+ AcquisitionGeometry.HORIZONTAL]
else:
- dimension_labels = ['angle' ,
- 'horizontal']
-
- #DataContainer.__init__(self, array, deep_copy, dimension_labels, **kwargs)
+ dimension_labels = [AcquisitionGeometry.ANGLE,
+ AcquisitionGeometry.HORIZONTAL]
+
super(AcquisitionData, self).__init__(array, deep_copy,
dimension_labels, **kwargs)
+ def get_shape_labels(self, geometry, dimension_labels=None):
+ channels = geometry.channels
+ horiz = geometry.pixel_num_h
+ vert = geometry.pixel_num_v
+ angles = geometry.angles
+ num_of_angles = numpy.shape(angles)[0]
+
+ if dimension_labels is None:
+ if channels > 1:
+ if vert > 1:
+ shape = (channels, num_of_angles , vert, horiz)
+ dim_labels = [AcquisitionGeometry.CHANNEL,
+ AcquisitionGeometry.ANGLE,
+ AcquisitionGeometry.VERTICAL,
+ AcquisitionGeometry.HORIZONTAL]
+ else:
+ shape = (channels , num_of_angles, horiz)
+ dim_labels = [AcquisitionGeometry.CHANNEL,
+ AcquisitionGeometry.ANGLE,
+ AcquisitionGeometry.HORIZONTAL]
+ else:
+ if vert > 1:
+ shape = (num_of_angles, vert, horiz)
+ dim_labels = [AcquisitionGeometry.ANGLE,
+ AcquisitionGeometry.VERTICAL,
+ AcquisitionGeometry.HORIZONTAL
+ ]
+ else:
+ shape = (num_of_angles, horiz)
+ dim_labels = [AcquisitionGeometry.ANGLE,
+ AcquisitionGeometry.HORIZONTAL
+ ]
+
+ dimension_labels = dim_labels
+ else:
+ shape = []
+ for i in range(len(dimension_labels)):
+ dim = dimension_labels[i]
+
+ if dim == AcquisitionGeometry.CHANNEL:
+ shape.append(channels)
+ elif dim == AcquisitionGeometry.ANGLE:
+ shape.append(num_of_angles)
+ elif dim == AcquisitionGeometry.VERTICAL:
+ shape.append(vert)
+ elif dim == AcquisitionGeometry.HORIZONTAL:
+ shape.append(horiz)
+ if len(shape) != len(dimension_labels):
+ raise ValueError('Missing {0} axes.\nExpected{1} got {2}'\
+ .format(
+ len(dimension_labels) - len(shape),
+ dimension_labels, shape)
+ )
+ shape = tuple(shape)
+ return (shape, dimension_labels)
+
+
class DataProcessor(object):
diff --git a/Wrappers/Python/ccpi/io/reader.py b/Wrappers/Python/ccpi/io/reader.py
index 856f5e0..07e3bf9 100644
--- a/Wrappers/Python/ccpi/io/reader.py
+++ b/Wrappers/Python/ccpi/io/reader.py
@@ -241,26 +241,37 @@ class NexusReader(object):
pass
dims = file[self.data_path].shape
if ymin is None and ymax is None:
- data = np.array(file[self.data_path])
+
+ try:
+ image_keys = self.get_image_keys()
+ print ("image_keys", image_keys)
+ projections = np.array(file[self.data_path])
+ data = projections[image_keys==0]
+ except KeyError as ke:
+ print (ke)
+ data = np.array(file[self.data_path])
+
else:
+ image_keys = self.get_image_keys()
+ print ("image_keys", image_keys)
+ projections = np.array(file[self.data_path])[image_keys==0]
if ymin is None:
ymin = 0
if ymax > dims[1]:
raise ValueError('ymax out of range')
- data = np.array(file[self.data_path][:,:ymax,:])
+ data = projections[:,:ymax,:]
elif ymax is None:
ymax = dims[1]
if ymin < 0:
raise ValueError('ymin out of range')
- data = np.array(file[self.data_path][:,ymin:,:])
+ data = projections[:,ymin:,:]
else:
if ymax > dims[1]:
raise ValueError('ymax out of range')
if ymin < 0:
raise ValueError('ymin out of range')
- data = np.array(file[self.data_path]
- [: , ymin:ymax , :] )
+ data = projections[: , ymin:ymax , :]
except:
print("Error reading nexus file")
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
index 7194eb8..e65bc89 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
@@ -23,7 +23,6 @@ Created on Thu Feb 21 11:11:23 2019
"""
from ccpi.optimisation.algorithms import Algorithm
-#from collections.abc import Iterable
class CGLS(Algorithm):
'''Conjugate Gradient Least Squares algorithm
@@ -84,4 +83,4 @@ class CGLS(Algorithm):
self.d = s + beta*self.d
def update_objective(self):
- self.loss.append(self.r.squared_norm()) \ No newline at end of file
+ self.loss.append(self.r.squared_norm())
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/FBPD.py b/Wrappers/Python/ccpi/optimisation/algorithms/FBPD.py
index 445ba7a..aa07359 100644
--- a/Wrappers/Python/ccpi/optimisation/algorithms/FBPD.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/FBPD.py
@@ -23,7 +23,7 @@ Created on Thu Feb 21 11:09:03 2019
"""
from ccpi.optimisation.algorithms import Algorithm
-from ccpi.optimisation.functions import ZeroFun
+from ccpi.optimisation.functions import ZeroFunction
class FBPD(Algorithm):
'''FBPD Algorithm
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
index 93ba178..064cb33 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
@@ -6,7 +6,7 @@ Created on Thu Feb 21 11:07:30 2019
"""
from ccpi.optimisation.algorithms import Algorithm
-from ccpi.optimisation.functions import ZeroFun
+from ccpi.optimisation.functions import ZeroFunction
import numpy
class FISTA(Algorithm):
@@ -46,11 +46,11 @@ class FISTA(Algorithm):
# default inputs
if f is None:
- self.f = ZeroFun()
+ self.f = ZeroFunction()
else:
self.f = f
if g is None:
- g = ZeroFun()
+ g = ZeroFunction()
self.g = g
else:
self.g = g
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py b/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py
index f1e4132..14763c5 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py
@@ -73,4 +73,4 @@ class GradientDescent(Algorithm):
def update_objective(self):
self.loss.append(self.objective_function(self.x))
- \ No newline at end of file
+
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
index 439149c..5e92767 100644
--- a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
@@ -126,10 +126,6 @@ def PDHG_old(f, g, operator, tau = None, sigma = None, opt = None, **kwargs):
show_iter = opt['show_iter'] if 'show_iter' in opt.keys() else False
stop_crit = opt['stop_crit'] if 'stop_crit' in opt.keys() else False
- if memopt:
- print ("memopt")
- else:
- print("no memopt")
x_old = operator.domain_geometry().allocate()
y_old = operator.range_geometry().allocate()
@@ -183,65 +179,13 @@ def PDHG_old(f, g, operator, tau = None, sigma = None, opt = None, **kwargs):
g.proximal(x_tmp, tau, out = x)
- xbar = x - x_old
+ x.subtract(x_old, out=xbar)
xbar *= theta
xbar += x
-
-
+
x_old.fill(x)
y_old.fill(y)
-
-# pass
-#
-## # Gradient descent, Dual problem solution
-## y_tmp = y_old + sigma * operator.direct(xbar)
-# y_tmp = operator.direct(xbar)
-# y_tmp *= sigma
-# y_tmp +=y_old
-#
-# y = f.proximal_conjugate(y_tmp, sigma)
-## f.proximal_conjugate(y_tmp, sigma, out = y)
-#
-# # Gradient ascent, Primal problem solution
-## x_tmp = x_old - tau * operator.adjoint(y)
-#
-# x_tmp = operator.adjoint(y)
-# x_tmp *=-tau
-# x_tmp +=x_old
-#
-# x = g.proximal(x_tmp, tau)
-## g.proximal(x_tmp, tau, out = x)
-#
-# #Update
-## xbar = x + theta * (x - x_old)
-# xbar = x - x_old
-# xbar *= theta
-# xbar += x
-#
-# x_old = x
-# y_old = y
-#
-## operator.direct(xbar, out = y_tmp)
-## y_tmp *= sigma
-## y_tmp +=y_old
-# if isinstance(f, FunctionOperatorComposition):
-# p1 = f(x) + g(x)
-# else:
-# p1 = f(operator.direct(x)) + g(x)
-# d1 = -(f.convex_conjugate(y) + g(-1*operator.adjoint(y)))
-# pd1 = p1 - d1
-
-# primal.append(p1)
-# dual.append(d1)
-# pdgap.append(pd1)
-
-# if i%100==0:
-# print(p1, d1, pd1)
-# if isinstance(f, FunctionOperatorComposition):
-# p1 = f(x) + g(x)
-# else:
-
t_end = time.time()
diff --git a/Wrappers/Python/ccpi/optimisation/algs.py b/Wrappers/Python/ccpi/optimisation/algs.py
index 6b6ae2c..c142eda 100755
--- a/Wrappers/Python/ccpi/optimisation/algs.py
+++ b/Wrappers/Python/ccpi/optimisation/algs.py
@@ -21,7 +21,7 @@ import numpy
import time
from ccpi.optimisation.functions import Function
-from ccpi.optimisation.functions import ZeroFun
+from ccpi.optimisation.functions import ZeroFunction
from ccpi.framework import ImageData
from ccpi.framework import AcquisitionData
from ccpi.optimisation.spdhg import spdhg
diff --git a/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py b/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py
index 7397cfb..2d0a00a 100644
--- a/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py
@@ -116,9 +116,10 @@ class L2NormSquared(Function):
return x/(1 + tau/2)
else:
if self.b is not None:
- out.fill( (x - tau*self.b)/(1 + tau/2) )
+ x.subtract(tau*self.b, out=out)
+ out.divide(1+tau/2, out=out)
else:
- out.fill( x/(1 + tau/2) )
+ x.divide(1 + tau/2, out=out)
def __rmul__(self, scalar):
diff --git a/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py b/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py
index f524c5f..3541bc2 100755
--- a/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py
@@ -94,19 +94,22 @@ class MixedL21Norm(Function):
else:
if out is None:
-# tmp = [ el*el for el in x.containers]
-# res = sum(tmp).sqrt().maximum(1.0)
-# frac = [el/res for el in x.containers]
-# res = BlockDataContainer(*frac)
-# return res
-
- return x.divide(x.pnorm().maximum(1.0))
+ tmp = [ el*el for el in x.containers]
+ res = sum(tmp).sqrt().maximum(1.0)
+ frac = [el/res for el in x.containers]
+ return BlockDataContainer(*frac)
+
+ #TODO this is slow, why???
+# return x.divide(x.pnorm().maximum(1.0))
else:
-# res1 = functools.reduce(lambda a,b: a + b*b, x.containers, x.get_item(0) * 0 )
-# res = res1.sqrt().maximum(1.0)
-# x.divide(res, out=out)
- x.divide(x.pnorm().maximum(1.0), out=out)
+ res1 = functools.reduce(lambda a,b: a + b*b, x.containers, x.get_item(0) * 0 )
+ res = res1.sqrt().maximum(1.0)
+ x.divide(res, out=out)
+
+# x.divide(sum([el*el for el in x.containers]).sqrt().maximum(1.0), out=out)
+ #TODO this is slow, why ???
+# x.divide(x.norm().maximum(1.0), out=out)
def __rmul__(self, scalar):
diff --git a/Wrappers/Python/test/test_DataContainer.py b/Wrappers/Python/test/test_DataContainer.py
index 8edfd8b..40cd244 100755
--- a/Wrappers/Python/test/test_DataContainer.py
+++ b/Wrappers/Python/test/test_DataContainer.py
@@ -494,6 +494,14 @@ class TestDataContainer(unittest.TestCase):
self.assertEqual(order[0], image.dimension_labels[0])
self.assertEqual(order[1], image.dimension_labels[1])
self.assertEqual(order[2], image.dimension_labels[2])
+
+ ig = ImageGeometry(2,3,2)
+ try:
+ z = ImageData(numpy.random.randint(10, size=(2,3)), geometry=ig)
+ self.assertTrue(False)
+ except ValueError as ve:
+ print (ve)
+ self.assertTrue(True)
#vgeometry.allocate('')
def test_AcquisitionGeometry_allocate(self):
@@ -525,6 +533,14 @@ class TestDataContainer(unittest.TestCase):
self.assertEqual(order[1], sino.dimension_labels[1])
self.assertEqual(order[2], sino.dimension_labels[2])
self.assertEqual(order[2], sino.dimension_labels[2])
+
+
+ try:
+ z = AcquisitionData(numpy.random.randint(10, size=(2,3)), geometry=ageometry)
+ self.assertTrue(False)
+ except ValueError as ve:
+ print (ve)
+ self.assertTrue(True)
def assertNumpyArrayEqual(self, first, second):
res = True
diff --git a/Wrappers/Python/test/test_functions.py b/Wrappers/Python/test/test_functions.py
index 22721fa..05bdd7a 100644
--- a/Wrappers/Python/test/test_functions.py
+++ b/Wrappers/Python/test/test_functions.py
@@ -26,7 +26,7 @@ from ccpi.optimisation.funcs import Norm2sq
# from ccpi.optimisation.functions.L2NormSquared import SimpleL2NormSq, L2NormSq
# from ccpi.optimisation.functions.L1Norm import SimpleL1Norm, L1Norm
#from ccpi.optimisation.functions import mixed_L12Norm
-from ccpi.optimisation.functions import ZeroFun
+from ccpi.optimisation.functions import ZeroFunction
from ccpi.optimisation.functions import FunctionOperatorComposition
import unittest
@@ -329,7 +329,7 @@ class TestFunction(unittest.TestCase):
a1 = f_no_scaled(U)
a2 = f_scaled(U)
- self.assertAlmostEqual(a1,a2)
+ self.assertNumpyArrayAlmostEqual(a1.as_array(),a2.as_array())
tmp = [ el**2 for el in U.containers ]
diff --git a/Wrappers/Python/test/test_run_test.py b/Wrappers/Python/test/test_run_test.py
index 8cef925..c698032 100755
--- a/Wrappers/Python/test/test_run_test.py
+++ b/Wrappers/Python/test/test_run_test.py
@@ -6,10 +6,10 @@ from ccpi.framework import ImageData
from ccpi.framework import AcquisitionData
from ccpi.framework import ImageGeometry
from ccpi.framework import AcquisitionGeometry
-from ccpi.optimisation.algs import FISTA
-from ccpi.optimisation.algs import FBPD
+from ccpi.optimisation.algorithms import FISTA
+#from ccpi.optimisation.algs import FBPD
from ccpi.optimisation.funcs import Norm2sq
-from ccpi.optimisation.functions import ZeroFun
+from ccpi.optimisation.functions import ZeroFunction
from ccpi.optimisation.funcs import Norm1
from ccpi.optimisation.funcs import TV2D
from ccpi.optimisation.funcs import Norm2
@@ -82,7 +82,7 @@ class TestAlgorithms(unittest.TestCase):
opt = {'memopt': True}
# Create object instances with the test data A and b.
f = Norm2sq(A, b, c=0.5, memopt=True)
- g0 = ZeroFun()
+ g0 = ZeroFunction()
# Initial guess
x_init = DataContainer(np.zeros((n, 1)))
@@ -90,12 +90,15 @@ class TestAlgorithms(unittest.TestCase):
f.grad(x_init)
# Run FISTA for least squares plus zero function.
- x_fista0, it0, timing0, criter0 = FISTA(x_init, f, g0, opt=opt)
-
+ #x_fista0, it0, timing0, criter0 = FISTA(x_init, f, g0, opt=opt)
+ fa = FISTA(x_init=x_init, f=f, g=g0)
+ fa.max_iteration = 10
+ fa.run(10)
+
# Print solution and final objective/criterion value for comparison
print("FISTA least squares plus zero function solution and objective value:")
- print(x_fista0.array)
- print(criter0[-1])
+ print(fa.get_output())
+ print(fa.get_last_objective())
# Compare to CVXPY
@@ -143,7 +146,7 @@ class TestAlgorithms(unittest.TestCase):
opt = {'memopt': True}
# Create object instances with the test data A and b.
f = Norm2sq(A, b, c=0.5, memopt=True)
- g0 = ZeroFun()
+ g0 = ZeroFunction()
# Initial guess
x_init = DataContainer(np.zeros((n, 1)))
@@ -155,12 +158,16 @@ class TestAlgorithms(unittest.TestCase):
g1.prox(x_init, 0.02)
# Combine with least squares and solve using generic FISTA implementation
- x_fista1, it1, timing1, criter1 = FISTA(x_init, f, g1, opt=opt)
+ #x_fista1, it1, timing1, criter1 = FISTA(x_init, f, g1, opt=opt)
+ fa = FISTA(x_init=x_init, f=f, g=g1)
+ fa.max_iteration = 10
+ fa.run(10)
+
# Print for comparison
print("FISTA least squares plus 1-norm solution and objective value:")
- print(x_fista1.as_array().squeeze())
- print(criter1[-1])
+ print(fa.get_output())
+ print(fa.get_last_objective())
# Compare to CVXPY
diff --git a/Wrappers/Python/wip/pdhg_TV_denoising.py b/Wrappers/Python/wip/pdhg_TV_denoising.py
index d885bca..e142d94 100755
--- a/Wrappers/Python/wip/pdhg_TV_denoising.py
+++ b/Wrappers/Python/wip/pdhg_TV_denoising.py
@@ -27,7 +27,7 @@ def dt(steps):
# Create phantom for TV denoising
-N = 200
+N = 500
data = np.zeros((N,N))
data[round(N/4):round(3*N/4),round(N/4):round(3*N/4)] = 0.5
@@ -40,8 +40,8 @@ ag = ig
n1 = random_noise(data, mode = 'gaussian', mean=0, var = 0.05, seed=10)
noisy_data = ImageData(n1)
-#plt.imshow(noisy_data.as_array())
-#plt.show()
+plt.imshow(noisy_data.as_array())
+plt.show()
#%%
@@ -82,7 +82,6 @@ else:
# Compute operator Norm
normK = operator.norm()
-print ("normK", normK)
# Primal & dual stepsizes
sigma = 1
@@ -91,54 +90,113 @@ tau = 1/(sigma*normK**2)
opt = {'niter':2000}
opt1 = {'niter':2000, 'memopt': True}
-#t1 = timer()
-#res, time, primal, dual, pdgap = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt)
-#print(timer()-t1)
-#
-#print("with memopt \n")
-#
-#t2 = timer()
-#res1, time1, primal1, dual1, pdgap1 = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt1)
-#print(timer()-t2)
-
-pdhg = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma)
-pdhg.max_iteration = 2000
-pdhg.update_objective_interval = 100
-
+t1 = timer()
+res, time, primal, dual, pdgap = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt)
+t2 = timer()
-pdhgo = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma, memopt=True)
-pdhgo.max_iteration = 2000
-pdhgo.update_objective_interval = 100
-steps = [timer()]
-pdhgo.run(2000)
-steps.append(timer())
-t1 = dt(steps)
-
-pdhg.run(2000)
-steps.append(timer())
-t2 = dt(steps)
+t3 = timer()
+res1, time1, primal1, dual1, pdgap1 = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt1)
+t4 = timer()
+#
+print ("No memopt in {}s, memopt in {}s ".format(t2-t1, t4 -t3))
-print ("Time difference {}s {}s {}s Speedup {:.2f}".format(t1,t2,t2-t1, t2/t1))
-res = pdhg.get_output()
-res1 = pdhgo.get_output()
+#
+#%%
+#pdhg = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma)
+#pdhg.max_iteration = 2000
+#pdhg.update_objective_interval = 100
+##
+#pdhgo = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma, memopt=True)
+#pdhgo.max_iteration = 2000
+#pdhgo.update_objective_interval = 100
+##
+#steps = [timer()]
+#pdhgo.run(2000)
+#steps.append(timer())
+#t1 = dt(steps)
+##
+#pdhg.run(2000)
+#steps.append(timer())
+#t2 = dt(steps)
+#
+#print ("Time difference {}s {}s {}s Speedup {:.2f}".format(t1,t2,t2-t1, t2/t1))
+#res = pdhg.get_output()
+#res1 = pdhgo.get_output()
-diff = (res-res1)
-print ("diff norm {} max {}".format(diff.norm(), diff.abs().as_array().max()))
-print ("Sum ( abs(diff) ) {}".format(diff.abs().sum()))
+#%%
+#plt.figure(figsize=(15,15))
+#plt.subplot(3,1,1)
+#plt.imshow(res.as_array())
+#plt.title('no memopt')
+#plt.colorbar()
+#plt.subplot(3,1,2)
+#plt.imshow(res1.as_array())
+#plt.title('memopt')
+#plt.colorbar()
+#plt.subplot(3,1,3)
+#plt.imshow((res1 - res).abs().as_array())
+#plt.title('diff')
+#plt.colorbar()
+#plt.show()
-plt.figure(figsize=(5,5))
-plt.subplot(1,3,1)
-plt.imshow(res.as_array())
-plt.colorbar()
+#plt.figure(figsize=(15,15))
+#plt.subplot(3,1,1)
+#plt.imshow(pdhg.get_output().as_array())
+#plt.title('no memopt class')
+#plt.colorbar()
+#plt.subplot(3,1,2)
+#plt.imshow(res.as_array())
+#plt.title('no memopt')
+#plt.colorbar()
+#plt.subplot(3,1,3)
+#plt.imshow((pdhg.get_output() - res).abs().as_array())
+#plt.title('diff')
+#plt.colorbar()
#plt.show()
-
-#plt.figure(figsize=(5,5))
-plt.subplot(1,3,2)
-plt.imshow(res1.as_array())
-plt.colorbar()
+#
+#
+#
+#plt.figure(figsize=(15,15))
+#plt.subplot(3,1,1)
+#plt.imshow(pdhgo.get_output().as_array())
+#plt.title('memopt class')
+#plt.colorbar()
+#plt.subplot(3,1,2)
+#plt.imshow(res1.as_array())
+#plt.title('no memopt')
+#plt.colorbar()
+#plt.subplot(3,1,3)
+#plt.imshow((pdhgo.get_output() - res1).abs().as_array())
+#plt.title('diff')
+#plt.colorbar()
+#plt.show()
+
+
+
+
+# print ("Time difference {}s {}s {}s Speedup {:.2f}".format(t1,t2,t2-t1, t2/t1))
+# res = pdhg.get_output()
+# res1 = pdhgo.get_output()
+#
+# diff = (res-res1)
+# print ("diff norm {} max {}".format(diff.norm(), diff.abs().as_array().max()))
+# print ("Sum ( abs(diff) ) {}".format(diff.abs().sum()))
+#
+#
+# plt.figure(figsize=(5,5))
+# plt.subplot(1,3,1)
+# plt.imshow(res.as_array())
+# plt.colorbar()
+# #plt.show()
+#
+# #plt.figure(figsize=(5,5))
+# plt.subplot(1,3,2)
+# plt.imshow(res1.as_array())
+# plt.colorbar()
+
#plt.show()
diff --git a/Wrappers/Python/wip/pdhg_TV_denoising3D.py b/Wrappers/Python/wip/pdhg_TV_denoising3D.py
new file mode 100644
index 0000000..06ecfa2
--- /dev/null
+++ b/Wrappers/Python/wip/pdhg_TV_denoising3D.py
@@ -0,0 +1,360 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Created on Fri Feb 22 14:53:03 2019
+
+@author: evangelos
+"""
+
+from ccpi.framework import ImageData, ImageGeometry, BlockDataContainer
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+from ccpi.optimisation.algorithms import PDHG, PDHG_old
+
+from ccpi.optimisation.operators import BlockOperator, Identity, Gradient
+from ccpi.optimisation.functions import ZeroFunction, L2NormSquared, \
+ MixedL21Norm, FunctionOperatorComposition, BlockFunction
+
+from skimage.util import random_noise
+
+from timeit import default_timer as timer
+def dt(steps):
+ return steps[-1] - steps[-2]
+
+#%%
+
+# Create phantom for TV denoising
+
+import timeit
+import os
+from tomophantom import TomoP3D
+import tomophantom
+
+print ("Building 3D phantom using TomoPhantom software")
+tic=timeit.default_timer()
+model = 13 # select a model number from the library
+N_size = 64 # Define phantom dimensions using a scalar value (cubic phantom)
+path = os.path.dirname(tomophantom.__file__)
+path_library3D = os.path.join(path, "Phantom3DLibrary.dat")
+#This will generate a N_size x N_size x N_size phantom (3D)
+phantom_tm = TomoP3D.Model(model, N_size, path_library3D)
+#toc=timeit.default_timer()
+#Run_time = toc - tic
+#print("Phantom has been built in {} seconds".format(Run_time))
+#
+#sliceSel = int(0.5*N_size)
+##plt.gray()
+#plt.figure()
+#plt.subplot(131)
+#plt.imshow(phantom_tm[sliceSel,:,:],vmin=0, vmax=1)
+#plt.title('3D Phantom, axial view')
+#
+#plt.subplot(132)
+#plt.imshow(phantom_tm[:,sliceSel,:],vmin=0, vmax=1)
+#plt.title('3D Phantom, coronal view')
+#
+#plt.subplot(133)
+#plt.imshow(phantom_tm[:,:,sliceSel],vmin=0, vmax=1)
+#plt.title('3D Phantom, sagittal view')
+#plt.show()
+
+#%%
+
+N = N_size
+ig = ImageGeometry(voxel_num_x=N, voxel_num_y=N, voxel_num_z=N)
+
+n1 = random_noise(phantom_tm, mode = 'gaussian', mean=0, var = 0.001, seed=10)
+noisy_data = ImageData(n1)
+#plt.imshow(noisy_data.as_array()[:,:,32])
+
+#%%
+
+# Regularisation Parameter
+alpha = 0.02
+
+#method = input("Enter structure of PDHG (0=Composite or 1=NotComposite): ")
+method = '0'
+
+if method == '0':
+
+ # Create operators
+ op1 = Gradient(ig)
+ op2 = Identity(ig)
+
+ # Form Composite Operator
+ operator = BlockOperator(op1, op2, shape=(2,1) )
+
+ #### Create functions
+
+ f1 = alpha * MixedL21Norm()
+ f2 = 0.5 * L2NormSquared(b = noisy_data)
+ f = BlockFunction(f1, f2)
+
+ g = ZeroFunction()
+
+else:
+
+ ###########################################################################
+ # No Composite #
+ ###########################################################################
+ operator = Gradient(ig)
+ f = alpha * FunctionOperatorComposition(operator, MixedL21Norm())
+ g = L2NormSquared(b = noisy_data)
+
+ ###########################################################################
+#%%
+
+# Compute operator Norm
+normK = operator.norm()
+
+# Primal & dual stepsizes
+sigma = 1
+tau = 1/(sigma*normK**2)
+
+opt = {'niter':2000}
+opt1 = {'niter':2000, 'memopt': True}
+
+#t1 = timer()
+#res, time, primal, dual, pdgap = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt)
+#t2 = timer()
+
+
+t3 = timer()
+res1, time1, primal1, dual1, pdgap1 = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt1)
+t4 = timer()
+
+#import cProfile
+#cProfile.run('res1, time1, primal1, dual1, pdgap1 = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt1) ')
+###
+print ("No memopt in {}s, memopt in {}s ".format(t2-t1, t4 -t3))
+#
+##
+##%%
+#
+#plt.figure(figsize=(10,10))
+#plt.subplot(311)
+#plt.imshow(res1.as_array()[sliceSel,:,:])
+#plt.colorbar()
+#plt.title('3D Phantom, axial view')
+#
+#plt.subplot(312)
+#plt.imshow(res1.as_array()[:,sliceSel,:])
+#plt.colorbar()
+#plt.title('3D Phantom, coronal view')
+#
+#plt.subplot(313)
+#plt.imshow(res1.as_array()[:,:,sliceSel])
+#plt.colorbar()
+#plt.title('3D Phantom, sagittal view')
+#plt.show()
+#
+#plt.figure(figsize=(10,10))
+#plt.subplot(311)
+#plt.imshow(res.as_array()[sliceSel,:,:])
+#plt.colorbar()
+#plt.title('3D Phantom, axial view')
+#
+#plt.subplot(312)
+#plt.imshow(res.as_array()[:,sliceSel,:])
+#plt.colorbar()
+#plt.title('3D Phantom, coronal view')
+#
+#plt.subplot(313)
+#plt.imshow(res.as_array()[:,:,sliceSel])
+#plt.colorbar()
+#plt.title('3D Phantom, sagittal view')
+#plt.show()
+#
+#diff = (res1 - res).abs()
+#
+#plt.figure(figsize=(10,10))
+#plt.subplot(311)
+#plt.imshow(diff.as_array()[sliceSel,:,:])
+#plt.colorbar()
+#plt.title('3D Phantom, axial view')
+#
+#plt.subplot(312)
+#plt.imshow(diff.as_array()[:,sliceSel,:])
+#plt.colorbar()
+#plt.title('3D Phantom, coronal view')
+#
+#plt.subplot(313)
+#plt.imshow(diff.as_array()[:,:,sliceSel])
+#plt.colorbar()
+#plt.title('3D Phantom, sagittal view')
+#plt.show()
+#
+#
+#
+#
+##%%
+#pdhg = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma)
+#pdhg.max_iteration = 2000
+#pdhg.update_objective_interval = 100
+####
+#pdhgo = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma, memopt=True)
+#pdhgo.max_iteration = 2000
+#pdhgo.update_objective_interval = 100
+####
+#steps = [timer()]
+#pdhgo.run(2000)
+#steps.append(timer())
+#t1 = dt(steps)
+##
+#pdhg.run(2000)
+#steps.append(timer())
+#t2 = dt(steps)
+#
+#print ("Time difference {}s {}s {}s Speedup {:.2f}".format(t1,t2,t2-t1, t2/t1))
+#res = pdhg.get_output()
+#res1 = pdhgo.get_output()
+
+#%%
+#plt.figure(figsize=(15,15))
+#plt.subplot(3,1,1)
+#plt.imshow(res.as_array())
+#plt.title('no memopt')
+#plt.colorbar()
+#plt.subplot(3,1,2)
+#plt.imshow(res1.as_array())
+#plt.title('memopt')
+#plt.colorbar()
+#plt.subplot(3,1,3)
+#plt.imshow((res1 - res).abs().as_array())
+#plt.title('diff')
+#plt.colorbar()
+#plt.show()
+
+
+#plt.figure(figsize=(15,15))
+#plt.subplot(3,1,1)
+#plt.imshow(pdhg.get_output().as_array())
+#plt.title('no memopt class')
+#plt.colorbar()
+#plt.subplot(3,1,2)
+#plt.imshow(res.as_array())
+#plt.title('no memopt')
+#plt.colorbar()
+#plt.subplot(3,1,3)
+#plt.imshow((pdhg.get_output() - res).abs().as_array())
+#plt.title('diff')
+#plt.colorbar()
+#plt.show()
+#
+#
+#
+#plt.figure(figsize=(15,15))
+#plt.subplot(3,1,1)
+#plt.imshow(pdhgo.get_output().as_array())
+#plt.title('memopt class')
+#plt.colorbar()
+#plt.subplot(3,1,2)
+#plt.imshow(res1.as_array())
+#plt.title('no memopt')
+#plt.colorbar()
+#plt.subplot(3,1,3)
+#plt.imshow((pdhgo.get_output() - res1).abs().as_array())
+#plt.title('diff')
+#plt.colorbar()
+#plt.show()
+
+
+
+
+
+# print ("Time difference {}s {}s {}s Speedup {:.2f}".format(t1,t2,t2-t1, t2/t1))
+# res = pdhg.get_output()
+# res1 = pdhgo.get_output()
+#
+# diff = (res-res1)
+# print ("diff norm {} max {}".format(diff.norm(), diff.abs().as_array().max()))
+# print ("Sum ( abs(diff) ) {}".format(diff.abs().sum()))
+#
+#
+# plt.figure(figsize=(5,5))
+# plt.subplot(1,3,1)
+# plt.imshow(res.as_array())
+# plt.colorbar()
+# #plt.show()
+#
+# #plt.figure(figsize=(5,5))
+# plt.subplot(1,3,2)
+# plt.imshow(res1.as_array())
+# plt.colorbar()
+
+#plt.show()
+
+
+
+#=======
+## opt = {'niter':2000, 'memopt': True}
+#
+## res, time, primal, dual, pdgap = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt)
+#
+#>>>>>>> origin/pdhg_fix
+#
+#
+## opt = {'niter':2000, 'memopt': False}
+## res1, time1, primal1, dual1, pdgap1 = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt)
+#
+## plt.figure(figsize=(5,5))
+## plt.subplot(1,3,1)
+## plt.imshow(res.as_array())
+## plt.title('memopt')
+## plt.colorbar()
+## plt.subplot(1,3,2)
+## plt.imshow(res1.as_array())
+## plt.title('no memopt')
+## plt.colorbar()
+## plt.subplot(1,3,3)
+## plt.imshow((res1 - res).abs().as_array())
+## plt.title('diff')
+## plt.colorbar()
+## plt.show()
+#pdhg = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma)
+#pdhg.max_iteration = 2000
+#pdhg.update_objective_interval = 100
+#
+#
+#pdhgo = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma, memopt=True)
+#pdhgo.max_iteration = 2000
+#pdhgo.update_objective_interval = 100
+#
+#steps = [timer()]
+#pdhgo.run(200)
+#steps.append(timer())
+#t1 = dt(steps)
+#
+#pdhg.run(200)
+#steps.append(timer())
+#t2 = dt(steps)
+#
+#print ("Time difference {} {} {}".format(t1,t2,t2-t1))
+#sol = pdhg.get_output().as_array()
+##sol = result.as_array()
+##
+#fig = plt.figure()
+#plt.subplot(1,3,1)
+#plt.imshow(noisy_data.as_array())
+#plt.colorbar()
+#plt.subplot(1,3,2)
+#plt.imshow(sol)
+#plt.colorbar()
+#plt.subplot(1,3,3)
+#plt.imshow(pdhgo.get_output().as_array())
+#plt.colorbar()
+#
+#plt.show()
+###
+##
+####
+##plt.plot(np.linspace(0,N,N), data[int(N/2),:], label = 'GTruth')
+##plt.plot(np.linspace(0,N,N), sol[int(N/2),:], label = 'Recon')
+##plt.legend()
+##plt.show()
+#
+#
+##%%
+##
diff --git a/Wrappers/Python/wip/pdhg_TV_tomography2D.py b/Wrappers/Python/wip/pdhg_TV_tomography2D.py
index e0868f7..3fec34e 100644
--- a/Wrappers/Python/wip/pdhg_TV_tomography2D.py
+++ b/Wrappers/Python/wip/pdhg_TV_tomography2D.py
@@ -56,7 +56,7 @@ detectors = 150
angles = np.linspace(0,np.pi,100)
ag = AcquisitionGeometry('parallel','2D',angles, detectors)
-Aop = AstraProjectorSimple(ig, ag, 'cpu')
+Aop = AstraProjectorSimple(ig, ag, 'gpu')
sin = Aop.direct(data)
plt.imshow(sin.as_array())
@@ -113,43 +113,28 @@ else:
sigma = 1
tau = 1/(sigma*normK**2)
-#pdhg = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma)
-#pdhg.max_iteration = 5000
-#pdhg.update_objective_interval = 250
-#
-#pdhg.run(5000)
-
-opt = {'niter':300}
-opt1 = {'niter':300, 'memopt': True}
+# Compute operator Norm
+normK = operator.norm()
+
+# Primal & dual stepsizes
+sigma = 1
+tau = 1/(sigma*normK**2)
+opt = {'niter':2000}
+opt1 = {'niter':2000, 'memopt': True}
t1 = timer()
res, time, primal, dual, pdgap = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt)
-
-print(timer()-t1)
-plt.figure(figsize=(5,5))
-plt.imshow(res.as_array())
-plt.colorbar()
-plt.show()
-
-#%%
-print("with memopt \n")
-#
t2 = timer()
+
+
+t3 = timer()
res1, time1, primal1, dual1, pdgap1 = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt1)
-#print(timer()-t2)
-#
-#
-plt.figure(figsize=(5,5))
-plt.imshow(res1.as_array())
-plt.colorbar()
-plt.show()
+t4 = timer()
#
-#%%
-plt.figure(figsize=(5,5))
-plt.imshow(np.abs(res1.as_array()-res.as_array()))
-plt.colorbar()
-plt.show()
+print ("No memopt in {}s, memopt in {}s ".format(t2-t1, t4 -t3))
+
+
#%%
#sol = pdhg.get_output().as_array()
#fig = plt.figure()