From b9d3b0722c03cded15973417514d4639e390311e Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Mon, 18 Feb 2019 15:19:48 +0000 Subject: working unit test, initial tomography test --- .../optimisation/operators/CompositeOperator.py | 184 ++++++--------------- 1 file changed, 54 insertions(+), 130 deletions(-) (limited to 'Wrappers/Python') diff --git a/Wrappers/Python/ccpi/optimisation/operators/CompositeOperator.py b/Wrappers/Python/ccpi/optimisation/operators/CompositeOperator.py index ad307b7..be2d525 100755 --- a/Wrappers/Python/ccpi/optimisation/operators/CompositeOperator.py +++ b/Wrappers/Python/ccpi/optimisation/operators/CompositeOperator.py @@ -8,8 +8,12 @@ Created on Thu Feb 14 12:36:40 2019 import numpy from numbers import Number import functools +from ccpi.framework import AcquisitionData, ImageData + class Operator(object): '''Operator that maps from a space X -> Y''' + def __init__(self, **kwargs): + self.scalar = 1 def is_linear(self): '''Returns if the operator is linear''' return False @@ -30,6 +34,10 @@ class Operator(object): raise NotImplementedError def domain_dim(self): raise NotImplementedError + def __rmul__(self, other): + assert isinstance(other, Number) + self.scalar = other + return self class LinearOperator(Operator): '''Operator that maps from a space X -> Y''' @@ -38,6 +46,8 @@ class LinearOperator(Operator): return True def adjoint(self,x, out=None): raise NotImplementedError + +# this should go in the framework class CompositeDataContainer(object): '''Class to hold a composite operator''' @@ -260,118 +270,9 @@ class CompositeDataContainer(object): def __itruediv__(self, other): return self.__idiv__(other) def norm(self): - y = numpy.asarray([el.norm() for el in self.containers]) - return numpy.reshape(y, self.shape) - -import time -from ccpi.optimisation.funcs import ZeroFun - -class Algorithm(object): - '''Base class for iterative algorithms - - provides the minimal infrastructure. - Algorithms are iterables so can be easily run in a for loop. They will - stop as soon as the stop cryterion is met. - The user is required to implement the set_up, __init__, update and - should_stop and update_objective methods - ''' - - def __init__(self): - self.iteration = 0 - self.stop_cryterion = 'max_iter' - self.__max_iteration = 0 - self.__loss = [] - self.memopt = False - self.timing = [] - def set_up(self, *args, **kwargs): - raise NotImplementedError() - def update(self): - raise NotImplementedError() - - def should_stop(self): - '''stopping cryterion''' - raise NotImplementedError() - - def __iter__(self): - return self - def next(self): - '''python2 backwards compatibility''' - return self.__next__() - def __next__(self): - if self.should_stop(): - raise StopIteration() - else: - time0 = time.time() - self.update() - self.timing.append( time.time() - time0 ) - self.update_objective() - self.iteration += 1 - def get_output(self): - '''Returns the solution found''' - return self.x - def get_current_loss(self): - '''Returns the current value of the loss function''' - return self.__loss[-1] - def update_objective(self): - raise NotImplementedError() - @property - def loss(self): - return self.__loss - @property - def max_iteration(self): - return self.__max_iteration - @max_iteration.setter - def max_iteration(self, value): - assert isinstance(value, int) - self.__max_iteration = value - -class GradientDescent(Algorithm): - '''Implementation of a simple Gradient Descent algorithm - ''' - - def __init__(self, **kwargs): - '''initialisation can be done at creation time if all - proper variables are passed or later with set_up''' - super(GradientDescent, self).__init__() - self.x = None - self.rate = 0 - self.objective_function = None - self.regulariser = None - args = ['x_init', 'objective_function', 'rate'] - for k,v in kwargs.items(): - if k in args: - args.pop(args.index(k)) - if len(args) == 0: - return self.set_up(x_init=kwargs['x_init'], - objective_function=kwargs['objective_function'], - rate=kwargs['rate']) - - def should_stop(self): - '''stopping cryterion, currently only based on number of iterations''' - return self.iteration >= self.max_iteration - - def set_up(self, x_init, objective_function, rate): - '''initialisation of the algorithm''' - self.x = x_init.copy() - if self.memopt: - self.x_update = x_init.copy() - self.objective_function = objective_function - self.rate = rate - self.loss.append(objective_function(x_init)) - - def update(self): - '''Single iteration''' - if self.memopt: - self.objective_function.gradient(self.x, out=self.x_update) - self.x_update *= -self.rate - self.x += self.x_update - else: - self.x += -self.rate * self.objective_function.grad(self.x) - - def update_objective(self): - self.loss.append(self.objective_function(self.x)) - - + y = numpy.asarray([el.norm().sum() for el in self.containers]) + return y.sum() + class CompositeOperator(Operator): '''Class to hold a composite operator''' def __init__(self, *args, shape=None): @@ -416,10 +317,10 @@ class CompositeOperator(Operator): return CompositeDataContainer(*res, shape=shape) def adjoint(self, x, out=None): - shape = self.get_output_shape(x.shape) + shape = self.get_output_shape(x.shape, adjoint=True) res = [] - for row in range(self.shape[0]): - for col in range(self.shape[1]): + for row in range(self.shape[1]): + for col in range(self.shape[0]): if col == 0: prod = self.get_item(row,col).adjoint(x.get_item(col)) else: @@ -427,18 +328,25 @@ class CompositeOperator(Operator): res.append(prod) return CompositeDataContainer(*res, shape=shape) - def get_output_shape(self, xshape): + def get_output_shape(self, xshape, adjoint=False): print ("operator shape {} data shape {}".format(self.shape, xshape)) - if self.shape[1] != xshape[0]: + sshape = self.shape[1] + oshape = self.shape[0] + if adjoint: + sshape = self.shape[0] + oshape = self.shape[1] + if sshape != xshape[0]: raise ValueError('Incompatible shapes {} {}'.format(self.shape, xshape)) - print ((self.shape[0], xshape[-1])) - return (self.shape[0], xshape[-1]) + print ((oshape, xshape[-1])) + return (oshape, xshape[-1]) if __name__ == '__main__': #from ccpi.optimisation.Algorithms import GradientDescent from ccpi.plugins.ops import CCPiProjectorSimple - from ccpi.optimisation.ops import TomoIdentity, PowerMethodNonsquare + from ccpi.optimisation.ops import PowerMethodNonsquare + from ccpi.optimisation.ops import TomoIdentity from ccpi.optimisation.funcs import Norm2sq, Norm1 - from ccpi.framework import ImageGeometry, ImageData, AcquisitionGeometry + from ccpi.framework import ImageGeometry, AcquisitionGeometry + from ccpi.optimisation.Algorithms import CGLS import matplotlib.pyplot as plt ig0 = ImageGeometry(2,3,4) @@ -699,25 +607,41 @@ if __name__ == '__main__': ImageData(geometry=ig, dimension_labels=['horizontal_x','horizontal_y','vertical'])) # setup a tomo identity - I = TomoIdentity(geometry=ig) + I = 0.3 * TomoIdentity(geometry=ig) # composite operator - K = CompositeOperator(A, I) + K = CompositeOperator(A, I, shape=(2,1)) out = K.direct(X_init) f = Norm2sq(K,B) - f.L = 0.001 + f.L = 0.1 + + cg = CGLS() + cg.set_up(X_init, K, B ) + cg.max_iteration = 1 - gd = GradientDescent() - gd.set_up(X_init, f, 0.001 ) - gd.max_iteration = 2 + cgs = CGLS() + cgs.set_up(x_init, A, b ) + cgs.max_iteration = 2 out.__isub__(B) out2 = K.adjoint(out) #(2.0*self.c)*self.A.adjoint( self.A.direct(x) - self.b ) - for _ in gd: - print ("iteration {} {}".format(gd.iteration, gd.get_current_loss())) - \ No newline at end of file + for _ in cg: + print ("iteration {} {}".format(cg.iteration, cg.get_current_loss())) + + fig = plt.figure() + plt.imshow(cg.get_output().get_item(0,0).subset(vertical=0).as_array()) + plt.title('Composite CGLS') + plt.show() + + for _ in cgs: + print ("iteration {} {}".format(cgs.iteration, cgs.get_current_loss())) + + fig = plt.figure() + plt.imshow(cgs.get_output().subset(vertical=0).as_array()) + plt.title('Simple CGLS') + plt.show() \ No newline at end of file -- cgit v1.2.3