summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Wrappers/Python/ccpi/optimisation/Algorithms.py60
1 files changed, 33 insertions, 27 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/Algorithms.py b/Wrappers/Python/ccpi/optimisation/Algorithms.py
index 325ed77..de7f0f8 100644
--- a/Wrappers/Python/ccpi/optimisation/Algorithms.py
+++ b/Wrappers/Python/ccpi/optimisation/Algorithms.py
@@ -22,22 +22,21 @@ from ccpi.optimisation.funcs import ZeroFun
class Algorithm(object):
'''Base class for iterative algorithms
-
- provides the minimal infrastructure.
- Algorithms are iterables so can be easily run in a for loop. They will
- stop as soon as the stop cryterion is met.
-
- The user is required to implement the set_up, __init__, update and
- should_stop methods
- '''
- iteration = 0
- stop_cryterion = 'max_iter'
- __max_iteration = 0
- __loss = []
- memopt = False
- timing = []
- def __init__(self, *args, **kwargs):
- pass
+
+ provides the minimal infrastructure.
+ Algorithms are iterables so can be easily run in a for loop. They will
+ stop as soon as the stop cryterion is met.
+ The user is required to implement the set_up, __init__, update and
+ should_stop and update_objective methods
+ '''
+
+ def __init__(self):
+ self.iteration = 0
+ self.stop_cryterion = 'max_iter'
+ self.__max_iteration = 0
+ self.__loss = []
+ self.memopt = False
+ self.timing = []
def set_up(self, *args, **kwargs):
raise NotImplementedError()
def update(self):
@@ -59,6 +58,7 @@ class Algorithm(object):
time0 = time.time()
self.update()
self.timing.append( time.time() - time0 )
+ self.update_objective()
self.iteration += 1
def get_output(self):
'''Returns the solution found'''
@@ -66,6 +66,8 @@ class Algorithm(object):
def get_current_loss(self):
'''Returns the current value of the loss function'''
return self.__loss[-1]
+ def update_objective(self):
+ raise NotImplementedError()
@property
def loss(self):
return self.__loss
@@ -80,13 +82,15 @@ class Algorithm(object):
class GradientDescent(Algorithm):
'''Implementation of a simple Gradient Descent algorithm
'''
- x = None
- rate = 0
- objective_function = None
- regulariser = None
+
def __init__(self, **kwargs):
'''initialisation can be done at creation time if all
proper variables are passed or later with set_up'''
+ super(GradientDescent, self).__init__()
+ self.x = None
+ self.rate = 0
+ self.objective_function = None
+ self.regulariser = None
args = ['x_init', 'objective_function', 'rate']
for k,v in kwargs.items():
if k in args:
@@ -117,7 +121,8 @@ class GradientDescent(Algorithm):
self.x += self.x_update
else:
self.x += -self.rate * self.objective_function.grad(self.x)
-
+
+ def update_objective(self):
self.loss.append(self.objective_function(self.x))
@@ -136,13 +141,15 @@ class FISTA(Algorithm):
h:
opt: additional algorithm
'''
- f = None
- g = None
- invL = None
- t_old = 1
+
def __init__(self, **kwargs):
'''initialisation can be done at creation time if all
proper variables are passed or later with set_up'''
+ super(FISTA, self).__init__()
+ self.f = None
+ self.g = None
+ self.invL = None
+ self.t_old = 1
args = ['x_init', 'f', 'g', 'opt']
for k,v in kwargs.items():
if k in args:
@@ -232,10 +239,9 @@ class FISTA(Algorithm):
self.x_old = self.x.copy()
self.t_old = self.t
-
+ def update_objective(self):
self.loss.append( self.f(self.x) + self.g(self.x) )
-
class FBPD(Algorithm)
'''FBPD Algorithm