diff options
-rw-r--r-- | Wrappers/Python/ccpi/optimisation/functions/BlockFunction.py | 17 |
1 files changed, 13 insertions, 4 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/functions/BlockFunction.py b/Wrappers/Python/ccpi/optimisation/functions/BlockFunction.py index 70216a9..81c16cd 100644 --- a/Wrappers/Python/ccpi/optimisation/functions/BlockFunction.py +++ b/Wrappers/Python/ccpi/optimisation/functions/BlockFunction.py @@ -10,6 +10,7 @@ import numpy as np #from ccpi.optimisation.funcs import Function from ccpi.optimisation.functions import Function from ccpi.framework import BlockDataContainer +from numbers import Number class BlockFunction(Function): '''A Block vector of Functions @@ -52,16 +53,24 @@ class BlockFunction(Function): def proximal_conjugate(self, x, tau, out = None): '''proximal_conjugate does not take into account the BlockOperator''' out = [None]*self.length - for i in range(self.length): - out[i] = self.functions[i].proximal_conjugate(x.get_item(i), tau) + if isinstance(tau, Number): + for i in range(self.length): + out[i] = self.functions[i].proximal_conjugate(x.get_item(i), tau) + else: + for i in range(self.length): + out[i] = self.functions[i].proximal_conjugate(x.get_item(i), tau.get_item(i)) return BlockDataContainer(*out) def proximal(self, x, tau, out = None): '''proximal does not take into account the BlockOperator''' out = [None]*self.length - for i in range(self.length): - out[i] = self.functions[i].proximal(x.get_item(i), tau) + if isinstance(tau, Number): + for i in range(self.length): + out[i] = self.functions[i].proximal(x.get_item(i), tau) + else: + for i in range(self.length): + out[i] = self.functions[i].proximal(x.get_item(i), tau.get_item(i)) return BlockDataContainer(*out) |