summaryrefslogtreecommitdiffstats
path: root/Wrappers
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-02-18 15:19:11 +0000
committerEdoardo Pasca <edo.paskino@gmail.com>2019-02-18 15:19:11 +0000
commit5f82583109cd218e08c2a9e1cca21adca73ffe6d (patch)
treef2b652f465db179bc149fa8984f16f5314da6c2c /Wrappers
parent6de950b093a7b3602d615e7eb3786d9469ced930 (diff)
downloadframework-5f82583109cd218e08c2a9e1cca21adca73ffe6d.tar.gz
framework-5f82583109cd218e08c2a9e1cca21adca73ffe6d.tar.bz2
framework-5f82583109cd218e08c2a9e1cca21adca73ffe6d.tar.xz
framework-5f82583109cd218e08c2a9e1cca21adca73ffe6d.zip
added CGLS
Diffstat (limited to 'Wrappers')
-rw-r--r--Wrappers/Python/ccpi/optimisation/Algorithms.py50
1 files changed, 50 insertions, 0 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/Algorithms.py b/Wrappers/Python/ccpi/optimisation/Algorithms.py
index de7f0f8..bf7f1c3 100644
--- a/Wrappers/Python/ccpi/optimisation/Algorithms.py
+++ b/Wrappers/Python/ccpi/optimisation/Algorithms.py
@@ -302,3 +302,53 @@ class FBPD(Algorithm)
# time and criterion
self.loss = self.constraint(self.x) + self.data_fidelity(self.x) + self.regulariser(self.operator.direct(self.x))
+class CGLS(Algorithm):
+
+ '''Conjugate Gradient Least Squares algorithm
+
+ Parameters:
+ x_init: initial guess
+ operator: operator for forward/backward projections
+ data: data to operate on
+ '''
+ def __init__(self, **kwargs):
+ super(CGLS, self).__init__()
+ self.x = kwargs.get('x_init', None)
+ self.operator = kwargs.get('operator', None)
+ self.data = kwargs.get('data', None)
+ if self.x is not None and self.operator is not None and \
+ self.data is not None:
+ print ("Calling from creator")
+ return self.set_up(x_init =kwargs['x_init'],
+ operator=kwargs['operator'],
+ data =kwargs['data'])
+
+ def set_up(self, x_init, operator , data ):
+
+ self.r = data.copy()
+ self.x = x_init.copy()
+
+ self.operator = operator
+ self.d = operator.adjoint(self.r)
+
+ self.normr2 = self.d.norm()
+
+ def should_stop(self):
+ '''stopping cryterion, currently only based on number of iterations'''
+ return self.iteration >= self.max_iteration
+
+ def update(self):
+
+ Ad = self.operator.direct(self.d)
+ alpha = self.normr2/Ad.norm()
+ self.x += alpha * self.d
+ self.r -= alpha * Ad
+ s = self.operator.adjoint(self.r)
+
+ normr2_new = s.norm()
+ beta = normr2_new/self.normr2
+ self.normr2 = normr2_new
+ self.d = s + beta*self.d
+
+ def update_objective(self):
+ self.loss.append(self.r.norm())