summaryrefslogtreecommitdiffstats
path: root/Wrappers
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-06-07 15:39:46 +0000
committerGitHub <noreply@github.com>2019-06-07 15:39:46 +0000
commit53e09caf548d51ca1d071f4bffe249afab5649f6 (patch)
tree5a7c6600e981c517bf6a6ebd62890958590d867f /Wrappers
parent2fa4ef8cdb4ccc420007b65d975682c9939e0171 (diff)
parent1d7324cec6f10f00b73af2bab3469202c5cc2e87 (diff)
downloadframework-53e09caf548d51ca1d071f4bffe249afab5649f6.tar.gz
framework-53e09caf548d51ca1d071f4bffe249afab5649f6.tar.bz2
framework-53e09caf548d51ca1d071f4bffe249afab5649f6.tar.xz
framework-53e09caf548d51ca1d071f4bffe249afab5649f6.zip
Merge pull request #305 from vais-ral/config_alg
Config alg
Diffstat (limited to 'Wrappers')
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/Algorithm.py3
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/CGLS.py6
-rw-r--r--Wrappers/Python/ccpi/optimisation/algorithms/FBPD.py86
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/FISTA.py1
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py3
-rw-r--r--Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py10
-rw-r--r--Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py8
-rw-r--r--Wrappers/Python/ccpi/optimisation/algorithms/__init__.py2
-rwxr-xr-xWrappers/Python/test/test_algorithms.py1
9 files changed, 25 insertions, 95 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/Algorithm.py b/Wrappers/Python/ccpi/optimisation/algorithms/Algorithm.py
index 4fbf83b..c62d0ea 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/Algorithm.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/Algorithm.py
@@ -51,6 +51,7 @@ class Algorithm(object):
self.__max_iteration = kwargs.get('max_iteration', 0)
self.__loss = []
self.memopt = False
+ self.configured = False
self.timing = []
self.update_objective_interval = kwargs.get('update_objective_interval', 1)
def set_up(self, *args, **kwargs):
@@ -86,6 +87,8 @@ class Algorithm(object):
raise StopIteration()
else:
time0 = time.time()
+ if not self.configured:
+ raise ValueError('Algorithm not configured correctly. Please run set_up.')
self.update()
self.timing.append( time.time() - time0 )
if self.iteration % self.update_objective_interval == 0:
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
index 4d4843c..6b610a0 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py
@@ -23,6 +23,7 @@ Created on Thu Feb 21 11:11:23 2019
"""
from ccpi.optimisation.algorithms import Algorithm
+from ccpi.optimisation.functions import Norm2Sq
class CGLS(Algorithm):
@@ -59,6 +60,9 @@ class CGLS(Algorithm):
# self.normr2 = sum(self.normr2)
#self.normr2 = numpy.sqrt(self.normr2)
#print ("set_up" , self.normr2)
+ n = Norm2Sq(operator, self.data)
+ self.loss.append(n(x_init))
+ self.configured = True
def update(self):
@@ -84,4 +88,4 @@ class CGLS(Algorithm):
self.d = s + beta*self.d
def update_objective(self):
- self.loss.append(self.r.squared_norm()) \ No newline at end of file
+ self.loss.append(self.r.squared_norm())
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/FBPD.py b/Wrappers/Python/ccpi/optimisation/algorithms/FBPD.py
deleted file mode 100644
index aa07359..0000000
--- a/Wrappers/Python/ccpi/optimisation/algorithms/FBPD.py
+++ /dev/null
@@ -1,86 +0,0 @@
-# -*- coding: utf-8 -*-
-# This work is part of the Core Imaging Library developed by
-# Visual Analytics and Imaging System Group of the Science Technology
-# Facilities Council, STFC
-
-# Copyright 2019 Edoardo Pasca
-
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-
-# http://www.apache.org/licenses/LICENSE-2.0
-
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Created on Thu Feb 21 11:09:03 2019
-
-@author: ofn77899
-"""
-
-from ccpi.optimisation.algorithms import Algorithm
-from ccpi.optimisation.functions import ZeroFunction
-
-class FBPD(Algorithm):
- '''FBPD Algorithm
-
- Parameters:
- x_init: initial guess
- f: constraint
- g: data fidelity
- h: regularizer
- opt: additional algorithm
- '''
- constraint = None
- data_fidelity = None
- regulariser = None
- def __init__(self, **kwargs):
- pass
- def set_up(self, x_init, operator=None, constraint=None, data_fidelity=None,\
- regulariser=None, opt=None):
-
- # default inputs
- if constraint is None:
- self.constraint = ZeroFun()
- else:
- self.constraint = constraint
- if data_fidelity is None:
- data_fidelity = ZeroFun()
- else:
- self.data_fidelity = data_fidelity
- if regulariser is None:
- self.regulariser = ZeroFun()
- else:
- self.regulariser = regulariser
-
- # algorithmic parameters
-
-
- # step-sizes
- self.tau = 2 / (self.data_fidelity.L + 2)
- self.sigma = (1/self.tau - self.data_fidelity.L/2) / self.regulariser.L
-
- self.inv_sigma = 1/self.sigma
-
- # initialization
- self.x = x_init
- self.y = operator.direct(self.x)
-
-
- def update(self):
-
- # primal forward-backward step
- x_old = self.x
- self.x = self.x - self.tau * ( self.data_fidelity.grad(self.x) + self.operator.adjoint(self.y) )
- self.x = self.constraint.prox(self.x, self.tau);
-
- # dual forward-backward step
- self.y = self.y + self.sigma * self.operator.direct(2*self.x - x_old);
- self.y = self.y - self.sigma * self.regulariser.prox(self.inv_sigma*self.y, self.inv_sigma);
-
- # time and criterion
- self.loss = self.constraint(self.x) + self.data_fidelity(self.x) + self.regulariser(self.operator.direct(self.x))
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
index 3c7a8d1..647ae98 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
@@ -58,6 +58,7 @@ class FISTA(Algorithm):
self.t_old = 1
self.update_objective()
+ self.configured = True
def update(self):
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py b/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py
index 14763c5..34bf954 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py
@@ -40,7 +40,7 @@ class GradientDescent(Algorithm):
if k in args:
args.pop(args.index(k))
if len(args) == 0:
- return self.set_up(x_init=kwargs['x_init'],
+ self.set_up(x_init=kwargs['x_init'],
objective_function=kwargs['objective_function'],
rate=kwargs['rate'])
@@ -61,6 +61,7 @@ class GradientDescent(Algorithm):
self.memopt = False
if self.memopt:
self.x_update = x_init.copy()
+ self.configured = True
def update(self):
'''Single iteration'''
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
index 39b092b..3afd8b0 100644
--- a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
@@ -23,10 +23,16 @@ class PDHG(Algorithm):
self.operator = kwargs.get('operator', None)
self.g = kwargs.get('g', None)
self.tau = kwargs.get('tau', None)
- self.sigma = kwargs.get('sigma', None)
+ self.sigma = kwargs.get('sigma', 1.)
+
if self.f is not None and self.operator is not None and \
self.g is not None:
+ if self.tau is None:
+ # Compute operator Norm
+ normK = self.operator.norm()
+ # Primal & dual stepsizes
+ self.tau = 1/(self.sigma*normK**2)
print ("Calling from creator")
self.set_up(self.f,
self.g,
@@ -57,6 +63,8 @@ class PDHG(Algorithm):
# relaxation parameter
self.theta = 1
+ self.update_objective()
+ self.configured = True
def update(self):
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py b/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py
index 30584d4..c73d323 100644
--- a/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py
@@ -59,6 +59,7 @@ class SIRT(Algorithm):
# Set up scaling matrices D and M.
self.M = 1/self.operator.direct(self.operator.domain_geometry().allocate(value=1.0))
self.D = 1/self.operator.adjoint(self.operator.range_geometry().allocate(value=1.0))
+ self.configured = True
def update(self):
@@ -67,8 +68,9 @@ class SIRT(Algorithm):
self.x += self.relax_par * (self.D*self.operator.adjoint(self.M*self.r))
- if self.constraint != None:
- self.x = self.constraint.prox(self.x,None)
+ if self.constraint is not None:
+ self.x = self.constraint.proximal(self.x,None)
+ # self.constraint.proximal(self.x,None, out=self.x)
def update_objective(self):
- self.loss.append(self.r.squared_norm()) \ No newline at end of file
+ self.loss.append(self.r.squared_norm())
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/__init__.py b/Wrappers/Python/ccpi/optimisation/algorithms/__init__.py
index 2dbacfc..8f255f3 100644
--- a/Wrappers/Python/ccpi/optimisation/algorithms/__init__.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/__init__.py
@@ -27,7 +27,5 @@ from .CGLS import CGLS
from .SIRT import SIRT
from .GradientDescent import GradientDescent
from .FISTA import FISTA
-from .FBPD import FBPD
from .PDHG import PDHG
-from .PDHG import PDHG_old
diff --git a/Wrappers/Python/test/test_algorithms.py b/Wrappers/Python/test/test_algorithms.py
index 4121358..8c398f4 100755
--- a/Wrappers/Python/test/test_algorithms.py
+++ b/Wrappers/Python/test/test_algorithms.py
@@ -18,7 +18,6 @@ from ccpi.optimisation.functions import Norm2Sq, ZeroFunction, \
from ccpi.optimisation.algorithms import GradientDescent
from ccpi.optimisation.algorithms import CGLS
from ccpi.optimisation.algorithms import FISTA
-from ccpi.optimisation.algorithms import FBPD