summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py25
-rw-r--r--Wrappers/Python/test/test_functions.py6
2 files changed, 20 insertions, 11 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py b/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py
index 5489d92..597d4d8 100644
--- a/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py
@@ -1,12 +1,21 @@
# -*- coding: utf-8 -*-
+# This work is part of the Core Imaging Library developed by
+# Visual Analytics and Imaging System Group of the Science Technology
+# Facilities Council, STFC
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-"""
-Created on Thu Feb 7 13:10:56 2019
+# Copyright 2018-2019 Evangelos Papoutsellis and Edoardo Pasca
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
-@author: evangelos
-"""
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import numpy
from ccpi.optimisation.functions import Function
@@ -75,10 +84,10 @@ class L2NormSquared(Function):
if out is None:
# FIXME: this is a number
- return (1/4) * x.squared_norm() + tmp
+ return (1./4.) * x.squared_norm() + tmp
else:
# FIXME: this is a DataContainer
- out.fill((1/4) * x.squared_norm() + tmp)
+ out.fill((1./4.) * x.squared_norm() + tmp)
def proximal(self, x, tau, out = None):
diff --git a/Wrappers/Python/test/test_functions.py b/Wrappers/Python/test/test_functions.py
index 3e5f26f..54dfa57 100644
--- a/Wrappers/Python/test/test_functions.py
+++ b/Wrappers/Python/test/test_functions.py
@@ -62,7 +62,7 @@ class TestFunction(unittest.TestCase):
self.assertEqual(a2, g(d))
# Compare convex conjugate of g
- a3 = 0.5 * d.power(2).sum() + (d*noisy_data).sum()
+ a3 = 0.5 * d.squared_norm() + d.dot(noisy_data)
self.assertEqual(a3, g.convex_conjugate(d))
#print( a3, g.convex_conjugate(d))
@@ -91,12 +91,12 @@ class TestFunction(unittest.TestCase):
#check convex conjuagate no data
c1 = f.convex_conjugate(u)
- c2 = 1/4 * u.squared_norm()
+ c2 = 1/4. * u.squared_norm()
numpy.testing.assert_equal(c1, c2)
#check convex conjuagate with data
d1 = f1.convex_conjugate(u)
- d2 = (1/4) * u.squared_norm() + (u*b).sum()
+ d2 = (1./4.) * u.squared_norm() + (u*b).sum()
numpy.testing.assert_equal(d1, d2)
# check proximal no data