summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-06-04 13:36:50 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2019-06-04 13:36:50 +0100
commit3271aa9c4ca50177bf2f9e37269efa03f52f635e (patch)
tree370f9de182a5e5a80a10c58059e1d1ecfb4466ea /Wrappers/Python
parentbe88c669c995176b54d43d1fa800095d449490d6 (diff)
downloadframework-3271aa9c4ca50177bf2f9e37269efa03f52f635e.tar.gz
framework-3271aa9c4ca50177bf2f9e37269efa03f52f635e.tar.bz2
framework-3271aa9c4ca50177bf2f9e37269efa03f52f635e.tar.xz
framework-3271aa9c4ca50177bf2f9e37269efa03f52f635e.zip
fixing tests
Diffstat (limited to 'Wrappers/Python')
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/FISTA.py117
-rwxr-xr-xWrappers/Python/ccpi/optimisation/functions/Norm2Sq.py2
-rwxr-xr-xWrappers/Python/ccpi/optimisation/functions/ScaledFunction.py3
-rw-r--r--Wrappers/Python/ccpi/optimisation/operators/LinearOperatorMatrix.py21
-rwxr-xr-xWrappers/Python/test/test_run_test.py14
5 files changed, 59 insertions, 98 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
index 8ea2b6c..04e7c38 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
@@ -20,102 +20,61 @@ class FISTA(Algorithm):
x_init: initial guess
f: data fidelity
g: regularizer
- h:
- opt: additional algorithm
+ opt: additional options
'''
-
+
+
def __init__(self, **kwargs):
'''initialisation can be done at creation time if all
proper variables are passed or later with set_up'''
super(FISTA, self).__init__()
- self.f = None
- self.g = None
+ self.f = kwargs.get('f', None)
+ self.g = kwargs.get('g', None)
+ self.x_init = kwargs.get('x_init',None)
self.invL = None
self.t_old = 1
- args = ['x_init', 'f', 'g', 'opt']
- for k,v in kwargs.items():
- if k in args:
- args.pop(args.index(k))
- if len(args) == 0:
- return self.set_up(kwargs['x_init'],
- f=kwargs['f'],
- g=kwargs['g'],
- opt=kwargs['opt'])
+ if self.f is not None and self.g is not None:
+ print ("Calling from creator")
+ self.set_up(self.x_init, self.f, self.g)
+
- def set_up(self, x_init, f=None, g=None, opt=None):
+ def set_up(self, x_init, f, g, opt=None, **kwargs):
- # default inputs
- if f is None:
- self.f = ZeroFunction()
- else:
- self.f = f
- if g is None:
- g = ZeroFunction()
- self.g = g
- else:
- self.g = g
+ self.f = f
+ self.g = g
# algorithmic parameters
if opt is None:
- opt = {'tol': 1e-4, 'memopt':False}
-
- self.tol = opt['tol'] if 'tol' in opt.keys() else 1e-4
- memopt = opt['memopt'] if 'memopt' in opt.keys() else False
- self.memopt = memopt
-
- # initialization
- if memopt:
- self.y = x_init.clone()
- self.x_old = x_init.clone()
- self.x = x_init.clone()
- self.u = x_init.clone()
- else:
- self.x_old = x_init.copy()
- self.y = x_init.copy()
-
- #timing = numpy.zeros(max_iter)
- #criter = numpy.zeros(max_iter)
+ opt = {'tol': 1e-4}
-
+ self.y = x_init.copy()
+ self.x_old = x_init.copy()
+ self.x = x_init.copy()
+ self.u = x_init.copy()
+
+
self.invL = 1/f.L
self.t_old = 1
def update(self):
- # algorithm loop
- #for it in range(0, max_iter):
-
- if self.memopt:
- # u = y - invL*f.grad(y)
- # store the result in x_old
- self.f.gradient(self.y, out=self.u)
- self.u.__imul__( -self.invL )
- self.u.__iadd__( self.y )
- # x = g.prox(u,invL)
- self.g.proximal(self.u, self.invL, out=self.x)
-
- self.t = 0.5*(1 + numpy.sqrt(1 + 4*(self.t_old**2)))
-
- # y = x + (t_old-1)/t*(x-x_old)
- self.x.subtract(self.x_old, out=self.y)
- self.y.__imul__ ((self.t_old-1)/self.t)
- self.y.__iadd__( self.x )
-
- self.x_old.fill(self.x)
- self.t_old = self.t
-
-
- else:
- u = self.y - self.invL*self.f.gradient(self.y)
-
- self.x = self.g.proximal(u,self.invL)
-
- self.t = 0.5*(1 + numpy.sqrt(1 + 4*(self.t_old**2)))
-
- self.y = self.x + (self.t_old-1)/self.t*(self.x-self.x_old)
-
- self.x_old = self.x.copy()
- self.t_old = self.t
+
+ self.f.gradient(self.y, out=self.u)
+ self.u.__imul__( -self.invL )
+ self.u.__iadd__( self.y )
+ # x = g.prox(u,invL)
+ self.g.proximal(self.u, self.invL, out=self.x)
+
+ self.t = 0.5*(1 + numpy.sqrt(1 + 4*(self.t_old**2)))
+
+ self.x.subtract(self.x_old, out=self.y)
+ self.y.__imul__ ((self.t_old-1)/self.t)
+ self.y.__iadd__( self.x )
+
+ self.x_old.fill(self.x)
+ self.t_old = self.t
def update_objective(self):
- self.loss.append( self.f(self.x) + self.g(self.x) ) \ No newline at end of file
+ self.loss.append( self.f(self.x) + self.g(self.x) )
+
+
diff --git a/Wrappers/Python/ccpi/optimisation/functions/Norm2Sq.py b/Wrappers/Python/ccpi/optimisation/functions/Norm2Sq.py
index 0b6bb0b..d9d9010 100755
--- a/Wrappers/Python/ccpi/optimisation/functions/Norm2Sq.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/Norm2Sq.py
@@ -88,7 +88,7 @@ class Norm2Sq(Function):
def gradient(self, x, out = None):
if self.memopt:
#return 2.0*self.c*self.A.adjoint( self.A.direct(x) - self.b )
-
+ print (self.range_tmp, self.range_tmp.as_array())
self.A.direct(x, out=self.range_tmp)
self.range_tmp -= self.b
self.A.adjoint(self.range_tmp, out=out)
diff --git a/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py b/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py
index 7caeab2..d292ac8 100755
--- a/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py
@@ -18,6 +18,7 @@
# limitations under the License.
from numbers import Number
import numpy
+import warnings
class ScaledFunction(object):
@@ -88,7 +89,7 @@ class ScaledFunction(object):
'''Alias of proximal(x, tau, None)'''
warnings.warn('''This method will disappear in following
versions of the CIL. Use proximal instead''', DeprecationWarning)
- return self.proximal(x, out=None)
+ return self.proximal(x, tau, out=None)
diff --git a/Wrappers/Python/ccpi/optimisation/operators/LinearOperatorMatrix.py b/Wrappers/Python/ccpi/optimisation/operators/LinearOperatorMatrix.py
index 62e22e0..6306192 100644
--- a/Wrappers/Python/ccpi/optimisation/operators/LinearOperatorMatrix.py
+++ b/Wrappers/Python/ccpi/optimisation/operators/LinearOperatorMatrix.py
@@ -10,6 +10,9 @@ from ccpi.optimisation.operators import LinearOperator
class LinearOperatorMatrix(LinearOperator):
def __init__(self,A):
self.A = A
+ M_A, N_A = self.A.shape
+ self.gm_domain = ImageGeometry(0, N_A)
+ self.gm_range = ImageGeometry(M_A,0)
self.s1 = None # Largest singular value, initially unknown
super(LinearOperatorMatrix, self).__init__()
@@ -30,22 +33,14 @@ class LinearOperatorMatrix(LinearOperator):
def size(self):
return self.A.shape
- def get_max_sing_val(self):
+ def norm(self):
# If unknown, compute and store. If known, simply return it.
if self.s1 is None:
self.s1 = svds(self.A,1,return_singular_vectors=False)[0]
return self.s1
else:
return self.s1
- def allocate_direct(self):
- '''allocates the memory to hold the result of adjoint'''
- #numpy.dot(self.A.transpose(),x.as_array())
- M_A, N_A = self.A.shape
- out = numpy.zeros((N_A,1))
- return DataContainer(out)
- def allocate_adjoint(self):
- '''allocate the memory to hold the result of direct'''
- #numpy.dot(self.A.transpose(),x.as_array())
- M_A, N_A = self.A.shape
- out = numpy.zeros((M_A,1))
- return DataContainer(out)
+ def domain_geometry(self):
+ return self.gm_domain
+ def range_geometry(self):
+ return self.gm_range
diff --git a/Wrappers/Python/test/test_run_test.py b/Wrappers/Python/test/test_run_test.py
index a6c13f4..ebe494f 100755
--- a/Wrappers/Python/test/test_run_test.py
+++ b/Wrappers/Python/test/test_run_test.py
@@ -10,7 +10,8 @@ from ccpi.optimisation.algorithms import FISTA
#from ccpi.optimisation.algs import FBPD
from ccpi.optimisation.functions import Norm2Sq
from ccpi.optimisation.functions import ZeroFunction
-from ccpi.optimisation.funcs import Norm1
+#from ccpi.optimisation.funcs import Norm1
+from ccpi.optimisation.functions import L1Norm
from ccpi.optimisation.funcs import Norm2
from ccpi.optimisation.operators import LinearOperatorMatrix
@@ -81,6 +82,7 @@ class TestAlgorithms(unittest.TestCase):
opt = {'memopt': True}
# Create object instances with the test data A and b.
f = Norm2Sq(A, b, c=0.5, memopt=True)
+ f.L = LinearOperator.PowerMethod(A, 10)
g0 = ZeroFunction()
# Initial guess
@@ -123,6 +125,7 @@ class TestAlgorithms(unittest.TestCase):
self.assertTrue(cvx_not_installable)
def test_FISTA_Norm1_cvx(self):
+ print ("test_FISTA_Norm1_cvx")
if not cvx_not_installable:
try:
opt = {'memopt': True}
@@ -151,7 +154,8 @@ class TestAlgorithms(unittest.TestCase):
x_init = DataContainer(np.zeros((n, 1)))
# Create 1-norm object instance
- g1 = Norm1(lam)
+ #g1 = Norm1(lam)
+ g1 = lam * L1Norm()
g1(x_init)
g1.prox(x_init, 0.02)
@@ -225,7 +229,8 @@ class TestAlgorithms(unittest.TestCase):
# Create 1-norm object instance
- g1 = Norm1(lam)
+ #g1 = Norm1(lam)
+ g1 = lam * L1Norm()
# Compare to CVXPY
@@ -292,7 +297,8 @@ class TestAlgorithms(unittest.TestCase):
# 1-norm regulariser
lam1_denoise = 1.0
- g1_denoise = Norm1(lam1_denoise)
+ #g1_denoise = Norm1(lam1_denoise)
+ g1_denoise = lam1_denoise * L1Norm()
# Initial guess
x_init_denoise = ImageData(np.zeros((N, N)))