diff options
author | Edoardo Pasca <edo.paskino@gmail.com> | 2019-06-07 15:39:46 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-06-07 15:39:46 +0000 |
commit | 53e09caf548d51ca1d071f4bffe249afab5649f6 (patch) | |
tree | 5a7c6600e981c517bf6a6ebd62890958590d867f /Wrappers | |
parent | 2fa4ef8cdb4ccc420007b65d975682c9939e0171 (diff) | |
parent | 1d7324cec6f10f00b73af2bab3469202c5cc2e87 (diff) | |
download | framework-53e09caf548d51ca1d071f4bffe249afab5649f6.tar.gz framework-53e09caf548d51ca1d071f4bffe249afab5649f6.tar.bz2 framework-53e09caf548d51ca1d071f4bffe249afab5649f6.tar.xz framework-53e09caf548d51ca1d071f4bffe249afab5649f6.zip |
Merge pull request #305 from vais-ral/config_alg
Config alg
Diffstat (limited to 'Wrappers')
9 files changed, 25 insertions, 95 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/Algorithm.py b/Wrappers/Python/ccpi/optimisation/algorithms/Algorithm.py index 4fbf83b..c62d0ea 100755 --- a/Wrappers/Python/ccpi/optimisation/algorithms/Algorithm.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/Algorithm.py @@ -51,6 +51,7 @@ class Algorithm(object): self.__max_iteration = kwargs.get('max_iteration', 0) self.__loss = [] self.memopt = False + self.configured = False self.timing = [] self.update_objective_interval = kwargs.get('update_objective_interval', 1) def set_up(self, *args, **kwargs): @@ -86,6 +87,8 @@ class Algorithm(object): raise StopIteration() else: time0 = time.time() + if not self.configured: + raise ValueError('Algorithm not configured correctly. Please run set_up.') self.update() self.timing.append( time.time() - time0 ) if self.iteration % self.update_objective_interval == 0: diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py index 4d4843c..6b610a0 100755 --- a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py @@ -23,6 +23,7 @@ Created on Thu Feb 21 11:11:23 2019 """ from ccpi.optimisation.algorithms import Algorithm +from ccpi.optimisation.functions import Norm2Sq class CGLS(Algorithm): @@ -59,6 +60,9 @@ class CGLS(Algorithm): # self.normr2 = sum(self.normr2) #self.normr2 = numpy.sqrt(self.normr2) #print ("set_up" , self.normr2) + n = Norm2Sq(operator, self.data) + self.loss.append(n(x_init)) + self.configured = True def update(self): @@ -84,4 +88,4 @@ class CGLS(Algorithm): self.d = s + beta*self.d def update_objective(self): - self.loss.append(self.r.squared_norm())
\ No newline at end of file + self.loss.append(self.r.squared_norm()) diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/FBPD.py b/Wrappers/Python/ccpi/optimisation/algorithms/FBPD.py deleted file mode 100644 index aa07359..0000000 --- a/Wrappers/Python/ccpi/optimisation/algorithms/FBPD.py +++ /dev/null @@ -1,86 +0,0 @@ -# -*- coding: utf-8 -*- -# This work is part of the Core Imaging Library developed by -# Visual Analytics and Imaging System Group of the Science Technology -# Facilities Council, STFC - -# Copyright 2019 Edoardo Pasca - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Created on Thu Feb 21 11:09:03 2019 - -@author: ofn77899 -""" - -from ccpi.optimisation.algorithms import Algorithm -from ccpi.optimisation.functions import ZeroFunction - -class FBPD(Algorithm): - '''FBPD Algorithm - - Parameters: - x_init: initial guess - f: constraint - g: data fidelity - h: regularizer - opt: additional algorithm - ''' - constraint = None - data_fidelity = None - regulariser = None - def __init__(self, **kwargs): - pass - def set_up(self, x_init, operator=None, constraint=None, data_fidelity=None,\ - regulariser=None, opt=None): - - # default inputs - if constraint is None: - self.constraint = ZeroFun() - else: - self.constraint = constraint - if data_fidelity is None: - data_fidelity = ZeroFun() - else: - self.data_fidelity = data_fidelity - if regulariser is None: - self.regulariser = ZeroFun() - else: - self.regulariser = regulariser - - # algorithmic parameters - - - # step-sizes - self.tau = 2 / (self.data_fidelity.L + 2) - self.sigma = (1/self.tau - self.data_fidelity.L/2) / self.regulariser.L - - self.inv_sigma = 1/self.sigma - - # initialization - self.x = x_init - self.y = operator.direct(self.x) - - - def update(self): - - # primal forward-backward step - x_old = self.x - self.x = self.x - self.tau * ( self.data_fidelity.grad(self.x) + self.operator.adjoint(self.y) ) - self.x = self.constraint.prox(self.x, self.tau); - - # dual forward-backward step - self.y = self.y + self.sigma * self.operator.direct(2*self.x - x_old); - self.y = self.y - self.sigma * self.regulariser.prox(self.inv_sigma*self.y, self.inv_sigma); - - # time and criterion - self.loss = self.constraint(self.x) + self.data_fidelity(self.x) + self.regulariser(self.operator.direct(self.x)) diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py index 3c7a8d1..647ae98 100755 --- a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py @@ -58,6 +58,7 @@ class FISTA(Algorithm): self.t_old = 1 self.update_objective() + self.configured = True def update(self): diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py b/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py index 14763c5..34bf954 100755 --- a/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py @@ -40,7 +40,7 @@ class GradientDescent(Algorithm): if k in args: args.pop(args.index(k)) if len(args) == 0: - return self.set_up(x_init=kwargs['x_init'], + self.set_up(x_init=kwargs['x_init'], objective_function=kwargs['objective_function'], rate=kwargs['rate']) @@ -61,6 +61,7 @@ class GradientDescent(Algorithm): self.memopt = False if self.memopt: self.x_update = x_init.copy() + self.configured = True def update(self): '''Single iteration''' diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py index 39b092b..3afd8b0 100644 --- a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py @@ -23,10 +23,16 @@ class PDHG(Algorithm): self.operator = kwargs.get('operator', None) self.g = kwargs.get('g', None) self.tau = kwargs.get('tau', None) - self.sigma = kwargs.get('sigma', None) + self.sigma = kwargs.get('sigma', 1.) + if self.f is not None and self.operator is not None and \ self.g is not None: + if self.tau is None: + # Compute operator Norm + normK = self.operator.norm() + # Primal & dual stepsizes + self.tau = 1/(self.sigma*normK**2) print ("Calling from creator") self.set_up(self.f, self.g, @@ -57,6 +63,8 @@ class PDHG(Algorithm): # relaxation parameter self.theta = 1 + self.update_objective() + self.configured = True def update(self): diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py b/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py index 30584d4..c73d323 100644 --- a/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py @@ -59,6 +59,7 @@ class SIRT(Algorithm): # Set up scaling matrices D and M. self.M = 1/self.operator.direct(self.operator.domain_geometry().allocate(value=1.0)) self.D = 1/self.operator.adjoint(self.operator.range_geometry().allocate(value=1.0)) + self.configured = True def update(self): @@ -67,8 +68,9 @@ class SIRT(Algorithm): self.x += self.relax_par * (self.D*self.operator.adjoint(self.M*self.r)) - if self.constraint != None: - self.x = self.constraint.prox(self.x,None) + if self.constraint is not None: + self.x = self.constraint.proximal(self.x,None) + # self.constraint.proximal(self.x,None, out=self.x) def update_objective(self): - self.loss.append(self.r.squared_norm())
\ No newline at end of file + self.loss.append(self.r.squared_norm()) diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/__init__.py b/Wrappers/Python/ccpi/optimisation/algorithms/__init__.py index 2dbacfc..8f255f3 100644 --- a/Wrappers/Python/ccpi/optimisation/algorithms/__init__.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/__init__.py @@ -27,7 +27,5 @@ from .CGLS import CGLS from .SIRT import SIRT from .GradientDescent import GradientDescent from .FISTA import FISTA -from .FBPD import FBPD from .PDHG import PDHG -from .PDHG import PDHG_old diff --git a/Wrappers/Python/test/test_algorithms.py b/Wrappers/Python/test/test_algorithms.py index 4121358..8c398f4 100755 --- a/Wrappers/Python/test/test_algorithms.py +++ b/Wrappers/Python/test/test_algorithms.py @@ -18,7 +18,6 @@ from ccpi.optimisation.functions import Norm2Sq, ZeroFunction, \ from ccpi.optimisation.algorithms import GradientDescent from ccpi.optimisation.algorithms import CGLS from ccpi.optimisation.algorithms import FISTA -from ccpi.optimisation.algorithms import FBPD |