diff options
Diffstat (limited to 'Wrappers/Python')
-rwxr-xr-x | Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py | 49 |
1 files changed, 47 insertions, 2 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py index e65bc89..cb4f049 100755 --- a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py @@ -23,6 +23,8 @@ Created on Thu Feb 21 11:11:23 2019 """ from ccpi.optimisation.algorithms import Algorithm +import numpy + class CGLS(Algorithm): '''Conjugate Gradient Least Squares algorithm @@ -54,13 +56,16 @@ class CGLS(Algorithm): self.normr2 = self.d.squared_norm() + + self.s = self.operator.domain_geometry().allocate() #if isinstance(self.normr2, Iterable): # self.normr2 = sum(self.normr2) #self.normr2 = numpy.sqrt(self.normr2) #print ("set_up" , self.normr2) def update(self): - + self.update_new() + def update_old(self): Ad = self.operator.direct(self.d) #norm = (Ad*Ad).sum() #if isinstance(norm, Iterable): @@ -82,5 +87,45 @@ class CGLS(Algorithm): self.normr2 = normr2_new self.d = s + beta*self.d + def update_new(self): + + Ad = self.operator.direct(self.d) + norm = Ad.squared_norm() + if norm == 0.: + print ('cannot update solution') + raise StopIteration() + alpha = self.normr2/norm + if alpha == 0.: + print ('cannot update solution') + raise StopIteration() + self.d *= alpha + Ad *= alpha + self.r -= Ad + if numpy.isnan(self.r.as_array()).any(): + print ("some nan") + raise StopIteration() + self.x += self.d + + self.operator.adjoint(self.r, out=self.s) + s = self.s + + normr2_new = s.squared_norm() + + beta = normr2_new/self.normr2 + self.normr2 = normr2_new + self.d *= (beta/alpha) + self.d += s + def update_objective(self): - self.loss.append(self.r.squared_norm()) + a = self.r.squared_norm() + if a is numpy.nan: + raise StopIteration() + self.loss.append(a) + +# def should_stop(self): +# if self.iteration > 0: +# x = self.get_last_objective() +# a = x > 0 +# return self.max_iteration_stop_cryterion() or (not a) +# else: +# return False |