diff options
Diffstat (limited to 'Wrappers')
-rwxr-xr-x | Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py | 162 |
1 files changed, 87 insertions, 75 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py b/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py index 378cbda..55e6e53 100755 --- a/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py +++ b/Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py @@ -1,22 +1,20 @@ # -*- coding: utf-8 -*- -# Copyright 2019 Science Technology Facilities Council -# Copyright 2019 University of Manchester -# -# This work is part of the Core Imaging Library developed by Science Technology -# Facilities Council and University of Manchester -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0.txt -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# CCP in Tomographic Imaging (CCPi) Core Imaging Library (CIL). + +# Copyright 2017 UKRI-STFC +# Copyright 2017 University of Manchester +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from __future__ import absolute_import from __future__ import division @@ -25,17 +23,15 @@ from __future__ import unicode_literals from ccpi.optimisation.functions import Function, ScaledFunction from ccpi.framework import BlockDataContainer +import numpy as np import functools -import numpy class MixedL21Norm(Function): - r'''MixedL21Norm: .. math:: f(x) = ||x||_{2,1} = \int \|x\|_{2} dx - - where x is a vector/tensor vield - + ''' + f(x) = ||x||_{2,1} = \sum |x|_{2} ''' def __init__(self, **kwargs): @@ -45,13 +41,15 @@ class MixedL21Norm(Function): def __call__(self, x): - '''Evaluates MixedL21Norm at point x + ''' Evaluates L2,1Norm at point x - :param: x: is a BlockDataContainer + :param: x is a BlockDataContainer + ''' if not isinstance(x, BlockDataContainer): raise ValueError('__call__ expected BlockDataContainer, got {}'.format(type(x))) - tmp = x.get_item(0) * 0 + + tmp = x.get_item(0) * 0. for el in x.containers: tmp += el.power(2.) return tmp.sqrt().sum() @@ -62,11 +60,8 @@ class MixedL21Norm(Function): def convex_conjugate(self,x): - r'''Convex conjugate of of MixedL21Norm: - - Indicator function of .. math:: ||\cdot||_{2, \infty} - which is either 0 if .. math:: ||x||_{2, \infty}<1 or \infty - + ''' This is the Indicator function of ||\cdot||_{2, \infty} + which is either 0 if ||x||_{2, \infty} or \infty ''' return 0.0 @@ -74,68 +69,59 @@ class MixedL21Norm(Function): def proximal(self, x, tau, out=None): - r'''Proximal operator of MixedL21Norm at x: - - .. math:: prox_{\tau * f(x) - ''' - pass - - def proximal_conjugate(self, x, tau, out=None): - - r'''Proximal operator of the convex conjugate of MixedL21Norm at x: - - .. math:: prox_{\tau * f^{*}}(x) + if out is None: + + tmp = sum([ el*el for el in x.containers]).sqrt() + res = (tmp - tau).maximum(0.0) * x/tmp + return res + + else: + + tmp = functools.reduce(lambda a,b: a + b*b, x.containers, x.get_item(0) * 0 ).sqrt() + res = (tmp - tau).maximum(0.0) * x/tmp - ''' + for el in res.containers: + el.as_array()[np.isnan(el.as_array())]=0 + out.fill(res) + + + def proximal_conjugate(self, x, tau, out=None): + if out is None: - # tmp = [ el*el for el in x.containers] - # res = sum(tmp).sqrt().maximum(1.0) - # frac = [el/res for el in x.containers] - # return BlockDataContainer(*frac) - tmp = x.get_item(0) * 0 - for el in x.containers: - tmp += el.power(2.) - tmp.sqrt(out=tmp) - tmp.maximum(1.0, out=tmp) - frac = [ el.divide(tmp) for el in x.containers ] + tmp = x.get_item(0) * 0 + for el in x.containers: + tmp += el.power(2.) + tmp.sqrt(out=tmp) + tmp.maximum(1.0, out=tmp) + frac = [ el.divide(tmp) for el in x.containers ] return BlockDataContainer(*frac) - - + + else: res1 = functools.reduce(lambda a,b: a + b*b, x.containers, x.get_item(0) * 0 ) - if False: - res = res1.sqrt().maximum(1.0) - x.divide(res, out=out) - else: - res1.sqrt(out=res1) - res1.maximum(1.0, out=res1) - x.divide(res1, out=out) - + res1.sqrt(out=res1) + res1.maximum(1.0, out=res1) + x.divide(res1, out=out) + def __rmul__(self, scalar): - '''Multiplication of MixedL21Norm with a scalar - - Returns: ScaledFunction + ''' Multiplication of MixedL21Norm with a scalar + + Returns: ScaledFunction ''' return ScaledFunction(self, scalar) -def sqrt_maximum(x, a): - y = numpy.sqrt(x) - if y >= a: - return y - else: - return a # if __name__ == '__main__': - M, N, K = 2,3,5 - from ccpi.framework import BlockGeometry + M, N, K = 2,3,50 + from ccpi.framework import BlockGeometry, ImageGeometry import numpy ig = ImageGeometry(M, N) @@ -145,8 +131,9 @@ if __name__ == '__main__': U = BG.allocate('random_int') # Define no scale and scaled + alpha = 0.5 f_no_scaled = MixedL21Norm() - f_scaled = 0.5 * MixedL21Norm() + f_scaled = alpha * MixedL21Norm() # call @@ -174,11 +161,36 @@ if __name__ == '__main__': numpy.testing.assert_array_almost_equal(res_no_out[1].as_array(), \ res_out[1].as_array(), decimal=4) -# + tau = 0.4 + d1 = f_scaled.proximal(U, tau) + + tmp = (U.get_item(0)**2 + U.get_item(1)**2).sqrt() + + d2 = (tmp - alpha*tau).maximum(0) * U/tmp + numpy.testing.assert_array_almost_equal(d1.get_item(0).as_array(), \ + d2.get_item(0).as_array(), decimal=4) + numpy.testing.assert_array_almost_equal(d1.get_item(1).as_array(), \ + d2.get_item(1).as_array(), decimal=4) + out1 = BG.allocate('random_int') + + + f_scaled.proximal(U, tau, out = out1) + + numpy.testing.assert_array_almost_equal(out1.get_item(0).as_array(), \ + d1.get_item(0).as_array(), decimal=4) + numpy.testing.assert_array_almost_equal(out1.get_item(1).as_array(), \ + d1.get_item(1).as_array(), decimal=4) +# + + + + + +
\ No newline at end of file |