diff options
author | epapoutsellis <epapoutsellis@gmail.com> | 2019-04-04 17:37:11 +0100 |
---|---|---|
committer | epapoutsellis <epapoutsellis@gmail.com> | 2019-04-04 17:37:11 +0100 |
commit | 2ee7afd4cb57a51071ba454e79880e78ce24c03b (patch) | |
tree | 2b824984f1fab0c4488e68ca6a2e85df3555b05b | |
parent | c102b119a1dd9444fba0c244ebcfe260cd679a7f (diff) | |
download | framework-2ee7afd4cb57a51071ba454e79880e78ce24c03b.tar.gz framework-2ee7afd4cb57a51071ba454e79880e78ce24c03b.tar.bz2 framework-2ee7afd4cb57a51071ba454e79880e78ce24c03b.tar.xz framework-2ee7afd4cb57a51071ba454e79880e78ce24c03b.zip |
change proxima, proximal conjugate for tau BlockDataContainer
-rw-r--r-- | Wrappers/Python/ccpi/optimisation/functions/BlockFunction.py | 17 |
1 files changed, 13 insertions, 4 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/functions/BlockFunction.py b/Wrappers/Python/ccpi/optimisation/functions/BlockFunction.py index 70216a9..81c16cd 100644 --- a/Wrappers/Python/ccpi/optimisation/functions/BlockFunction.py +++ b/Wrappers/Python/ccpi/optimisation/functions/BlockFunction.py @@ -10,6 +10,7 @@ import numpy as np #from ccpi.optimisation.funcs import Function from ccpi.optimisation.functions import Function from ccpi.framework import BlockDataContainer +from numbers import Number class BlockFunction(Function): '''A Block vector of Functions @@ -52,16 +53,24 @@ class BlockFunction(Function): def proximal_conjugate(self, x, tau, out = None): '''proximal_conjugate does not take into account the BlockOperator''' out = [None]*self.length - for i in range(self.length): - out[i] = self.functions[i].proximal_conjugate(x.get_item(i), tau) + if isinstance(tau, Number): + for i in range(self.length): + out[i] = self.functions[i].proximal_conjugate(x.get_item(i), tau) + else: + for i in range(self.length): + out[i] = self.functions[i].proximal_conjugate(x.get_item(i), tau.get_item(i)) return BlockDataContainer(*out) def proximal(self, x, tau, out = None): '''proximal does not take into account the BlockOperator''' out = [None]*self.length - for i in range(self.length): - out[i] = self.functions[i].proximal(x.get_item(i), tau) + if isinstance(tau, Number): + for i in range(self.length): + out[i] = self.functions[i].proximal(x.get_item(i), tau) + else: + for i in range(self.length): + out[i] = self.functions[i].proximal(x.get_item(i), tau.get_item(i)) return BlockDataContainer(*out) |