diff options
author | Edoardo Pasca <edo.paskino@gmail.com> | 2019-02-18 15:19:48 +0000 |
---|---|---|
committer | Edoardo Pasca <edo.paskino@gmail.com> | 2019-02-18 15:19:48 +0000 |
commit | b9d3b0722c03cded15973417514d4639e390311e (patch) | |
tree | 3aaedd9e3e97f76157ad298974c8de00bcc486a8 /Wrappers/Python | |
parent | 5f82583109cd218e08c2a9e1cca21adca73ffe6d (diff) | |
download | framework-b9d3b0722c03cded15973417514d4639e390311e.tar.gz framework-b9d3b0722c03cded15973417514d4639e390311e.tar.bz2 framework-b9d3b0722c03cded15973417514d4639e390311e.tar.xz framework-b9d3b0722c03cded15973417514d4639e390311e.zip |
working unit test, initial tomography test
Diffstat (limited to 'Wrappers/Python')
-rwxr-xr-x | Wrappers/Python/ccpi/optimisation/operators/CompositeOperator.py | 184 |
1 files changed, 54 insertions, 130 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/operators/CompositeOperator.py b/Wrappers/Python/ccpi/optimisation/operators/CompositeOperator.py index ad307b7..be2d525 100755 --- a/Wrappers/Python/ccpi/optimisation/operators/CompositeOperator.py +++ b/Wrappers/Python/ccpi/optimisation/operators/CompositeOperator.py @@ -8,8 +8,12 @@ Created on Thu Feb 14 12:36:40 2019 import numpy
from numbers import Number
import functools
+from ccpi.framework import AcquisitionData, ImageData
+
class Operator(object):
'''Operator that maps from a space X -> Y'''
+ def __init__(self, **kwargs):
+ self.scalar = 1
def is_linear(self):
'''Returns if the operator is linear'''
return False
@@ -30,6 +34,10 @@ class Operator(object): raise NotImplementedError
def domain_dim(self):
raise NotImplementedError
+ def __rmul__(self, other):
+ assert isinstance(other, Number)
+ self.scalar = other
+ return self
class LinearOperator(Operator):
'''Operator that maps from a space X -> Y'''
@@ -38,6 +46,8 @@ class LinearOperator(Operator): return True
def adjoint(self,x, out=None):
raise NotImplementedError
+
+# this should go in the framework
class CompositeDataContainer(object):
'''Class to hold a composite operator'''
@@ -260,118 +270,9 @@ class CompositeDataContainer(object): def __itruediv__(self, other):
return self.__idiv__(other)
def norm(self):
- y = numpy.asarray([el.norm() for el in self.containers])
- return numpy.reshape(y, self.shape)
-
-import time
-from ccpi.optimisation.funcs import ZeroFun
-
-class Algorithm(object):
- '''Base class for iterative algorithms
-
- provides the minimal infrastructure.
- Algorithms are iterables so can be easily run in a for loop. They will
- stop as soon as the stop cryterion is met.
- The user is required to implement the set_up, __init__, update and
- should_stop and update_objective methods
- '''
-
- def __init__(self):
- self.iteration = 0
- self.stop_cryterion = 'max_iter'
- self.__max_iteration = 0
- self.__loss = []
- self.memopt = False
- self.timing = []
- def set_up(self, *args, **kwargs):
- raise NotImplementedError()
- def update(self):
- raise NotImplementedError()
-
- def should_stop(self):
- '''stopping cryterion'''
- raise NotImplementedError()
-
- def __iter__(self):
- return self
- def next(self):
- '''python2 backwards compatibility'''
- return self.__next__()
- def __next__(self):
- if self.should_stop():
- raise StopIteration()
- else:
- time0 = time.time()
- self.update()
- self.timing.append( time.time() - time0 )
- self.update_objective()
- self.iteration += 1
- def get_output(self):
- '''Returns the solution found'''
- return self.x
- def get_current_loss(self):
- '''Returns the current value of the loss function'''
- return self.__loss[-1]
- def update_objective(self):
- raise NotImplementedError()
- @property
- def loss(self):
- return self.__loss
- @property
- def max_iteration(self):
- return self.__max_iteration
- @max_iteration.setter
- def max_iteration(self, value):
- assert isinstance(value, int)
- self.__max_iteration = value
-
-class GradientDescent(Algorithm):
- '''Implementation of a simple Gradient Descent algorithm
- '''
-
- def __init__(self, **kwargs):
- '''initialisation can be done at creation time if all
- proper variables are passed or later with set_up'''
- super(GradientDescent, self).__init__()
- self.x = None
- self.rate = 0
- self.objective_function = None
- self.regulariser = None
- args = ['x_init', 'objective_function', 'rate']
- for k,v in kwargs.items():
- if k in args:
- args.pop(args.index(k))
- if len(args) == 0:
- return self.set_up(x_init=kwargs['x_init'],
- objective_function=kwargs['objective_function'],
- rate=kwargs['rate'])
-
- def should_stop(self):
- '''stopping cryterion, currently only based on number of iterations'''
- return self.iteration >= self.max_iteration
-
- def set_up(self, x_init, objective_function, rate):
- '''initialisation of the algorithm'''
- self.x = x_init.copy()
- if self.memopt:
- self.x_update = x_init.copy()
- self.objective_function = objective_function
- self.rate = rate
- self.loss.append(objective_function(x_init))
-
- def update(self):
- '''Single iteration'''
- if self.memopt:
- self.objective_function.gradient(self.x, out=self.x_update)
- self.x_update *= -self.rate
- self.x += self.x_update
- else:
- self.x += -self.rate * self.objective_function.grad(self.x)
-
- def update_objective(self):
- self.loss.append(self.objective_function(self.x))
-
-
+ y = numpy.asarray([el.norm().sum() for el in self.containers])
+ return y.sum()
+
class CompositeOperator(Operator):
'''Class to hold a composite operator'''
def __init__(self, *args, shape=None):
@@ -416,10 +317,10 @@ class CompositeOperator(Operator): return CompositeDataContainer(*res, shape=shape)
def adjoint(self, x, out=None):
- shape = self.get_output_shape(x.shape)
+ shape = self.get_output_shape(x.shape, adjoint=True)
res = []
- for row in range(self.shape[0]):
- for col in range(self.shape[1]):
+ for row in range(self.shape[1]):
+ for col in range(self.shape[0]):
if col == 0:
prod = self.get_item(row,col).adjoint(x.get_item(col))
else:
@@ -427,18 +328,25 @@ class CompositeOperator(Operator): res.append(prod)
return CompositeDataContainer(*res, shape=shape)
- def get_output_shape(self, xshape):
+ def get_output_shape(self, xshape, adjoint=False):
print ("operator shape {} data shape {}".format(self.shape, xshape))
- if self.shape[1] != xshape[0]:
+ sshape = self.shape[1]
+ oshape = self.shape[0]
+ if adjoint:
+ sshape = self.shape[0]
+ oshape = self.shape[1]
+ if sshape != xshape[0]:
raise ValueError('Incompatible shapes {} {}'.format(self.shape, xshape))
- print ((self.shape[0], xshape[-1]))
- return (self.shape[0], xshape[-1])
+ print ((oshape, xshape[-1]))
+ return (oshape, xshape[-1])
if __name__ == '__main__':
#from ccpi.optimisation.Algorithms import GradientDescent
from ccpi.plugins.ops import CCPiProjectorSimple
- from ccpi.optimisation.ops import TomoIdentity, PowerMethodNonsquare
+ from ccpi.optimisation.ops import PowerMethodNonsquare
+ from ccpi.optimisation.ops import TomoIdentity
from ccpi.optimisation.funcs import Norm2sq, Norm1
- from ccpi.framework import ImageGeometry, ImageData, AcquisitionGeometry
+ from ccpi.framework import ImageGeometry, AcquisitionGeometry
+ from ccpi.optimisation.Algorithms import CGLS
import matplotlib.pyplot as plt
ig0 = ImageGeometry(2,3,4)
@@ -699,25 +607,41 @@ if __name__ == '__main__': ImageData(geometry=ig, dimension_labels=['horizontal_x','horizontal_y','vertical']))
# setup a tomo identity
- I = TomoIdentity(geometry=ig)
+ I = 0.3 * TomoIdentity(geometry=ig)
# composite operator
- K = CompositeOperator(A, I)
+ K = CompositeOperator(A, I, shape=(2,1))
out = K.direct(X_init)
f = Norm2sq(K,B)
- f.L = 0.001
+ f.L = 0.1
+
+ cg = CGLS()
+ cg.set_up(X_init, K, B )
+ cg.max_iteration = 1
- gd = GradientDescent()
- gd.set_up(X_init, f, 0.001 )
- gd.max_iteration = 2
+ cgs = CGLS()
+ cgs.set_up(x_init, A, b )
+ cgs.max_iteration = 2
out.__isub__(B)
out2 = K.adjoint(out)
#(2.0*self.c)*self.A.adjoint( self.A.direct(x) - self.b )
- for _ in gd:
- print ("iteration {} {}".format(gd.iteration, gd.get_current_loss()))
-
\ No newline at end of file + for _ in cg:
+ print ("iteration {} {}".format(cg.iteration, cg.get_current_loss()))
+
+ fig = plt.figure()
+ plt.imshow(cg.get_output().get_item(0,0).subset(vertical=0).as_array())
+ plt.title('Composite CGLS')
+ plt.show()
+
+ for _ in cgs:
+ print ("iteration {} {}".format(cgs.iteration, cgs.get_current_loss()))
+
+ fig = plt.figure()
+ plt.imshow(cgs.get_output().subset(vertical=0).as_array())
+ plt.title('Simple CGLS')
+ plt.show()
\ No newline at end of file |