summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python
diff options
context:
space:
mode:
authorepapoutsellis <epapoutsellis@gmail.com>2019-04-25 16:23:57 +0100
committerepapoutsellis <epapoutsellis@gmail.com>2019-04-25 16:23:57 +0100
commitd34028a672b22f7c4a02464736d74c70fb354362 (patch)
treea40c918b8c662fa811e0e2cdbcc65c89caa4d8f7 /Wrappers/Python
parentb36285116596d62aefc878395a142b1541bdd1e8 (diff)
downloadframework-d34028a672b22f7c4a02464736d74c70fb354362.tar.gz
framework-d34028a672b22f7c4a02464736d74c70fb354362.tar.bz2
framework-d34028a672b22f7c4a02464736d74c70fb354362.tar.xz
framework-d34028a672b22f7c4a02464736d74c70fb354362.zip
memopt fix prox conjugate
Diffstat (limited to 'Wrappers/Python')
-rw-r--r--Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py39
1 files changed, 11 insertions, 28 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py b/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py
index 22d21fd..14b5ea0 100644
--- a/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py
@@ -62,19 +62,6 @@ class KullbackLeibler(Function):
if out is None:
return 1 - self.b/(x + self.bnoise)
else:
-#<<<<<<< HEAD
-# self.b.divide(x+self.bnoise, out=out)
-# out.subtract(1, out=out)
-#
-# def convex_conjugate(self, x):
-#
-# tmp = self.b/( 1 - x )
-# ind = tmp.as_array()>0
-#
-# sel
-#
-# return (self.b * ( ImageData( numpy.log(tmp) ) - 1 ) - self.bnoise * (x - 1)).sum()
-#=======
x.add(self.bnoise, out=out)
self.b.divide(out, out=out)
out.subtract(1, out=out)
@@ -116,23 +103,19 @@ class KullbackLeibler(Function):
if out is None:
z = x + tau * self.bnoise
-
return 0.5*((z + 1) - ((z-1)**2 + 4 * tau * self.b).sqrt())
else:
- z = x + tau * self.bnoise
- res = 0.5*((z + 1) - ((z-1)**2 + 4 * tau * self.b).sqrt())
- out.fill(res)
-# else:
-# z_m = x + tau * self.bnoise -1
-# self.b.multiply(4*tau, out=out)
-# z_m.multiply(z_m, out=z_m)
-# out += z_m
-# out.sqrt(out=out)
-# z = z_m + 2
-# z_m.sqrt(out=z_m)
-# z_m += 2
-# out *= -1
-# out += z_m
+ z_m = x + tau * self.bnoise -1
+ self.b.multiply(4*tau, out=out)
+ z_m.multiply(z_m, out=z_m)
+ out += z_m
+ out.sqrt(out=out)
+ z_m.sqrt(out=z_m)
+ z_m += 2
+ out *= -1
+ out += z_m
+ out *= 0.5
+
def __rmul__(self, scalar):