diff options
Diffstat (limited to 'Wrappers/Python')
| -rw-r--r-- | Wrappers/Python/ccpi/optimisation/Algorithms.py | 50 | 
1 files changed, 50 insertions, 0 deletions
| diff --git a/Wrappers/Python/ccpi/optimisation/Algorithms.py b/Wrappers/Python/ccpi/optimisation/Algorithms.py index de7f0f8..bf7f1c3 100644 --- a/Wrappers/Python/ccpi/optimisation/Algorithms.py +++ b/Wrappers/Python/ccpi/optimisation/Algorithms.py @@ -302,3 +302,53 @@ class FBPD(Algorithm)          # time and criterion          self.loss = self.constraint(self.x) + self.data_fidelity(self.x) + self.regulariser(self.operator.direct(self.x)) +class CGLS(Algorithm): + +    '''Conjugate Gradient Least Squares algorithm + +    Parameters: +      x_init: initial guess +      operator: operator for forward/backward projections +      data: data to operate on +    ''' +    def __init__(self, **kwargs): +        super(CGLS, self).__init__() +        self.x        = kwargs.get('x_init', None) +        self.operator = kwargs.get('operator', None) +        self.data     = kwargs.get('data', None) +        if self.x is not None and self.operator is not None and \ +           self.data is not None: +            print ("Calling from creator") +            return self.set_up(x_init  =kwargs['x_init'], +                               operator=kwargs['operator'], +                               data    =kwargs['data']) + +    def set_up(self, x_init, operator , data ): + +        self.r = data.copy() +        self.x = x_init.copy() + +        self.operator = operator +        self.d = operator.adjoint(self.r) + +        self.normr2 = self.d.norm() + +    def should_stop(self): +        '''stopping cryterion, currently only based on number of iterations''' +        return self.iteration >= self.max_iteration + +    def update(self): + +        Ad = self.operator.direct(self.d) +        alpha = self.normr2/Ad.norm() +        self.x += alpha * self.d +        self.r -= alpha * Ad +        s  = self.operator.adjoint(self.r) + +        normr2_new = s.norm() +        beta = normr2_new/self.normr2 +        self.normr2 = normr2_new +        self.d = s + beta*self.d + +    def update_objective(self): +        self.loss.append(self.r.norm()) | 
