diff options
author | Edoardo Pasca <edo.paskino@gmail.com> | 2019-06-04 13:36:50 +0100 |
---|---|---|
committer | Edoardo Pasca <edo.paskino@gmail.com> | 2019-06-04 13:36:50 +0100 |
commit | 3271aa9c4ca50177bf2f9e37269efa03f52f635e (patch) | |
tree | 370f9de182a5e5a80a10c58059e1d1ecfb4466ea /Wrappers/Python | |
parent | be88c669c995176b54d43d1fa800095d449490d6 (diff) | |
download | framework-3271aa9c4ca50177bf2f9e37269efa03f52f635e.tar.gz framework-3271aa9c4ca50177bf2f9e37269efa03f52f635e.tar.bz2 framework-3271aa9c4ca50177bf2f9e37269efa03f52f635e.tar.xz framework-3271aa9c4ca50177bf2f9e37269efa03f52f635e.zip |
fixing tests
Diffstat (limited to 'Wrappers/Python')
5 files changed, 59 insertions, 98 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py index 8ea2b6c..04e7c38 100755 --- a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py @@ -20,102 +20,61 @@ class FISTA(Algorithm): x_init: initial guess f: data fidelity g: regularizer - h: - opt: additional algorithm + opt: additional options ''' - + + def __init__(self, **kwargs): '''initialisation can be done at creation time if all proper variables are passed or later with set_up''' super(FISTA, self).__init__() - self.f = None - self.g = None + self.f = kwargs.get('f', None) + self.g = kwargs.get('g', None) + self.x_init = kwargs.get('x_init',None) self.invL = None self.t_old = 1 - args = ['x_init', 'f', 'g', 'opt'] - for k,v in kwargs.items(): - if k in args: - args.pop(args.index(k)) - if len(args) == 0: - return self.set_up(kwargs['x_init'], - f=kwargs['f'], - g=kwargs['g'], - opt=kwargs['opt']) + if self.f is not None and self.g is not None: + print ("Calling from creator") + self.set_up(self.x_init, self.f, self.g) + - def set_up(self, x_init, f=None, g=None, opt=None): + def set_up(self, x_init, f, g, opt=None, **kwargs): - # default inputs - if f is None: - self.f = ZeroFunction() - else: - self.f = f - if g is None: - g = ZeroFunction() - self.g = g - else: - self.g = g + self.f = f + self.g = g # algorithmic parameters if opt is None: - opt = {'tol': 1e-4, 'memopt':False} - - self.tol = opt['tol'] if 'tol' in opt.keys() else 1e-4 - memopt = opt['memopt'] if 'memopt' in opt.keys() else False - self.memopt = memopt - - # initialization - if memopt: - self.y = x_init.clone() - self.x_old = x_init.clone() - self.x = x_init.clone() - self.u = x_init.clone() - else: - self.x_old = x_init.copy() - self.y = x_init.copy() - - #timing = numpy.zeros(max_iter) - #criter = numpy.zeros(max_iter) + opt = {'tol': 1e-4} - + self.y = x_init.copy() + self.x_old = x_init.copy() + self.x = x_init.copy() + self.u = x_init.copy() + + self.invL = 1/f.L self.t_old = 1 def update(self): - # algorithm loop - #for it in range(0, max_iter): - - if self.memopt: - # u = y - invL*f.grad(y) - # store the result in x_old - self.f.gradient(self.y, out=self.u) - self.u.__imul__( -self.invL ) - self.u.__iadd__( self.y ) - # x = g.prox(u,invL) - self.g.proximal(self.u, self.invL, out=self.x) - - self.t = 0.5*(1 + numpy.sqrt(1 + 4*(self.t_old**2))) - - # y = x + (t_old-1)/t*(x-x_old) - self.x.subtract(self.x_old, out=self.y) - self.y.__imul__ ((self.t_old-1)/self.t) - self.y.__iadd__( self.x ) - - self.x_old.fill(self.x) - self.t_old = self.t - - - else: - u = self.y - self.invL*self.f.gradient(self.y) - - self.x = self.g.proximal(u,self.invL) - - self.t = 0.5*(1 + numpy.sqrt(1 + 4*(self.t_old**2))) - - self.y = self.x + (self.t_old-1)/self.t*(self.x-self.x_old) - - self.x_old = self.x.copy() - self.t_old = self.t + + self.f.gradient(self.y, out=self.u) + self.u.__imul__( -self.invL ) + self.u.__iadd__( self.y ) + # x = g.prox(u,invL) + self.g.proximal(self.u, self.invL, out=self.x) + + self.t = 0.5*(1 + numpy.sqrt(1 + 4*(self.t_old**2))) + + self.x.subtract(self.x_old, out=self.y) + self.y.__imul__ ((self.t_old-1)/self.t) + self.y.__iadd__( self.x ) + + self.x_old.fill(self.x) + self.t_old = self.t def update_objective(self): - self.loss.append( self.f(self.x) + self.g(self.x) )
\ No newline at end of file + self.loss.append( self.f(self.x) + self.g(self.x) ) + + diff --git a/Wrappers/Python/ccpi/optimisation/functions/Norm2Sq.py b/Wrappers/Python/ccpi/optimisation/functions/Norm2Sq.py index 0b6bb0b..d9d9010 100755 --- a/Wrappers/Python/ccpi/optimisation/functions/Norm2Sq.py +++ b/Wrappers/Python/ccpi/optimisation/functions/Norm2Sq.py @@ -88,7 +88,7 @@ class Norm2Sq(Function): def gradient(self, x, out = None): if self.memopt: #return 2.0*self.c*self.A.adjoint( self.A.direct(x) - self.b ) - + print (self.range_tmp, self.range_tmp.as_array()) self.A.direct(x, out=self.range_tmp) self.range_tmp -= self.b self.A.adjoint(self.range_tmp, out=out) diff --git a/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py b/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py index 7caeab2..d292ac8 100755 --- a/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py +++ b/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py @@ -18,6 +18,7 @@ # limitations under the License.
from numbers import Number
import numpy
+import warnings
class ScaledFunction(object):
@@ -88,7 +89,7 @@ class ScaledFunction(object): '''Alias of proximal(x, tau, None)'''
warnings.warn('''This method will disappear in following
versions of the CIL. Use proximal instead''', DeprecationWarning)
- return self.proximal(x, out=None)
+ return self.proximal(x, tau, out=None)
diff --git a/Wrappers/Python/ccpi/optimisation/operators/LinearOperatorMatrix.py b/Wrappers/Python/ccpi/optimisation/operators/LinearOperatorMatrix.py index 62e22e0..6306192 100644 --- a/Wrappers/Python/ccpi/optimisation/operators/LinearOperatorMatrix.py +++ b/Wrappers/Python/ccpi/optimisation/operators/LinearOperatorMatrix.py @@ -10,6 +10,9 @@ from ccpi.optimisation.operators import LinearOperator class LinearOperatorMatrix(LinearOperator): def __init__(self,A): self.A = A + M_A, N_A = self.A.shape + self.gm_domain = ImageGeometry(0, N_A) + self.gm_range = ImageGeometry(M_A,0) self.s1 = None # Largest singular value, initially unknown super(LinearOperatorMatrix, self).__init__() @@ -30,22 +33,14 @@ class LinearOperatorMatrix(LinearOperator): def size(self): return self.A.shape - def get_max_sing_val(self): + def norm(self): # If unknown, compute and store. If known, simply return it. if self.s1 is None: self.s1 = svds(self.A,1,return_singular_vectors=False)[0] return self.s1 else: return self.s1 - def allocate_direct(self): - '''allocates the memory to hold the result of adjoint''' - #numpy.dot(self.A.transpose(),x.as_array()) - M_A, N_A = self.A.shape - out = numpy.zeros((N_A,1)) - return DataContainer(out) - def allocate_adjoint(self): - '''allocate the memory to hold the result of direct''' - #numpy.dot(self.A.transpose(),x.as_array()) - M_A, N_A = self.A.shape - out = numpy.zeros((M_A,1)) - return DataContainer(out) + def domain_geometry(self): + return self.gm_domain + def range_geometry(self): + return self.gm_range diff --git a/Wrappers/Python/test/test_run_test.py b/Wrappers/Python/test/test_run_test.py index a6c13f4..ebe494f 100755 --- a/Wrappers/Python/test/test_run_test.py +++ b/Wrappers/Python/test/test_run_test.py @@ -10,7 +10,8 @@ from ccpi.optimisation.algorithms import FISTA #from ccpi.optimisation.algs import FBPD from ccpi.optimisation.functions import Norm2Sq from ccpi.optimisation.functions import ZeroFunction -from ccpi.optimisation.funcs import Norm1 +#from ccpi.optimisation.funcs import Norm1 +from ccpi.optimisation.functions import L1Norm from ccpi.optimisation.funcs import Norm2 from ccpi.optimisation.operators import LinearOperatorMatrix @@ -81,6 +82,7 @@ class TestAlgorithms(unittest.TestCase): opt = {'memopt': True} # Create object instances with the test data A and b. f = Norm2Sq(A, b, c=0.5, memopt=True) + f.L = LinearOperator.PowerMethod(A, 10) g0 = ZeroFunction() # Initial guess @@ -123,6 +125,7 @@ class TestAlgorithms(unittest.TestCase): self.assertTrue(cvx_not_installable) def test_FISTA_Norm1_cvx(self): + print ("test_FISTA_Norm1_cvx") if not cvx_not_installable: try: opt = {'memopt': True} @@ -151,7 +154,8 @@ class TestAlgorithms(unittest.TestCase): x_init = DataContainer(np.zeros((n, 1))) # Create 1-norm object instance - g1 = Norm1(lam) + #g1 = Norm1(lam) + g1 = lam * L1Norm() g1(x_init) g1.prox(x_init, 0.02) @@ -225,7 +229,8 @@ class TestAlgorithms(unittest.TestCase): # Create 1-norm object instance - g1 = Norm1(lam) + #g1 = Norm1(lam) + g1 = lam * L1Norm() # Compare to CVXPY @@ -292,7 +297,8 @@ class TestAlgorithms(unittest.TestCase): # 1-norm regulariser lam1_denoise = 1.0 - g1_denoise = Norm1(lam1_denoise) + #g1_denoise = Norm1(lam1_denoise) + g1_denoise = lam1_denoise * L1Norm() # Initial guess x_init_denoise = ImageData(np.zeros((N, N))) |