summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-04-01 16:34:38 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2019-04-01 16:34:38 +0100
commit0e8dac47faf88379175310552a4611ca34f407ea (patch)
tree2a5ea27a8537eb7ff9b7dea933482440d6621857 /Wrappers/Python
parent93517aa9f1472458fa962beae1abebb3e1223a6c (diff)
downloadframework-0e8dac47faf88379175310552a4611ca34f407ea.tar.gz
framework-0e8dac47faf88379175310552a4611ca34f407ea.tar.bz2
framework-0e8dac47faf88379175310552a4611ca34f407ea.tar.xz
framework-0e8dac47faf88379175310552a4611ca34f407ea.zip
massive clean
Diffstat (limited to 'Wrappers/Python')
-rwxr-xr-xWrappers/Python/ccpi/optimisation/funcs.py193
1 files changed, 12 insertions, 181 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/funcs.py b/Wrappers/Python/ccpi/optimisation/funcs.py
index 8ce54c7..6741020 100755
--- a/Wrappers/Python/ccpi/optimisation/funcs.py
+++ b/Wrappers/Python/ccpi/optimisation/funcs.py
@@ -21,108 +21,9 @@ from ccpi.optimisation.ops import Identity, FiniteDiff2D
import numpy
from ccpi.framework import DataContainer
import warnings
+from ccpi.optimisation.functions import Function
-def isSizeCorrect(data1 ,data2):
- if issubclass(type(data1), DataContainer) and \
- issubclass(type(data2), DataContainer):
- # check dimensionality
- if data1.check_dimensions(data2):
- return True
- elif issubclass(type(data1) , numpy.ndarray) and \
- issubclass(type(data2) , numpy.ndarray):
- return data1.shape == data2.shape
- else:
- raise ValueError("{0}: getting two incompatible types: {1} {2}"\
- .format('Function', type(data1), type(data2)))
- return False
-
-class Function(object):
- def __init__(self):
- self.L = None
- def __call__(self,x, out=None): raise NotImplementedError
- def grad(self, x):
- warnings.warn("grad method is deprecated. use gradient instead", DeprecationWarning)
- return self.gradient(x, out=None)
- def prox(self, x, tau):
- warnings.warn("prox method is deprecated. use proximal instead", DeprecationWarning)
- return self.proximal(x,tau,out=None)
- def gradient(self, x, out=None): raise NotImplementedError
- def proximal(self, x, tau, out=None): raise NotImplementedError
-
-
-class Norm2(Function):
-
- def __init__(self,
- gamma=1.0,
- direction=None):
- super(Norm2, self).__init__()
- self.gamma = gamma;
- self.direction = direction;
-
- def __call__(self, x, out=None):
-
- if out is None:
- xx = numpy.sqrt(numpy.sum(numpy.square(x.as_array()), self.direction,
- keepdims=True))
- else:
- if isSizeCorrect(out, x):
- # check dimensionality
- if issubclass(type(out), DataContainer):
- arr = out.as_array()
- numpy.square(x.as_array(), out=arr)
- xx = numpy.sqrt(numpy.sum(arr, self.direction, keepdims=True))
-
- elif issubclass(type(out) , numpy.ndarray):
- numpy.square(x.as_array(), out=out)
- xx = numpy.sqrt(numpy.sum(out, self.direction, keepdims=True))
- else:
- raise ValueError ('Wrong size: x{0} out{1}'.format(x.shape,out.shape) )
-
- p = numpy.sum(self.gamma*xx)
-
- return p
-
- def prox(self, x, tau):
-
- xx = numpy.sqrt(numpy.sum( numpy.square(x.as_array()), self.direction,
- keepdims=True ))
- xx = numpy.maximum(0, 1 - tau*self.gamma / xx)
- p = x.as_array() * xx
-
- return type(x)(p,geometry=x.geometry)
- def proximal(self, x, tau, out=None):
- if out is None:
- return self.prox(x,tau)
- else:
- if isSizeCorrect(out, x):
- # check dimensionality
- if issubclass(type(out), DataContainer):
- numpy.square(x.as_array(), out = out.as_array())
- xx = numpy.sqrt(numpy.sum( out.as_array() , self.direction,
- keepdims=True ))
- xx = numpy.maximum(0, 1 - tau*self.gamma / xx)
- x.multiply(xx, out= out.as_array())
-
-
- elif issubclass(type(out) , numpy.ndarray):
- numpy.square(x.as_array(), out=out)
- xx = numpy.sqrt(numpy.sum(out, self.direction, keepdims=True))
-
- xx = numpy.maximum(0, 1 - tau*self.gamma / xx)
- x.multiply(xx, out= out)
- else:
- raise ValueError ('Wrong size: x{0} out{1}'.format(x.shape,out.shape) )
-
-
-class TV2D(Norm2):
-
- def __init__(self, gamma):
- super(TV2D,self).__init__(gamma, 0)
- self.op = FiniteDiff2D()
- self.L = self.op.get_max_sing_val()
-
-
# Define a class for squared 2-norm
class Norm2sq(Function):
'''
@@ -148,8 +49,8 @@ class Norm2sq(Function):
self.c = c # Default 1.
if memopt:
try:
- self.adjoint_placehold = A.range_geometry().allocate()
- self.direct_placehold = A.domain_geometry().allocate()
+ self.range_tmp = A.range_geometry().allocate()
+ self.domain_tmp = A.domain_geometry().allocate()
self.memopt = True
except NameError as ne:
warnings.warn(str(ne))
@@ -164,7 +65,7 @@ class Norm2sq(Function):
# Compute the Lipschitz parameter from the operator if possible
# Leave it initialised to None otherwise
try:
- self.L = 2.0*self.c*(self.A.get_max_sing_val()**2)
+ self.L = 2.0*self.c*(self.A.norm()**2)
except AttributeError as ae:
pass
except NotImplementedError as noe:
@@ -192,88 +93,16 @@ class Norm2sq(Function):
if self.memopt:
#return 2.0*self.c*self.A.adjoint( self.A.direct(x) - self.b )
- self.A.direct(x, out=self.adjoint_placehold)
- self.adjoint_placehold.__isub__( self.b )
- self.A.adjoint(self.adjoint_placehold, out=self.direct_placehold)
- #self.direct_placehold.__imul__(2.0 * self.c)
- ## can this be avoided?
- #out.fill(self.direct_placehold)
- self.direct_placehold.multiply(2.0*self.c, out=out)
+ self.A.direct(x, out=self.range_tmp)
+ self.range_tmp -= self.b
+ self.A.adjoint(self.range_tmp, out=out)
+ #self.direct_placehold.multiply(2.0*self.c, out=out)
+ out *= (self.c * 2.0)
else:
return (2.0*self.c)*self.A.adjoint( self.A.direct(x) - self.b )
-class ZeroFun(Function):
-
- def __init__(self,gamma=0,L=1):
- self.gamma = gamma
- self.L = L
- super(ZeroFun, self).__init__()
-
- def __call__(self,x):
- return 0
-
- def prox(self,x,tau):
- return x.copy()
-
- def proximal(self, x, tau, out=None):
- if out is None:
- return self.prox(x, tau)
- else:
- if isSizeCorrect(out, x):
- # check dimensionality
- if issubclass(type(out), DataContainer):
- out.fill(x)
-
- elif issubclass(type(out) , numpy.ndarray):
- out[:] = x
- else:
- raise ValueError ('Wrong size: x{0} out{1}'
- .format(x.shape,out.shape) )
-
-# A more interesting example, least squares plus 1-norm minimization.
-# Define class to represent 1-norm including prox function
-class Norm1(Function):
-
- def __init__(self,gamma):
- super(Norm1, self).__init__()
- self.gamma = gamma
- self.L = 1
- self.sign_x = None
-
- def __call__(self,x,out=None):
- if out is None:
- return self.gamma*(x.abs().sum())
- else:
- if not x.shape == out.shape:
- raise ValueError('Norm1 Incompatible size:',
- x.shape, out.shape)
- x.abs(out=out)
- return out.sum() * self.gamma
-
- def prox(self,x,tau):
- return (x.abs() - tau*self.gamma).maximum(0) * x.sign()
-
- def proximal(self, x, tau, out=None):
- if out is None:
- return self.prox(x, tau)
- else:
- if isSizeCorrect(x,out):
- # check dimensionality
- if issubclass(type(out), DataContainer):
- v = (x.abs() - tau*self.gamma).maximum(0)
- x.sign(out=out)
- out *= v
- #out.fill(self.prox(x,tau))
- elif issubclass(type(out) , numpy.ndarray):
- v = (x.abs() - tau*self.gamma).maximum(0)
- out[:] = x.sign()
- out *= v
- #out[:] = self.prox(x,tau)
- else:
- raise ValueError ('Wrong size: x{0} out{1}'.format(x.shape,out.shape) )
-
# Box constraints indicator function. Calling returns 0 if argument is within
# the box. The prox operator is projection onto the box. Only implements one
# scalar lower and one upper as constraint on all elements. Should generalise
@@ -282,9 +111,10 @@ class IndicatorBox(Function):
def __init__(self,lower=-numpy.inf,upper=numpy.inf):
# Do nothing
+ super(IndicatorBox, self).__init__()
self.lower = lower
self.upper = upper
- super(IndicatorBox, self).__init__()
+
def __call__(self,x):
@@ -315,3 +145,4 @@ class IndicatorBox(Function):
x.sign(out=self.sign_x)
out.__imul__( self.sign_x )
+