summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python
diff options
context:
space:
mode:
authorepapoutsellis <epapoutsellis@gmail.com>2019-06-03 22:51:43 +0100
committerepapoutsellis <epapoutsellis@gmail.com>2019-06-03 22:51:43 +0100
commit14e06f7ad88202114b22ed478ba6efab952fa30b (patch)
treee5a990b77bcf2ca3649f942e876d0f3f85f70154 /Wrappers/Python
parent1aa94932776f3a95b02304b1dfd8a18459d7e37c (diff)
downloadframework-14e06f7ad88202114b22ed478ba6efab952fa30b.tar.gz
framework-14e06f7ad88202114b22ed478ba6efab952fa30b.tar.bz2
framework-14e06f7ad88202114b22ed478ba6efab952fa30b.tar.xz
framework-14e06f7ad88202114b22ed478ba6efab952fa30b.zip
fix call kl div
Diffstat (limited to 'Wrappers/Python')
-rw-r--r--Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py9
1 files changed, 4 insertions, 5 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py b/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py
index 0d3c8f5..6920829 100644
--- a/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py
@@ -52,8 +52,8 @@ class KullbackLeibler(Function):
'''
- # TODO avoid scipy import ????
- tmp = scipy.special.kl_div(self.b.as_array(), x.as_array())
+ ind = x.as_array()>0
+ tmp = scipy.special.kl_div(self.b.as_array()[ind], x.as_array()[ind])
return numpy.sum(tmp)
@@ -78,9 +78,8 @@ class KullbackLeibler(Function):
def convex_conjugate(self, x):
- # TODO avoid scipy import ????
- xlogy = scipy.special.xlogy(self.b.as_array(), 1 - x.as_array())
- return numpy.sum(-xlogy)
+ xlogy = - scipy.special.xlogy(self.b.as_array(), 1 - x.as_array())
+ return numpy.sum(xlogy)
def proximal(self, x, tau, out=None):