From 8fec2c7984d2145f356ea272d62254c759524a86 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Mon, 1 Jul 2019 21:11:49 +0100 Subject: calculation of flag in method, redefinition of default stop criterion --- .../Python/ccpi/optimisation/algorithms/CGLS.py | 28 ++++++++++------------ 1 file changed, 13 insertions(+), 15 deletions(-) (limited to 'Wrappers') diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py index 1695a73..661780e 100755 --- a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py @@ -49,19 +49,18 @@ class CGLS(Algorithm): self.x = kwargs.get('x_init', None) self.operator = kwargs.get('operator', None) self.data = kwargs.get('data', None) - self.tolerance = kwargs.get('tolerance', None) + self.tolerance = kwargs.get('tolerance', 1e-6) if self.x is not None and self.operator is not None and \ self.data is not None: + print (self.__class__.__name__ , "set_up called from creator") self.set_up(x_init =kwargs['x_init'], operator=kwargs['operator'], data =kwargs['data']) - if self.tolerance is None: - self.tolerance = 1e-6 def set_up(self, x_init, operator , data ): - self.x = x_init + self.x = x_init.copy() self.r = data - self.operator.direct(self.x) self.s = self.operator.adjoint(self.r) @@ -77,8 +76,7 @@ class CGLS(Algorithm): self.normx = self.x.norm() self.xmax = self.normx - n = Norm2Sq(self.operator, self.data) - self.loss.append(n(self.x)) + self.loss.append(self.r.squared_norm()) self.configured = True # def set_up(self, x_init, operator , data ): @@ -160,16 +158,16 @@ class CGLS(Algorithm): self.loss.append(a) def should_stop(self): + return self.flag() or self.max_iteration_stop_cryterion() + + def flag(self): + flag = (self.norms <= self.norms0 * self.tolerance) or (self.normx * self.tolerance >= 1) + + if flag: + self.update_objective() + print (self.verbose_output()) + return flag - self.update_objective() - flag = (self.norms <= self.norms0 * self.tolerance) or (self.normx * self.tolerance >= 1); - - #if self.gamma<=self.tolerance: - if flag == 1 or self.max_iteration_stop_cryterion(): - print('Tolerance is reached: Iter: {}'.format(self.iteration)) - return True - - #raise StopIteration() -- cgit v1.2.3