summaryrefslogtreecommitdiffstats
path: root/Wrappers
diff options
context:
space:
mode:
Diffstat (limited to 'Wrappers')
-rw-r--r--Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py39
-rw-r--r--Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py10
2 files changed, 16 insertions, 33 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py b/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py
index 22d21fd..14b5ea0 100644
--- a/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py
@@ -62,19 +62,6 @@ class KullbackLeibler(Function):
if out is None:
return 1 - self.b/(x + self.bnoise)
else:
-#<<<<<<< HEAD
-# self.b.divide(x+self.bnoise, out=out)
-# out.subtract(1, out=out)
-#
-# def convex_conjugate(self, x):
-#
-# tmp = self.b/( 1 - x )
-# ind = tmp.as_array()>0
-#
-# sel
-#
-# return (self.b * ( ImageData( numpy.log(tmp) ) - 1 ) - self.bnoise * (x - 1)).sum()
-#=======
x.add(self.bnoise, out=out)
self.b.divide(out, out=out)
out.subtract(1, out=out)
@@ -116,23 +103,19 @@ class KullbackLeibler(Function):
if out is None:
z = x + tau * self.bnoise
-
return 0.5*((z + 1) - ((z-1)**2 + 4 * tau * self.b).sqrt())
else:
- z = x + tau * self.bnoise
- res = 0.5*((z + 1) - ((z-1)**2 + 4 * tau * self.b).sqrt())
- out.fill(res)
-# else:
-# z_m = x + tau * self.bnoise -1
-# self.b.multiply(4*tau, out=out)
-# z_m.multiply(z_m, out=z_m)
-# out += z_m
-# out.sqrt(out=out)
-# z = z_m + 2
-# z_m.sqrt(out=z_m)
-# z_m += 2
-# out *= -1
-# out += z_m
+ z_m = x + tau * self.bnoise -1
+ self.b.multiply(4*tau, out=out)
+ z_m.multiply(z_m, out=z_m)
+ out += z_m
+ out.sqrt(out=out)
+ z_m.sqrt(out=z_m)
+ z_m += 2
+ out *= -1
+ out += z_m
+ out *= 0.5
+
def __rmul__(self, scalar):
diff --git a/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py b/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py
index 5490782..20e754e 100644
--- a/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py
@@ -94,12 +94,12 @@ class L2NormSquared(Function):
return x/(1+2*tau)
else:
-# tmp = x.subtract(self.b)
+ tmp = x.subtract(self.b)
# tmp -= self.b
-# tmp /= (1+2*tau)
-# tmp += self.b
-# return tmp
- return (x-self.b)/(1+2*tau) + self.b
+ tmp /= (1+2*tau)
+ tmp += self.b
+ return tmp
+# return (x-self.b)/(1+2*tau) + self.b
# if self.b is not None:
# out=x