summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python
diff options
context:
space:
mode:
authorepapoutsellis <epapoutsellis@gmail.com>2019-04-29 16:18:38 +0100
committerepapoutsellis <epapoutsellis@gmail.com>2019-04-29 16:18:38 +0100
commit4bacf27deec2abe993750ca7e1c873fd7aece8cd (patch)
treec3a676e4eef247af43a48e72ac26aedee7ec9296 /Wrappers/Python
parent7203efb28376e42e3c346b00c9c266f8c9febaf0 (diff)
downloadframework-4bacf27deec2abe993750ca7e1c873fd7aece8cd.tar.gz
framework-4bacf27deec2abe993750ca7e1c873fd7aece8cd.tar.bz2
framework-4bacf27deec2abe993750ca7e1c873fd7aece8cd.tar.xz
framework-4bacf27deec2abe993750ca7e1c873fd7aece8cd.zip
fix FISTA and FunctionOperatorComposition
Diffstat (limited to 'Wrappers/Python')
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/FISTA.py8
-rw-r--r--Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition.py87
2 files changed, 51 insertions, 44 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
index 8ea2b6c..ee51049 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
@@ -65,10 +65,10 @@ class FISTA(Algorithm):
# initialization
if memopt:
- self.y = x_init.clone()
- self.x_old = x_init.clone()
- self.x = x_init.clone()
- self.u = x_init.clone()
+ self.y = x_init.copy()
+ self.x_old = x_init.copy()
+ self.x = x_init.copy()
+ self.u = x_init.copy()
else:
self.x_old = x_init.copy()
self.y = x_init.copy()
diff --git a/Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition.py b/Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition.py
index 70511bb..8895f3d 100644
--- a/Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition.py
@@ -19,16 +19,13 @@ class FunctionOperatorComposition(Function):
'''
- def __init__(self, operator, function):
+ def __init__(self, function, operator):
super(FunctionOperatorComposition, self).__init__()
+
self.function = function
self.operator = operator
- alpha = 1
-
- if isinstance (function, ScaledFunction):
- alpha = function.scalar
- self.L = 2 * alpha * operator.norm()**2
+ self.L = function.L * operator.norm()**2
def __call__(self, x):
@@ -39,47 +36,57 @@ class FunctionOperatorComposition(Function):
'''
- return self.function(self.operator.direct(x))
+ return self.function(self.operator.direct(x))
+
+ def gradient(self, x, out=None):
+#
+ ''' Gradient takes into account the Operator'''
+ if out is None:
+ return self.operator.adjoint(self.function.gradient(self.operator.direct(x)))
+ else:
+ tmp = self.operator.range_geometry().allocate()
+ self.operator.direct(x, out=tmp)
+ self.function.gradient(tmp, out=tmp)
+ self.operator.adjoint(tmp, out=out)
- #TODO do not know if we need it
- def call_adjoint(self, x):
- return self.function(self.operator.adjoint(x))
+
+ #TODO do not know if we need it
+ #def call_adjoint(self, x):
+ #
+ # return self.function(self.operator.adjoint(x))
- def convex_conjugate(self, x):
-
- ''' convex_conjugate does not take into account the Operator'''
- return self.function.convex_conjugate(x)
- def proximal(self, x, tau, out=None):
-
- '''proximal does not take into account the Operator'''
- if out is None:
- return self.function.proximal(x, tau)
- else:
- self.function.proximal(x, tau, out=out)
-
+ #def convex_conjugate(self, x):
+ #
+ # ''' convex_conjugate does not take into account the Operator'''
+ # return self.function.convex_conjugate(x)
- def proximal_conjugate(self, x, tau, out=None):
+
- ''' proximal conjugate does not take into account the Operator'''
- if out is None:
- return self.function.proximal_conjugate(x, tau)
- else:
- self.function.proximal_conjugate(x, tau, out=out)
- def gradient(self, x, out=None):
-
- ''' Gradient takes into account the Operator'''
- if out is None:
- return self.operator.adjoint(
- self.function.gradient(self.operator.direct(x))
- )
- else:
- self.operator.adjoint(
- self.function.gradient(self.operator.direct(x),
- out=out)
- )
+
+if __name__ == '__main__':
+
+ from ccpi.framework import ImageGeometry
+ from ccpi.optimisation.operators import Gradient
+ from ccpi.optimisation.functions import L2NormSquared
+
+ M, N, K = 2,3
+ ig = ImageGeometry(voxel_num_x=M, voxel_num_y = N)
+
+ G = Gradient(ig)
+ alpha = 0.5
+
+ f = L2NormSquared()
+ f_comp = FunctionOperatorComposition(G, alpha * f)
+ x = ig.allocate('random_int')
+ print(f_comp.gradient(x).shape
+
+ )
+
+
+
\ No newline at end of file