diff options
author | epapoutsellis <epapoutsellis@gmail.com> | 2019-06-03 22:51:43 +0100 |
---|---|---|
committer | epapoutsellis <epapoutsellis@gmail.com> | 2019-06-03 22:51:43 +0100 |
commit | 14e06f7ad88202114b22ed478ba6efab952fa30b (patch) | |
tree | e5a990b77bcf2ca3649f942e876d0f3f85f70154 | |
parent | 1aa94932776f3a95b02304b1dfd8a18459d7e37c (diff) | |
download | framework-14e06f7ad88202114b22ed478ba6efab952fa30b.tar.gz framework-14e06f7ad88202114b22ed478ba6efab952fa30b.tar.bz2 framework-14e06f7ad88202114b22ed478ba6efab952fa30b.tar.xz framework-14e06f7ad88202114b22ed478ba6efab952fa30b.zip |
fix call kl div
-rw-r--r-- | Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py | 9 |
1 files changed, 4 insertions, 5 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py b/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py index 0d3c8f5..6920829 100644 --- a/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py +++ b/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py @@ -52,8 +52,8 @@ class KullbackLeibler(Function): ''' - # TODO avoid scipy import ???? - tmp = scipy.special.kl_div(self.b.as_array(), x.as_array()) + ind = x.as_array()>0 + tmp = scipy.special.kl_div(self.b.as_array()[ind], x.as_array()[ind]) return numpy.sum(tmp) @@ -78,9 +78,8 @@ class KullbackLeibler(Function): def convex_conjugate(self, x): - # TODO avoid scipy import ???? - xlogy = scipy.special.xlogy(self.b.as_array(), 1 - x.as_array()) - return numpy.sum(-xlogy) + xlogy = - scipy.special.xlogy(self.b.as_array(), 1 - x.as_array()) + return numpy.sum(xlogy) def proximal(self, x, tau, out=None): |