summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-03-14 14:51:11 +0000
committerEdoardo Pasca <edo.paskino@gmail.com>2019-03-14 14:51:11 +0000
commit2bc9cce049c6ae588562ac88e089553a3dcc6d19 (patch)
treefa62be19a67aa932cbea8032dfe87778a8339a77
parentc9748b96e531a64c4e56909ab19a0b82fc01eb45 (diff)
downloadframework-2bc9cce049c6ae588562ac88e089553a3dcc6d19.tar.gz
framework-2bc9cce049c6ae588562ac88e089553a3dcc6d19.tar.bz2
framework-2bc9cce049c6ae588562ac88e089553a3dcc6d19.tar.xz
framework-2bc9cce049c6ae588562ac88e089553a3dcc6d19.zip
added ScaledFunction
-rwxr-xr-xWrappers/Python/ccpi/optimisation/functions/Function.py48
-rwxr-xr-xWrappers/Python/ccpi/optimisation/functions/ScaledFunction.py60
2 files changed, 108 insertions, 0 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/functions/Function.py b/Wrappers/Python/ccpi/optimisation/functions/Function.py
new file mode 100755
index 0000000..43ce900
--- /dev/null
+++ b/Wrappers/Python/ccpi/optimisation/functions/Function.py
@@ -0,0 +1,48 @@
+# -*- coding: utf-8 -*-
+# This work is part of the Core Imaging Library developed by
+# Visual Analytics and Imaging System Group of the Science Technology
+# Facilities Council, STFC
+
+# Copyright 2018-2019 Jakob Jorgensen, Daniil Kazantsev and Edoardo Pasca
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+
+class Function(object):
+ '''Abstract class representing a function
+
+
+ '''
+ def __init__(self):
+ self.L = None
+ def __call__(self,x, out=None):
+ raise NotImplementedError
+ def call_adjoint(self, x, out=None):
+ raise NotImplementedError
+ def convex_conjugate(self, x, out=None):
+ raise NotImplementedError
+ def proximal_conjugate(self, x, tau, out = None):
+ raise NotImplementedError
+ def grad(self, x):
+ warnings.warn('''This method will disappear in following
+ versions of the CIL. Use gradient instead''', DeprecationWarning)
+ return self.gradient(x, out=None)
+ def prox(self, x, tau):
+ warnings.warn('''This method will disappear in following
+ versions of the CIL. Use proximal instead''', DeprecationWarning)
+ return self.proximal(x, out=None)
+ def gradient(self, x, out=None):
+ raise NotImplementedError
+ def proximal(self, x, tau, out=None):
+ raise NotImplementedError \ No newline at end of file
diff --git a/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py b/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py
new file mode 100755
index 0000000..f2e39fb
--- /dev/null
+++ b/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py
@@ -0,0 +1,60 @@
+from numbers import Number
+import numpy
+
+class ScaledFunction(object):
+ '''ScaledFunction
+
+ A class to represent the scalar multiplication of an Operator with a scalar.
+ It holds an operator and a scalar. Basically it returns the multiplication
+ of the result of direct and adjoint of the operator with the scalar.
+ For the rest it behaves like the operator it holds.
+
+ Args:
+ operator (Operator): a Operator or LinearOperator
+ scalar (Number): a scalar multiplier
+ Example:
+ The scaled operator behaves like the following:
+ sop = ScaledOperator(operator, scalar)
+ sop.direct(x) = scalar * operator.direct(x)
+ sop.adjoint(x) = scalar * operator.adjoint(x)
+ sop.norm() = operator.norm()
+ sop.range_geometry() = operator.range_geometry()
+ sop.domain_geometry() = operator.domain_geometry()
+ '''
+ def __init__(self, function, scalar):
+ super(ScaledFunction, self).__init__()
+ self.L = None
+ if not isinstance (scalar, Number):
+ raise TypeError('expected scalar: got {}'.format(type(scalar)))
+ self.scalar = scalar
+ self.function = function
+
+ def __call__(self,x, out=None):
+ return self.scalar * self.function(x)
+
+ def call_adjoint(self, x, out=None):
+ return self.scalar * self.function.call_adjoint(x, out=out)
+
+ def convex_conjugate(self, x, out=None):
+ return self.scalar * self.function.convex_conjugate(x, out=out)
+
+ def proximal_conjugate(self, x, tau, out = None):
+ '''TODO check if this is mathematically correct'''
+ return self.function.proximal_conjugate(x, tau, out=out)
+
+ def grad(self, x):
+ warnings.warn('''This method will disappear in following
+ versions of the CIL. Use gradient instead''', DeprecationWarning)
+ return self.gradient(x, out=None)
+
+ def prox(self, x, tau):
+ warnings.warn('''This method will disappear in following
+ versions of the CIL. Use proximal instead''', DeprecationWarning)
+ return self.proximal(x, out=None)
+
+ def gradient(self, x, out=None):
+ return self.scalar * self.function.gradient(x, out=out)
+
+ def proximal(self, x, tau, out=None):
+ '''TODO check if this is mathematically correct'''
+ return self.function.proximal(x, tau, out=out)