summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-06-06 10:19:57 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2019-06-06 10:19:57 +0100
commit1bdd5f572988caa3888b33a0b422692fa78962ef (patch)
treef91643afe242bc321e3ec2108b6162694f0ae970
parentb234f4cf26ee56da94211dc15c9b277c7c29fff4 (diff)
downloadframework-1bdd5f572988caa3888b33a0b422692fa78962ef.tar.gz
framework-1bdd5f572988caa3888b33a0b422692fa78962ef.tar.bz2
framework-1bdd5f572988caa3888b33a0b422692fa78962ef.tar.xz
framework-1bdd5f572988caa3888b33a0b422692fa78962ef.zip
add memopt and some checks
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/CGLS.py49
1 files changed, 47 insertions, 2 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
index e65bc89..cb4f049 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
@@ -23,6 +23,8 @@ Created on Thu Feb 21 11:11:23 2019
"""
from ccpi.optimisation.algorithms import Algorithm
+import numpy
+
class CGLS(Algorithm):
'''Conjugate Gradient Least Squares algorithm
@@ -54,13 +56,16 @@ class CGLS(Algorithm):
self.normr2 = self.d.squared_norm()
+
+ self.s = self.operator.domain_geometry().allocate()
#if isinstance(self.normr2, Iterable):
# self.normr2 = sum(self.normr2)
#self.normr2 = numpy.sqrt(self.normr2)
#print ("set_up" , self.normr2)
def update(self):
-
+ self.update_new()
+ def update_old(self):
Ad = self.operator.direct(self.d)
#norm = (Ad*Ad).sum()
#if isinstance(norm, Iterable):
@@ -82,5 +87,45 @@ class CGLS(Algorithm):
self.normr2 = normr2_new
self.d = s + beta*self.d
+ def update_new(self):
+
+ Ad = self.operator.direct(self.d)
+ norm = Ad.squared_norm()
+ if norm == 0.:
+ print ('cannot update solution')
+ raise StopIteration()
+ alpha = self.normr2/norm
+ if alpha == 0.:
+ print ('cannot update solution')
+ raise StopIteration()
+ self.d *= alpha
+ Ad *= alpha
+ self.r -= Ad
+ if numpy.isnan(self.r.as_array()).any():
+ print ("some nan")
+ raise StopIteration()
+ self.x += self.d
+
+ self.operator.adjoint(self.r, out=self.s)
+ s = self.s
+
+ normr2_new = s.squared_norm()
+
+ beta = normr2_new/self.normr2
+ self.normr2 = normr2_new
+ self.d *= (beta/alpha)
+ self.d += s
+
def update_objective(self):
- self.loss.append(self.r.squared_norm())
+ a = self.r.squared_norm()
+ if a is numpy.nan:
+ raise StopIteration()
+ self.loss.append(a)
+
+# def should_stop(self):
+# if self.iteration > 0:
+# x = self.get_last_objective()
+# a = x > 0
+# return self.max_iteration_stop_cryterion() or (not a)
+# else:
+# return False