From 14e06f7ad88202114b22ed478ba6efab952fa30b Mon Sep 17 00:00:00 2001 From: epapoutsellis Date: Mon, 3 Jun 2019 22:51:43 +0100 Subject: fix call kl div --- Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'Wrappers/Python') diff --git a/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py b/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py index 0d3c8f5..6920829 100644 --- a/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py +++ b/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py @@ -52,8 +52,8 @@ class KullbackLeibler(Function): ''' - # TODO avoid scipy import ???? - tmp = scipy.special.kl_div(self.b.as_array(), x.as_array()) + ind = x.as_array()>0 + tmp = scipy.special.kl_div(self.b.as_array()[ind], x.as_array()[ind]) return numpy.sum(tmp) @@ -78,9 +78,8 @@ class KullbackLeibler(Function): def convex_conjugate(self, x): - # TODO avoid scipy import ???? - xlogy = scipy.special.xlogy(self.b.as_array(), 1 - x.as_array()) - return numpy.sum(-xlogy) + xlogy = - scipy.special.xlogy(self.b.as_array(), 1 - x.as_array()) + return numpy.sum(xlogy) def proximal(self, x, tau, out=None): -- cgit v1.2.3