summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py48
-rw-r--r--Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py60
2 files changed, 71 insertions, 37 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
index d1b5351..a165e55 100644
--- a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
@@ -152,28 +152,28 @@ def PDHG_old(f, g, operator, tau = None, sigma = None, opt = None, **kwargs):
if not memopt:
- y_old += sigma * operator.direct(xbar)
- y = f.proximal_conjugate(y_old, sigma)
-
- x_old -= tau*operator.adjoint(y)
- x = g.proximal(x_old, tau)
-
+ y_tmp = y_old + sigma * operator.direct(xbar)
+ y = f.proximal_conjugate(y_tmp, sigma)
+
+ x_tmp = x_old - tau*operator.adjoint(y)
+ x = g.proximal(x_tmp, tau)
+
x.subtract(x_old, out=xbar)
xbar *= theta
xbar += x
-
+
x_old.fill(x)
- y_old.fill(y)
+ y_old.fill(y)
-# if i%100==0:
-#
-# p1 = f(operator.direct(x)) + g(x)
-# d1 = - ( f.convex_conjugate(y) + g(-1*operator.adjoint(y)) )
-# primal.append(p1)
-# dual.append(d1)
-# pdgap.append(p1-d1)
-# print(p1, d1, p1-d1)
+ if i%100==0:
+
+ p1 = f(operator.direct(x)) + g(x)
+ d1 = - ( f.convex_conjugate(y) + g(-1*operator.adjoint(y)) )
+ primal.append(p1)
+ dual.append(d1)
+ pdgap.append(p1-d1)
+ print(p1, d1, p1-d1)
else:
@@ -196,14 +196,14 @@ def PDHG_old(f, g, operator, tau = None, sigma = None, opt = None, **kwargs):
x_old.fill(x)
y_old.fill(y)
-# if i%100==0:
-#
-# p1 = f(operator.direct(x)) + g(x)
-# d1 = - ( f.convex_conjugate(y) + g(-1*operator.adjoint(y)) )
-# primal.append(p1)
-# dual.append(d1)
-# pdgap.append(p1-d1)
-# print(p1, d1, p1-d1)
+ if i%100==0:
+
+ p1 = f(operator.direct(x)) + g(x)
+ d1 = - ( f.convex_conjugate(y) + g(-1*operator.adjoint(y)) )
+ primal.append(p1)
+ dual.append(d1)
+ pdgap.append(p1-d1)
+ print(p1, d1, p1-d1)
t_end = time.time()
diff --git a/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py b/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py
index 18af154..ae25bdb 100644
--- a/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py
@@ -19,23 +19,29 @@
import numpy
from ccpi.optimisation.functions import Function
-from ccpi.optimisation.functions.ScaledFunction import ScaledFunction
-from ccpi.framework import DataContainer, ImageData, ImageGeometry
+from ccpi.optimisation.functions.ScaledFunction import ScaledFunction
class KullbackLeibler(Function):
- def __init__(self,data,**kwargs):
+ ''' Assume that data > 0
+
+ '''
+
+ def __init__(self,data, **kwargs):
super(KullbackLeibler, self).__init__()
self.b = data
self.bnoise = kwargs.get('bnoise', 0)
+
+ def __call__(self, x):
+
+ # TODO check
+
self.sum_value = self.b + self.bnoise
if (self.sum_value.as_array()<0).any():
self.sum_value = numpy.inf
-
- def __call__(self, x):
if self.sum_value==numpy.inf:
return numpy.inf
@@ -43,22 +49,50 @@ class KullbackLeibler(Function):
return numpy.sum( x.as_array() - self.b.as_array() * numpy.log(self.sum_value.as_array()))
- def gradient(self, x):
+ def gradient(self, x, out=None):
#TODO Division check
- return 1 - self.b/(x + self.bnoise)
+ if out is None:
+ return 1 - self.b/(x + self.bnoise)
+ else:
+ self.b.divide(x+self.bnoise, out=out)
+ out.subtract(1, out=out)
- def convex_conjugate(self, x, out=None):
- pass
+ def convex_conjugate(self, x):
+
+ return self.b * ( numpy.log(self.b/(1-x)) - 1 ) - self.bnoise * (x - 1)
def proximal(self, x, tau, out=None):
- z = x + tau * self.bnoise
- return (z + 1) - ((z-1)**2 + 4 * tau * self.b).sqrt()
-
+ if out is None:
+ return 0.5 *( (x - self.bnoise - tau) + ( (x + self.bnoise - tau)**2 + 4*tau*self.b ) .sqrt() )
+ else:
+ tmp = 0.5 *( (x - self.bnoise - tau) + ( (x + self.bnoise - tau)**2 + 4*tau*self.b ) .sqrt() )
+ out.fill(tmp)
+
def proximal_conjugate(self, x, tau, out=None):
- pass
+
+
+ if out is None:
+ z = x + tau * self.bnoise
+ return 0.5*((z + 1) - ((z-1)**2 + 4 * tau * self.b).sqrt())
+ else:
+ z = x + tau * self.bnoise
+ res = 0.5*((z + 1) - ((z-1)**2 + 4 * tau * self.b).sqrt())
+ out.fill(res)
+
+
+
+ def __rmul__(self, scalar):
+
+ ''' Multiplication of L2NormSquared with a scalar
+
+ Returns: ScaledFunction
+
+ '''
+
+ return ScaledFunction(self, scalar)