diff options
author | Edoardo Pasca <edo.paskino@gmail.com> | 2019-02-18 15:19:11 +0000 |
---|---|---|
committer | Edoardo Pasca <edo.paskino@gmail.com> | 2019-02-18 15:19:11 +0000 |
commit | 5f82583109cd218e08c2a9e1cca21adca73ffe6d (patch) | |
tree | f2b652f465db179bc149fa8984f16f5314da6c2c /Wrappers | |
parent | 6de950b093a7b3602d615e7eb3786d9469ced930 (diff) | |
download | framework-5f82583109cd218e08c2a9e1cca21adca73ffe6d.tar.gz framework-5f82583109cd218e08c2a9e1cca21adca73ffe6d.tar.bz2 framework-5f82583109cd218e08c2a9e1cca21adca73ffe6d.tar.xz framework-5f82583109cd218e08c2a9e1cca21adca73ffe6d.zip |
added CGLS
Diffstat (limited to 'Wrappers')
-rw-r--r-- | Wrappers/Python/ccpi/optimisation/Algorithms.py | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/Algorithms.py b/Wrappers/Python/ccpi/optimisation/Algorithms.py index de7f0f8..bf7f1c3 100644 --- a/Wrappers/Python/ccpi/optimisation/Algorithms.py +++ b/Wrappers/Python/ccpi/optimisation/Algorithms.py @@ -302,3 +302,53 @@ class FBPD(Algorithm) # time and criterion self.loss = self.constraint(self.x) + self.data_fidelity(self.x) + self.regulariser(self.operator.direct(self.x)) +class CGLS(Algorithm): + + '''Conjugate Gradient Least Squares algorithm + + Parameters: + x_init: initial guess + operator: operator for forward/backward projections + data: data to operate on + ''' + def __init__(self, **kwargs): + super(CGLS, self).__init__() + self.x = kwargs.get('x_init', None) + self.operator = kwargs.get('operator', None) + self.data = kwargs.get('data', None) + if self.x is not None and self.operator is not None and \ + self.data is not None: + print ("Calling from creator") + return self.set_up(x_init =kwargs['x_init'], + operator=kwargs['operator'], + data =kwargs['data']) + + def set_up(self, x_init, operator , data ): + + self.r = data.copy() + self.x = x_init.copy() + + self.operator = operator + self.d = operator.adjoint(self.r) + + self.normr2 = self.d.norm() + + def should_stop(self): + '''stopping cryterion, currently only based on number of iterations''' + return self.iteration >= self.max_iteration + + def update(self): + + Ad = self.operator.direct(self.d) + alpha = self.normr2/Ad.norm() + self.x += alpha * self.d + self.r -= alpha * Ad + s = self.operator.adjoint(self.r) + + normr2_new = s.norm() + beta = normr2_new/self.normr2 + self.normr2 = normr2_new + self.d = s + beta*self.d + + def update_objective(self): + self.loss.append(self.r.norm()) |