From c7213fdfcd31e6ec780aab4afe1bd34374d784f5 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 11 Oct 2019 16:31:49 +0100 Subject: Pass kwargs to algorithm (#380) * add test for algorithm * fix conflict * suppress warning * pass kwargs to Algorithm class creator --- Wrappers/Python/ccpi/framework/framework.py | 8 ++-- .../Python/ccpi/optimisation/algorithms/CGLS.py | 30 +++++++++----- .../Python/ccpi/optimisation/algorithms/FISTA.py | 27 ++++++++----- .../optimisation/algorithms/GradientDescent.py | 30 +++++++++----- .../Python/ccpi/optimisation/algorithms/PDHG.py | 29 +++++++++---- .../Python/ccpi/optimisation/algorithms/SIRT.py | 27 +++++++++---- Wrappers/Python/test/test_algorithms.py | 47 ++++++++++++++++++---- 7 files changed, 145 insertions(+), 53 deletions(-) (limited to 'Wrappers') diff --git a/Wrappers/Python/ccpi/framework/framework.py b/Wrappers/Python/ccpi/framework/framework.py index 6d5bd1b..ed97862 100755 --- a/Wrappers/Python/ccpi/framework/framework.py +++ b/Wrappers/Python/ccpi/framework/framework.py @@ -722,7 +722,8 @@ class DataContainer(object): return type(self)(out, deep_copy=False, dimension_labels=self.dimension_labels, - geometry=self.geometry) + geometry=self.geometry, + suppress_warning=True) elif issubclass(type(out), DataContainer) and issubclass(type(x2), DataContainer): @@ -800,7 +801,8 @@ class DataContainer(object): return type(self)(out, deep_copy=False, dimension_labels=self.dimension_labels, - geometry=self.geometry) + geometry=self.geometry, + suppress_warning=True) elif issubclass(type(out), DataContainer): if self.check_dimensions(out): kwargs['out'] = out.as_array() @@ -885,7 +887,7 @@ class ImageData(DataContainer): if not kwargs.get('suppress_warning', False): warnings.warn('Direct invocation is deprecated and will be removed in following version. Use allocate from ImageGeometry instead', - DeprecationWarning) + DeprecationWarning, stacklevel=4) self.geometry = kwargs.get('geometry', None) if array is None: if self.geometry is not None: diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py index d2e5b29..57292df 100755 --- a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py @@ -47,20 +47,30 @@ class CGLS(Algorithm): Reference: https://web.stanford.edu/group/SOL/software/cgls/ ''' - def __init__(self, **kwargs): + def __init__(self, x_init=None, operator=None, data=None, tolerance=1e-6, **kwargs): + '''initialisation of the algorithm + + :param operator : Linear operator for the inverse problem + :param x_init : Initial guess ( Default x_init = 0) + :param data : Acquired data to reconstruct + :param tolerance: Tolerance/ Stopping Criterion to end CGLS algorithm + ''' + super(CGLS, self).__init__(**kwargs) - super(CGLS, self).__init__() - x_init = kwargs.get('x_init', None) - operator = kwargs.get('operator', None) - data = kwargs.get('data', None) - tolerance = kwargs.get('tolerance', 1e-6) if x_init is not None and operator is not None and data is not None: - print(self.__class__.__name__ , "set_up called from creator") self.set_up(x_init=x_init, operator=operator, data=data, tolerance=tolerance) def set_up(self, x_init, operator, data, tolerance=1e-6): - + '''initialisation of the algorithm + + :param operator : Linear operator for the inverse problem + :param x_init : Initial guess ( Default x_init = 0) + :param data : Acquired data to reconstruct + :param tolerance: Tolerance/ Stopping Criterion to end CGLS algorithm + ''' + print("{} setting up".format(self.__class__.__name__, )) + self.x = x_init * 0. self.operator = operator self.tolerance = tolerance @@ -78,7 +88,9 @@ class CGLS(Algorithm): self.xmax = self.normx self.loss.append(self.r.squared_norm()) - self.configured = True + self.configured = True + print("{} configured".format(self.__class__.__name__, )) + def update(self): diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py index 5d79b67..8c485b7 100755 --- a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py @@ -53,24 +53,31 @@ class FISTA(Algorithm): ''' - def __init__(self, **kwargs): + def __init__(self, x_init=None, f=None, g=ZeroFunction(), **kwargs): - '''creator + '''FISTA algorithm creator initialisation can be done at creation time if all - proper variables are passed or later with set_up''' + proper variables are passed or later with set_up + + :param x_init : Initial guess ( Default x_init = 0) + :param f : Differentiable function + :param g : Convex function with " simple " proximal operator''' + + super(FISTA, self).__init__(**kwargs) - super(FISTA, self).__init__() - f = kwargs.get('f', None) - g = kwargs.get('g', ZeroFunction()) - x_init = kwargs.get('x_init', None) - if x_init is not None and f is not None: - print(self.__class__.__name__ , "set_up called from creator") self.set_up(x_init=x_init, f=f, g=g) def set_up(self, x_init, f, g=ZeroFunction()): + '''initialisation of the algorithm + :param x_init : Initial guess ( Default x_init = 0) + :param f : Differentiable function + :param g : Convex function with " simple " proximal operator''' + + print("{} setting up".format(self.__class__.__name__, )) + self.y = x_init.copy() self.x_old = x_init.copy() self.x = x_init.copy() @@ -84,6 +91,8 @@ class FISTA(Algorithm): self.t = 1 self.update_objective() self.configured = True + print("{} configured".format(self.__class__.__name__, )) + def update(self): self.t_old = self.t diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py b/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py index f79651a..8f9c958 100755 --- a/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py @@ -35,17 +35,20 @@ class GradientDescent(Algorithm): ''' - def __init__(self, **kwargs): - '''initialisation can be done at creation time if all - proper variables are passed or later with set_up''' - super(GradientDescent, self).__init__() - - x_init = kwargs.get('x_init', None) - objective_function = kwargs.get('objective_function', None) - rate = kwargs.get('rate', None) + def __init__(self, x_init=None, objective_function=None, rate=None, **kwargs): + '''GradientDescent algorithm creator + + initialisation can be done at creation time if all + proper variables are passed or later with set_up + + :param x_init: initial guess + :param objective_function: objective function to be minimised + :param rate: step rate + ''' + super(GradientDescent, self).__init__(**kwargs) + if x_init is not None and objective_function is not None and rate is not None: - print(self.__class__.__name__, "set_up called from creator") self.set_up(x_init=x_init, objective_function=objective_function, rate=rate) def should_stop(self): @@ -53,7 +56,13 @@ class GradientDescent(Algorithm): return self.iteration >= self.max_iteration def set_up(self, x_init, objective_function, rate): - '''initialisation of the algorithm''' + '''initialisation of the algorithm + + :param x_init: initial guess + :param objective_function: objective function to be minimised + :param rate: step rate''' + print("{} setting up".format(self.__class__.__name__, )) + self.x = x_init.copy() self.objective_function = objective_function self.rate = rate @@ -69,6 +78,7 @@ class GradientDescent(Algorithm): self.x_update = x_init.copy() self.configured = True + print("{} configured".format(self.__class__.__name__, )) def update(self): '''Single iteration''' diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py index 7bc4e11..7ed82b2 100644 --- a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py @@ -61,20 +61,31 @@ class PDHG(Algorithm): SIAM J. Imaging Sci. 3, 1015–1046. ''' - def __init__(self, **kwargs): - super(PDHG, self).__init__(max_iteration=kwargs.get('max_iteration',0)) - f = kwargs.get('f', None) - operator = kwargs.get('operator', None) - g = kwargs.get('g', None) - tau = kwargs.get('tau', None) - sigma = kwargs.get('sigma', 1.) + def __init__(self, f=None, g=None, operator=None, tau=None, sigma=1.,**kwargs): + '''PDHG algorithm creator + + :param operator : Linear Operator = K + :param f : Convex function with "simple" proximal of its conjugate. + :param g : Convex function with "simple" proximal + :param sigma : Step size parameter for Primal problem + :param tau : Step size parameter for Dual problem''' + super(PDHG, self).__init__(**kwargs) + if f is not None and operator is not None and g is not None: - print(self.__class__.__name__ , "set_up called from creator") self.set_up(f=f, g=g, operator=operator, tau=tau, sigma=sigma) def set_up(self, f, g, operator, tau=None, sigma=1.): + '''initialisation of the algorithm + + :param operator : Linear Operator = K + :param f : Convex function with "simple" proximal of its conjugate. + :param g : Convex function with "simple" proximal + :param sigma : Step size parameter for Primal problem + :param tau : Step size parameter for Dual problem''' + print("{} setting up".format(self.__class__.__name__, )) + # can't happen with default sigma if sigma is None and tau is None: raise ValueError('Need sigma*tau||K||^2<1') @@ -108,6 +119,8 @@ class PDHG(Algorithm): self.theta = 1 self.update_objective() self.configured = True + print("{} configured".format(self.__class__.__name__, )) + def update(self): diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py b/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py index 8feef87..50398f4 100644 --- a/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py @@ -47,19 +47,30 @@ class SIRT(Algorithm): e.g. x\in[0, 1], IndicatorBox to enforce box constraints Default is None). ''' - def __init__(self, **kwargs): - super(SIRT, self).__init__() + def __init__(self, x_init=None, operator=None, data=None, constraint=None, **kwargs): + '''SIRT algorithm creator - x_init = kwargs.get('x_init', None) - operator = kwargs.get('operator', None) - data = kwargs.get('data', None) - constraint = kwargs.get('constraint', None) + :param x_init : Initial guess + :param operator : Linear operator for the inverse problem + :param data : Acquired data to reconstruct + :param constraint : Function proximal method + e.g. x\in[0, 1], IndicatorBox to enforce box constraints + Default is None).''' + super(SIRT, self).__init__(**kwargs) if x_init is not None and operator is not None and data is not None: - print(self.__class__.__name__, "set_up called from creator") self.set_up(x_init=x_init, operator=operator, data=data, constraint=constraint) def set_up(self, x_init, operator, data, constraint=None): + '''initialisation of the algorithm + + :param operator : Linear operator for the inverse problem + :param x_init : Initial guess + :param data : Acquired data to reconstruct + :param constraint : Function proximal method + e.g. x\in[0, 1], IndicatorBox to enforce box constraints + Default is None).''' + print("{} setting up".format(self.__class__.__name__, )) self.x = x_init.copy() self.operator = operator @@ -75,6 +86,8 @@ class SIRT(Algorithm): self.D = 1/self.operator.adjoint(self.operator.range_geometry().allocate(value=1.0)) self.update_objective() self.configured = True + print("{} configured".format(self.__class__.__name__, )) + def update(self): diff --git a/Wrappers/Python/test/test_algorithms.py b/Wrappers/Python/test/test_algorithms.py index 15a83e8..1dd198a 100755 --- a/Wrappers/Python/test/test_algorithms.py +++ b/Wrappers/Python/test/test_algorithms.py @@ -75,24 +75,40 @@ class TestAlgorithms(unittest.TestCase): alg.max_iteration = 20 alg.run(20, verbose=True) self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array()) + alg = GradientDescent(x_init=x_init, + objective_function=norm2sq, + rate=rate, max_iteration=20, + update_objective_interval=2) + alg.max_iteration = 20 + self.assertTrue(alg.max_iteration == 20) + self.assertTrue(alg.update_objective_interval==2) + alg.run(20, verbose=True) + self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array()) def test_CGLS(self): print ("Test CGLS") #ig = ImageGeometry(124,153,154) ig = ImageGeometry(10,2) numpy.random.seed(2) x_init = ig.allocate(0.) + b = ig.allocate('random') # b = x_init.copy() # fill with random numbers # b.fill(numpy.random.random(x_init.shape)) - b = ig.allocate() - bdata = numpy.reshape(numpy.asarray([i for i in range(20)]), (2,10)) - b.fill(bdata) + # b = ig.allocate() + # bdata = numpy.reshape(numpy.asarray([i for i in range(20)]), (2,10)) + # b.fill(bdata) identity = Identity(ig) alg = CGLS(x_init=x_init, operator=identity, data=b) alg.max_iteration = 200 alg.run(20, verbose=True) - self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array(), decimal=4) + self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array()) + + alg = CGLS(x_init=x_init, operator=identity, data=b, max_iteration=200, update_objective_interval=2) + self.assertTrue(alg.max_iteration == 200) + self.assertTrue(alg.update_objective_interval==2) + alg.run(20, verbose=True) + self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array()) def test_FISTA(self): print ("Test FISTA") @@ -114,6 +130,15 @@ class TestAlgorithms(unittest.TestCase): alg.max_iteration = 2 alg.run(20, verbose=True) self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array()) + + alg = FISTA(x_init=x_init, f=norm2sq, g=ZeroFunction(), max_iteration=2, update_objective_interval=2) + + self.assertTrue(alg.max_iteration == 2) + self.assertTrue(alg.update_objective_interval==2) + + alg.run(20, verbose=True) + self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array()) + def test_FISTA_Norm2Sq(self): print ("Test FISTA Norm2Sq") @@ -133,6 +158,14 @@ class TestAlgorithms(unittest.TestCase): alg.max_iteration = 2 alg.run(20, verbose=True) self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array()) + + alg = FISTA(x_init=x_init, f=norm2sq, g=ZeroFunction(), max_iteration=2, update_objective_interval=3) + self.assertTrue(alg.max_iteration == 2) + self.assertTrue(alg.update_objective_interval== 3) + + alg.run(20, verbose=True) + self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array()) + def test_FISTA_catch_Lipschitz(self): print ("Test FISTA catch Lipschitz") ig = ImageGeometry(127,139,149) @@ -242,9 +275,9 @@ class TestAlgorithms(unittest.TestCase): tau = 1/(sigma*normK**2) # Setup and run the PDHG algorithm - pdhg1 = PDHG(f=f1,g=g,operator=operator, tau=tau, sigma=sigma) - pdhg1.max_iteration = 2000 - pdhg1.update_objective_interval = 200 + pdhg1 = PDHG(f=f1,g=g,operator=operator, tau=tau, sigma=sigma, + max_iteration=2000, update_objective_interval=200) + pdhg1.run(1000) rmse = (pdhg1.get_output() - data).norm() / data.as_array().size -- cgit v1.2.3