summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-03-14 14:52:36 +0000
committerEdoardo Pasca <edo.paskino@gmail.com>2019-03-14 14:52:36 +0000
commitb3be9080f736964486c8f647a68720d2836eb89d (patch)
treedc13762d96aa721f0c7b7bb8b918f9427e747421
parent53689e374625441867c6169829b1ee9b167547f4 (diff)
downloadframework-b3be9080f736964486c8f647a68720d2836eb89d.tar.gz
framework-b3be9080f736964486c8f647a68720d2836eb89d.tar.bz2
framework-b3be9080f736964486c8f647a68720d2836eb89d.tar.xz
framework-b3be9080f736964486c8f647a68720d2836eb89d.zip
use ScaledFunction
-rw-r--r--Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition.py23
1 files changed, 17 insertions, 6 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition.py b/Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition.py
index 0f3defe..3ac4358 100644
--- a/Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition.py
@@ -9,16 +9,20 @@ Created on Fri Mar 8 09:55:36 2019
import numpy as np
#from ccpi.optimisation.funcs import Function
from ccpi.optimisation.functions import Function
+from ccpi.optimisation.functions import ScaledFunction
class FunctionOperatorComposition(Function):
def __init__(self, operator, function):
-
+ super(FunctionOperatorComposition, self).__init__()
self.function = function
self.operator = operator
- self.L = 2*self.function.alpha*operator.norm()**2
- super(FunctionOperatorComposition, self).__init__()
+ alpha = 1
+ if isinstance (function, ScaledFunction):
+ alpha = function.scalar
+ self.L = 2 * alpha * operator.norm()**2
+
def __call__(self, x):
@@ -45,10 +49,17 @@ class FunctionOperatorComposition(Function):
return self.function.proximal_conjugate(x, tau)
- def gradient(self, x):
+ def gradient(self, x, out=None):
''' Gradient takes into account the Operator'''
-
- return self.operator.adjoint(self.function.gradient(self.operator.direct(x)))
+ if out is None:
+ return self.operator.adjoint(
+ self.function.gradient(self.operator.direct(x))
+ )
+ else:
+ self.operator.adjoint(
+ self.function.gradient(self.operator.direct(x),
+ out=out)
+ )
\ No newline at end of file