diff options
Diffstat (limited to 'Wrappers/Python')
-rwxr-xr-x | Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py | 8 | ||||
-rw-r--r-- | Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition.py | 87 |
2 files changed, 51 insertions, 44 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py index 8ea2b6c..ee51049 100755 --- a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py @@ -65,10 +65,10 @@ class FISTA(Algorithm): # initialization if memopt: - self.y = x_init.clone() - self.x_old = x_init.clone() - self.x = x_init.clone() - self.u = x_init.clone() + self.y = x_init.copy() + self.x_old = x_init.copy() + self.x = x_init.copy() + self.u = x_init.copy() else: self.x_old = x_init.copy() self.y = x_init.copy() diff --git a/Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition.py b/Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition.py index 70511bb..8895f3d 100644 --- a/Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition.py +++ b/Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition.py @@ -19,16 +19,13 @@ class FunctionOperatorComposition(Function): ''' - def __init__(self, operator, function): + def __init__(self, function, operator): super(FunctionOperatorComposition, self).__init__() + self.function = function self.operator = operator - alpha = 1 - - if isinstance (function, ScaledFunction): - alpha = function.scalar - self.L = 2 * alpha * operator.norm()**2 + self.L = function.L * operator.norm()**2 def __call__(self, x): @@ -39,47 +36,57 @@ class FunctionOperatorComposition(Function): ''' - return self.function(self.operator.direct(x)) + return self.function(self.operator.direct(x)) + + def gradient(self, x, out=None): +# + ''' Gradient takes into account the Operator''' + if out is None: + return self.operator.adjoint(self.function.gradient(self.operator.direct(x))) + else: + tmp = self.operator.range_geometry().allocate() + self.operator.direct(x, out=tmp) + self.function.gradient(tmp, out=tmp) + self.operator.adjoint(tmp, out=out) - #TODO do not know if we need it - def call_adjoint(self, x): - return self.function(self.operator.adjoint(x)) + + #TODO do not know if we need it + #def call_adjoint(self, x): + # + # return self.function(self.operator.adjoint(x)) - def convex_conjugate(self, x): - - ''' convex_conjugate does not take into account the Operator''' - return self.function.convex_conjugate(x) - def proximal(self, x, tau, out=None): - - '''proximal does not take into account the Operator''' - if out is None: - return self.function.proximal(x, tau) - else: - self.function.proximal(x, tau, out=out) - + #def convex_conjugate(self, x): + # + # ''' convex_conjugate does not take into account the Operator''' + # return self.function.convex_conjugate(x) - def proximal_conjugate(self, x, tau, out=None): + - ''' proximal conjugate does not take into account the Operator''' - if out is None: - return self.function.proximal_conjugate(x, tau) - else: - self.function.proximal_conjugate(x, tau, out=out) - def gradient(self, x, out=None): - - ''' Gradient takes into account the Operator''' - if out is None: - return self.operator.adjoint( - self.function.gradient(self.operator.direct(x)) - ) - else: - self.operator.adjoint( - self.function.gradient(self.operator.direct(x), - out=out) - ) + +if __name__ == '__main__': + + from ccpi.framework import ImageGeometry + from ccpi.optimisation.operators import Gradient + from ccpi.optimisation.functions import L2NormSquared + + M, N, K = 2,3 + ig = ImageGeometry(voxel_num_x=M, voxel_num_y = N) + + G = Gradient(ig) + alpha = 0.5 + + f = L2NormSquared() + f_comp = FunctionOperatorComposition(G, alpha * f) + x = ig.allocate('random_int') + print(f_comp.gradient(x).shape + + ) + + +
\ No newline at end of file |