diff options
author | epapoutsellis <epapoutsellis@gmail.com> | 2019-04-25 16:23:57 +0100 |
---|---|---|
committer | epapoutsellis <epapoutsellis@gmail.com> | 2019-04-25 16:23:57 +0100 |
commit | d34028a672b22f7c4a02464736d74c70fb354362 (patch) | |
tree | a40c918b8c662fa811e0e2cdbcc65c89caa4d8f7 /Wrappers/Python | |
parent | b36285116596d62aefc878395a142b1541bdd1e8 (diff) | |
download | framework-d34028a672b22f7c4a02464736d74c70fb354362.tar.gz framework-d34028a672b22f7c4a02464736d74c70fb354362.tar.bz2 framework-d34028a672b22f7c4a02464736d74c70fb354362.tar.xz framework-d34028a672b22f7c4a02464736d74c70fb354362.zip |
memopt fix prox conjugate
Diffstat (limited to 'Wrappers/Python')
-rw-r--r-- | Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py | 39 |
1 files changed, 11 insertions, 28 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py b/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py index 22d21fd..14b5ea0 100644 --- a/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py +++ b/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py @@ -62,19 +62,6 @@ class KullbackLeibler(Function): if out is None: return 1 - self.b/(x + self.bnoise) else: -#<<<<<<< HEAD -# self.b.divide(x+self.bnoise, out=out) -# out.subtract(1, out=out) -# -# def convex_conjugate(self, x): -# -# tmp = self.b/( 1 - x ) -# ind = tmp.as_array()>0 -# -# sel -# -# return (self.b * ( ImageData( numpy.log(tmp) ) - 1 ) - self.bnoise * (x - 1)).sum() -#======= x.add(self.bnoise, out=out) self.b.divide(out, out=out) out.subtract(1, out=out) @@ -116,23 +103,19 @@ class KullbackLeibler(Function): if out is None: z = x + tau * self.bnoise - return 0.5*((z + 1) - ((z-1)**2 + 4 * tau * self.b).sqrt()) else: - z = x + tau * self.bnoise - res = 0.5*((z + 1) - ((z-1)**2 + 4 * tau * self.b).sqrt()) - out.fill(res) -# else: -# z_m = x + tau * self.bnoise -1 -# self.b.multiply(4*tau, out=out) -# z_m.multiply(z_m, out=z_m) -# out += z_m -# out.sqrt(out=out) -# z = z_m + 2 -# z_m.sqrt(out=z_m) -# z_m += 2 -# out *= -1 -# out += z_m + z_m = x + tau * self.bnoise -1 + self.b.multiply(4*tau, out=out) + z_m.multiply(z_m, out=z_m) + out += z_m + out.sqrt(out=out) + z_m.sqrt(out=z_m) + z_m += 2 + out *= -1 + out += z_m + out *= 0.5 + def __rmul__(self, scalar): |