summaryrefslogtreecommitdiffstats
path: root/Wrappers
diff options
context:
space:
mode:
Diffstat (limited to 'Wrappers')
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/CGLS.py28
1 files changed, 13 insertions, 15 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
index 1695a73..661780e 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
@@ -49,19 +49,18 @@ class CGLS(Algorithm):
self.x = kwargs.get('x_init', None)
self.operator = kwargs.get('operator', None)
self.data = kwargs.get('data', None)
- self.tolerance = kwargs.get('tolerance', None)
+ self.tolerance = kwargs.get('tolerance', 1e-6)
if self.x is not None and self.operator is not None and \
self.data is not None:
+ print (self.__class__.__name__ , "set_up called from creator")
self.set_up(x_init =kwargs['x_init'],
operator=kwargs['operator'],
data =kwargs['data'])
- if self.tolerance is None:
- self.tolerance = 1e-6
def set_up(self, x_init, operator , data ):
- self.x = x_init
+ self.x = x_init.copy()
self.r = data - self.operator.direct(self.x)
self.s = self.operator.adjoint(self.r)
@@ -77,8 +76,7 @@ class CGLS(Algorithm):
self.normx = self.x.norm()
self.xmax = self.normx
- n = Norm2Sq(self.operator, self.data)
- self.loss.append(n(self.x))
+ self.loss.append(self.r.squared_norm())
self.configured = True
# def set_up(self, x_init, operator , data ):
@@ -160,16 +158,16 @@ class CGLS(Algorithm):
self.loss.append(a)
def should_stop(self):
+ return self.flag() or self.max_iteration_stop_cryterion()
+
+ def flag(self):
+ flag = (self.norms <= self.norms0 * self.tolerance) or (self.normx * self.tolerance >= 1)
+
+ if flag:
+ self.update_objective()
+ print (self.verbose_output())
+ return flag
- self.update_objective()
- flag = (self.norms <= self.norms0 * self.tolerance) or (self.normx * self.tolerance >= 1);
-
- #if self.gamma<=self.tolerance:
- if flag == 1 or self.max_iteration_stop_cryterion():
- print('Tolerance is reached: Iter: {}'.format(self.iteration))
- return True
-
-
#raise StopIteration()