From cfe16a4d31f4c6d1748edadcfc0706bd6f9ee7cf Mon Sep 17 00:00:00 2001 From: epapoutsellis Date: Sat, 20 Apr 2019 18:48:39 +0100 Subject: changes for SymmetrizedGradient --- .../ccpi/optimisation/functions/MixedL21Norm.py | 60 ++++++++-------------- 1 file changed, 22 insertions(+), 38 deletions(-) diff --git a/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py b/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py index 2004e5f..e8f6da4 100755 --- a/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py +++ b/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py @@ -43,19 +43,9 @@ class MixedL21Norm(Function): ''' if not isinstance(x, BlockDataContainer): raise ValueError('__call__ expected BlockDataContainer, got {}'.format(type(x))) - - if self.SymTensor: - - #TODO fix this case - param = [1]*x.shape[0] - param[-1] = 2 - tmp = [param[i]*(x[i] ** 2) for i in range(x.shape[0])] - res = sum(tmp).sqrt().sum() - - else: - - tmp = [ el**2 for el in x.containers ] - res = sum(tmp).sqrt().sum() + + tmp = [ el**2 for el in x.containers ] + res = sum(tmp).sqrt().sum() return res @@ -67,7 +57,12 @@ class MixedL21Norm(Function): ''' This is the Indicator function of ||\cdot||_{2, \infty} which is either 0 if ||x||_{2, \infty} or \infty ''' + return 0.0 + + #tmp = [ el**2 for el in x.containers ] + #print(sum(tmp).sqrt().as_array().max()) + #return sum(tmp).sqrt().as_array().max() def proximal(self, x, tau, out=None): @@ -80,35 +75,24 @@ class MixedL21Norm(Function): def proximal_conjugate(self, x, tau, out=None): - if self.SymTensor: - - param = [1]*x.shape[0] - param[-1] = 2 - tmp = [param[i]*(x[i] ** 2) for i in range(x.shape[0])] - frac = [x[i]/(sum(tmp).sqrt()).maximum(1.0) for i in range(x.shape[0])] - res = BlockDataContainer(*frac) - - return res - - else: - if out is None: - tmp = [ el*el for el in x.containers] - res = sum(tmp).sqrt().maximum(1.0) - frac = [el/res for el in x.containers] - return BlockDataContainer(*frac) + if out is None: + tmp = [ el*el for el in x.containers] + res = sum(tmp).sqrt().maximum(1.0) + frac = [el/res for el in x.containers] + return BlockDataContainer(*frac) + - - #TODO this is slow, why??? + #TODO this is slow, why??? # return x.divide(x.pnorm().maximum(1.0)) - else: - - res1 = functools.reduce(lambda a,b: a + b*b, x.containers, x.get_item(0) * 0 ) - res = res1.sqrt().maximum(1.0) - x.divide(res, out=out) - + else: + + res1 = functools.reduce(lambda a,b: a + b*b, x.containers, x.get_item(0) * 0 ) + res = res1.sqrt().maximum(1.0) + x.divide(res, out=out) + # x.divide(sum([el*el for el in x.containers]).sqrt().maximum(1.0), out=out) - #TODO this is slow, why ??? + #TODO this is slow, why ??? # x.divide(x.pnorm().maximum(1.0), out=out) -- cgit v1.2.3