diff options
| author | Edoardo Pasca <edo.paskino@gmail.com> | 2019-03-14 15:34:04 +0000 | 
|---|---|---|
| committer | Edoardo Pasca <edo.paskino@gmail.com> | 2019-03-14 15:34:04 +0000 | 
| commit | f2b62709a1e4a9529dcee17b2cf7de87a5f02d2c (patch) | |
| tree | eca0629ff88154b8b17df14f5f8208d419cee014 /Wrappers | |
| parent | 9769759d3f7f1eab53631627474eade8e4c6f96a (diff) | |
| download | framework-f2b62709a1e4a9529dcee17b2cf7de87a5f02d2c.tar.gz framework-f2b62709a1e4a9529dcee17b2cf7de87a5f02d2c.tar.bz2 framework-f2b62709a1e4a9529dcee17b2cf7de87a5f02d2c.tar.xz framework-f2b62709a1e4a9529dcee17b2cf7de87a5f02d2c.zip  | |
removed alpha parameter
Diffstat (limited to 'Wrappers')
| -rw-r--r-- | Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py | 39 | ||||
| -rwxr-xr-x | Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py | 39 | 
2 files changed, 45 insertions, 33 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py b/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py index 5817317..54c947a 100644 --- a/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py +++ b/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py @@ -19,34 +19,41 @@ class SimpleL2NormSq(Function):      def __init__(self, alpha=1):          super(SimpleL2NormSq, self).__init__()          -        self.alpha = alpha -                  # Lispchitz constant of gradient -        self.L = 2*self.alpha +        self.L = 2      def __call__(self, x):          return self.alpha * x.power(2).sum() -    def gradient(self,x): -        return 2 * self.alpha * x +    def gradient(self,x, out=None): +        if out is None: +            return 2 * x +        else: +            out.fill(2*x)      def convex_conjugate(self,x): -        return (1/(4*self.alpha)) * x.power(2).sum() -     -    def proximal(self, x, tau): -        return x.divide(1+2*tau*self.alpha) +        return (1/4) * x.squared_norm() +         +    def proximal(self, x, tau, out=None): +        if out is None: +            return x.divide(1+2*tau) +        else: +            x.divide(1+2*tau, out=out) -    def proximal_conjugate(self, x, tau): -        return x.divide(1 + tau/(2*self.alpha) )     +    def proximal_conjugate(self, x, tau, out=None): +        if out is None: +            return x.divide(1 + tau/2)     +        else: +            x.divide(1+tau/2, out=out) +  ############################   L2NORM FUNCTIONS   #############################  class L2NormSq(SimpleL2NormSq): -    def __init__(self, alpha, **kwargs): +    def __init__(self, **kwargs): -        super(L2NormSq, self).__init__(alpha) -        self.alpha = alpha +        super(L2NormSq, self).__init__()          self.b = kwargs.get('b',None)                    def __call__(self, x): @@ -59,9 +66,9 @@ class L2NormSq(SimpleL2NormSq):      def gradient(self, x):          if self.b is None: -            return 2*self.alpha * x  +            return 2 * x           else: -            return 2*self.alpha * (x - self.b)  +            return 2 * (x - self.b)       def convex_conjugate(self, x): diff --git a/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py b/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py index f2e39fb..7e2f20a 100755 --- a/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py +++ b/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py @@ -4,22 +4,17 @@ import numpy  class ScaledFunction(object):
      '''ScaledFunction
 -    A class to represent the scalar multiplication of an Operator with a scalar.
 -    It holds an operator and a scalar. Basically it returns the multiplication
 -    of the result of direct and adjoint of the operator with the scalar.
 -    For the rest it behaves like the operator it holds.
 +    A class to represent the scalar multiplication of an Function with a scalar.
 +    It holds a function and a scalar. Basically it returns the multiplication
 +    of the product of the function __call__, convex_conjugate and gradient with the scalar.
 +    For the rest it behaves like the function it holds.
      Args:
 -       operator (Operator): a Operator or LinearOperator
 +       function (Function): a Function or BlockOperator
         scalar (Number): a scalar multiplier
      Example:
         The scaled operator behaves like the following:
 -       sop = ScaledOperator(operator, scalar)
 -       sop.direct(x) = scalar * operator.direct(x)
 -       sop.adjoint(x) = scalar * operator.adjoint(x)
 -       sop.norm() = operator.norm()
 -       sop.range_geometry() = operator.range_geometry()
 -       sop.domain_geometry() = operator.domain_geometry()
 +       
      '''
      def __init__(self, function, scalar):
          super(ScaledFunction, self).__init__()
 @@ -30,31 +25,41 @@ class ScaledFunction(object):          self.function = function
      def __call__(self,x, out=None):
 +        '''Evaluates the function at x '''
          return self.scalar * self.function(x)
 -    def call_adjoint(self, x, out=None):
 -        return self.scalar * self.function.call_adjoint(x, out=out)
 -
      def convex_conjugate(self, x, out=None):
 -        return self.scalar * self.function.convex_conjugate(x, out=out)
 +        '''returns the convex_conjugate of the scaled function '''
 +        if out is None:
 +            return self.scalar * self.function.convex_conjugate(x/self.scalar, out=out)
 +        else:
 +            out.fill(self.function.convex_conjugate(x/self.scalar))
 +            out *= self.scalar
      def proximal_conjugate(self, x, tau, out = None):
 -        '''TODO check if this is mathematically correct'''
 +        '''This returns the proximal operator for the function at x, tau
 +        
 +        TODO check if this is mathematically correct'''
          return self.function.proximal_conjugate(x, tau, out=out)
      def grad(self, x):
 +        '''Alias of gradient(x,None)'''
          warnings.warn('''This method will disappear in following 
          versions of the CIL. Use gradient instead''', DeprecationWarning)
          return self.gradient(x, out=None)
      def prox(self, x, tau):
 +        '''Alias of proximal(x, tau, None)'''
          warnings.warn('''This method will disappear in following 
          versions of the CIL. Use proximal instead''', DeprecationWarning)
          return self.proximal(x, out=None)
      def gradient(self, x, out=None):
 +        '''Returns the gradient of the function at x, if the function is differentiable'''
          return self.scalar * self.function.gradient(x, out=out)
      def proximal(self, x, tau, out=None):
 -        '''TODO check if this is mathematically correct'''
 +        '''This returns the proximal operator for the function at x, tau
 +        
 +        TODO check if this is mathematically correct'''
          return self.function.proximal(x, tau, out=out)
  | 
