summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorepapoutsellis <epapoutsellis@gmail.com>2019-04-04 17:37:11 +0100
committerepapoutsellis <epapoutsellis@gmail.com>2019-04-04 17:37:11 +0100
commit2ee7afd4cb57a51071ba454e79880e78ce24c03b (patch)
tree2b824984f1fab0c4488e68ca6a2e85df3555b05b
parentc102b119a1dd9444fba0c244ebcfe260cd679a7f (diff)
downloadframework-2ee7afd4cb57a51071ba454e79880e78ce24c03b.tar.gz
framework-2ee7afd4cb57a51071ba454e79880e78ce24c03b.tar.bz2
framework-2ee7afd4cb57a51071ba454e79880e78ce24c03b.tar.xz
framework-2ee7afd4cb57a51071ba454e79880e78ce24c03b.zip
change proxima, proximal conjugate for tau BlockDataContainer
-rw-r--r--Wrappers/Python/ccpi/optimisation/functions/BlockFunction.py17
1 files changed, 13 insertions, 4 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/functions/BlockFunction.py b/Wrappers/Python/ccpi/optimisation/functions/BlockFunction.py
index 70216a9..81c16cd 100644
--- a/Wrappers/Python/ccpi/optimisation/functions/BlockFunction.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/BlockFunction.py
@@ -10,6 +10,7 @@ import numpy as np
#from ccpi.optimisation.funcs import Function
from ccpi.optimisation.functions import Function
from ccpi.framework import BlockDataContainer
+from numbers import Number
class BlockFunction(Function):
'''A Block vector of Functions
@@ -52,16 +53,24 @@ class BlockFunction(Function):
def proximal_conjugate(self, x, tau, out = None):
'''proximal_conjugate does not take into account the BlockOperator'''
out = [None]*self.length
- for i in range(self.length):
- out[i] = self.functions[i].proximal_conjugate(x.get_item(i), tau)
+ if isinstance(tau, Number):
+ for i in range(self.length):
+ out[i] = self.functions[i].proximal_conjugate(x.get_item(i), tau)
+ else:
+ for i in range(self.length):
+ out[i] = self.functions[i].proximal_conjugate(x.get_item(i), tau.get_item(i))
return BlockDataContainer(*out)
def proximal(self, x, tau, out = None):
'''proximal does not take into account the BlockOperator'''
out = [None]*self.length
- for i in range(self.length):
- out[i] = self.functions[i].proximal(x.get_item(i), tau)
+ if isinstance(tau, Number):
+ for i in range(self.length):
+ out[i] = self.functions[i].proximal(x.get_item(i), tau)
+ else:
+ for i in range(self.length):
+ out[i] = self.functions[i].proximal(x.get_item(i), tau.get_item(i))
return BlockDataContainer(*out)