summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorepapoutsellis <epapoutsellis@gmail.com>2019-04-20 18:48:39 +0100
committerepapoutsellis <epapoutsellis@gmail.com>2019-04-20 18:48:39 +0100
commitcfe16a4d31f4c6d1748edadcfc0706bd6f9ee7cf (patch)
tree80e6b5512aced636bd7e3063af85da84cfb99aa9
parentcc774263646d61bbc224911903e4e2e8f5e323dc (diff)
downloadframework-cfe16a4d31f4c6d1748edadcfc0706bd6f9ee7cf.tar.gz
framework-cfe16a4d31f4c6d1748edadcfc0706bd6f9ee7cf.tar.bz2
framework-cfe16a4d31f4c6d1748edadcfc0706bd6f9ee7cf.tar.xz
framework-cfe16a4d31f4c6d1748edadcfc0706bd6f9ee7cf.zip
changes for SymmetrizedGradient
-rwxr-xr-xWrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py60
1 files changed, 22 insertions, 38 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py b/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py
index 2004e5f..e8f6da4 100755
--- a/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py
@@ -43,19 +43,9 @@ class MixedL21Norm(Function):
'''
if not isinstance(x, BlockDataContainer):
raise ValueError('__call__ expected BlockDataContainer, got {}'.format(type(x)))
-
- if self.SymTensor:
-
- #TODO fix this case
- param = [1]*x.shape[0]
- param[-1] = 2
- tmp = [param[i]*(x[i] ** 2) for i in range(x.shape[0])]
- res = sum(tmp).sqrt().sum()
-
- else:
-
- tmp = [ el**2 for el in x.containers ]
- res = sum(tmp).sqrt().sum()
+
+ tmp = [ el**2 for el in x.containers ]
+ res = sum(tmp).sqrt().sum()
return res
@@ -67,7 +57,12 @@ class MixedL21Norm(Function):
''' This is the Indicator function of ||\cdot||_{2, \infty}
which is either 0 if ||x||_{2, \infty} or \infty
'''
+
return 0.0
+
+ #tmp = [ el**2 for el in x.containers ]
+ #print(sum(tmp).sqrt().as_array().max())
+ #return sum(tmp).sqrt().as_array().max()
def proximal(self, x, tau, out=None):
@@ -80,35 +75,24 @@ class MixedL21Norm(Function):
def proximal_conjugate(self, x, tau, out=None):
- if self.SymTensor:
-
- param = [1]*x.shape[0]
- param[-1] = 2
- tmp = [param[i]*(x[i] ** 2) for i in range(x.shape[0])]
- frac = [x[i]/(sum(tmp).sqrt()).maximum(1.0) for i in range(x.shape[0])]
- res = BlockDataContainer(*frac)
-
- return res
-
- else:
- if out is None:
- tmp = [ el*el for el in x.containers]
- res = sum(tmp).sqrt().maximum(1.0)
- frac = [el/res for el in x.containers]
- return BlockDataContainer(*frac)
+ if out is None:
+ tmp = [ el*el for el in x.containers]
+ res = sum(tmp).sqrt().maximum(1.0)
+ frac = [el/res for el in x.containers]
+ return BlockDataContainer(*frac)
+
-
- #TODO this is slow, why???
+ #TODO this is slow, why???
# return x.divide(x.pnorm().maximum(1.0))
- else:
-
- res1 = functools.reduce(lambda a,b: a + b*b, x.containers, x.get_item(0) * 0 )
- res = res1.sqrt().maximum(1.0)
- x.divide(res, out=out)
-
+ else:
+
+ res1 = functools.reduce(lambda a,b: a + b*b, x.containers, x.get_item(0) * 0 )
+ res = res1.sqrt().maximum(1.0)
+ x.divide(res, out=out)
+
# x.divide(sum([el*el for el in x.containers]).sqrt().maximum(1.0), out=out)
- #TODO this is slow, why ???
+ #TODO this is slow, why ???
# x.divide(x.pnorm().maximum(1.0), out=out)