diff options
author | Edoardo Pasca <edo.paskino@gmail.com> | 2019-02-13 15:46:06 +0000 |
---|---|---|
committer | Edoardo Pasca <edo.paskino@gmail.com> | 2019-02-13 15:46:06 +0000 |
commit | 00626d27f25aa19986a711703187a88bad2d2c43 (patch) | |
tree | 1bf4d63393179011f6026f7ec2cda4485cc0acce | |
parent | 64371504bd0bfeea4bba2b1fb3aa064034baadb1 (diff) | |
download | framework-00626d27f25aa19986a711703187a88bad2d2c43.tar.gz framework-00626d27f25aa19986a711703187a88bad2d2c43.tar.bz2 framework-00626d27f25aa19986a711703187a88bad2d2c43.tar.xz framework-00626d27f25aa19986a711703187a88bad2d2c43.zip |
Removed class members of Algorithm class
added update_objective
-rw-r--r-- | Wrappers/Python/ccpi/optimisation/Algorithms.py | 60 |
1 files changed, 33 insertions, 27 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/Algorithms.py b/Wrappers/Python/ccpi/optimisation/Algorithms.py index 325ed77..de7f0f8 100644 --- a/Wrappers/Python/ccpi/optimisation/Algorithms.py +++ b/Wrappers/Python/ccpi/optimisation/Algorithms.py @@ -22,22 +22,21 @@ from ccpi.optimisation.funcs import ZeroFun class Algorithm(object): '''Base class for iterative algorithms - - provides the minimal infrastructure. - Algorithms are iterables so can be easily run in a for loop. They will - stop as soon as the stop cryterion is met. - - The user is required to implement the set_up, __init__, update and - should_stop methods - ''' - iteration = 0 - stop_cryterion = 'max_iter' - __max_iteration = 0 - __loss = [] - memopt = False - timing = [] - def __init__(self, *args, **kwargs): - pass + + provides the minimal infrastructure. + Algorithms are iterables so can be easily run in a for loop. They will + stop as soon as the stop cryterion is met. + The user is required to implement the set_up, __init__, update and + should_stop and update_objective methods + ''' + + def __init__(self): + self.iteration = 0 + self.stop_cryterion = 'max_iter' + self.__max_iteration = 0 + self.__loss = [] + self.memopt = False + self.timing = [] def set_up(self, *args, **kwargs): raise NotImplementedError() def update(self): @@ -59,6 +58,7 @@ class Algorithm(object): time0 = time.time() self.update() self.timing.append( time.time() - time0 ) + self.update_objective() self.iteration += 1 def get_output(self): '''Returns the solution found''' @@ -66,6 +66,8 @@ class Algorithm(object): def get_current_loss(self): '''Returns the current value of the loss function''' return self.__loss[-1] + def update_objective(self): + raise NotImplementedError() @property def loss(self): return self.__loss @@ -80,13 +82,15 @@ class Algorithm(object): class GradientDescent(Algorithm): '''Implementation of a simple Gradient Descent algorithm ''' - x = None - rate = 0 - objective_function = None - regulariser = None + def __init__(self, **kwargs): '''initialisation can be done at creation time if all proper variables are passed or later with set_up''' + super(GradientDescent, self).__init__() + self.x = None + self.rate = 0 + self.objective_function = None + self.regulariser = None args = ['x_init', 'objective_function', 'rate'] for k,v in kwargs.items(): if k in args: @@ -117,7 +121,8 @@ class GradientDescent(Algorithm): self.x += self.x_update else: self.x += -self.rate * self.objective_function.grad(self.x) - + + def update_objective(self): self.loss.append(self.objective_function(self.x)) @@ -136,13 +141,15 @@ class FISTA(Algorithm): h: opt: additional algorithm ''' - f = None - g = None - invL = None - t_old = 1 + def __init__(self, **kwargs): '''initialisation can be done at creation time if all proper variables are passed or later with set_up''' + super(FISTA, self).__init__() + self.f = None + self.g = None + self.invL = None + self.t_old = 1 args = ['x_init', 'f', 'g', 'opt'] for k,v in kwargs.items(): if k in args: @@ -232,10 +239,9 @@ class FISTA(Algorithm): self.x_old = self.x.copy() self.t_old = self.t - + def update_objective(self): self.loss.append( self.f(self.x) + self.g(self.x) ) - class FBPD(Algorithm) '''FBPD Algorithm |