diff options
author | epapoutsellis <epapoutsellis@gmail.com> | 2019-06-03 21:35:53 +0100 |
---|---|---|
committer | epapoutsellis <epapoutsellis@gmail.com> | 2019-06-03 21:35:53 +0100 |
commit | 521cbed2e02c38f8a277d23c02f1a7eb9c8542ca (patch) | |
tree | 37fa2d865bb654f6c5440e75b9f5077778580133 | |
parent | 8ebe128bf1a893843f9ae34a2a7d5fb4ae91da98 (diff) | |
download | framework-521cbed2e02c38f8a277d23c02f1a7eb9c8542ca.tar.gz framework-521cbed2e02c38f8a277d23c02f1a7eb9c8542ca.tar.bz2 framework-521cbed2e02c38f8a277d23c02f1a7eb9c8542ca.tar.xz framework-521cbed2e02c38f8a277d23c02f1a7eb9c8542ca.zip |
fix memopt/input FISTA
-rwxr-xr-x | Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py | 42 |
1 files changed, 16 insertions, 26 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py index 3f285be..04e7c38 100755 --- a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py @@ -20,40 +20,28 @@ class FISTA(Algorithm): x_init: initial guess f: data fidelity g: regularizer - h: - opt: additional algorithm + opt: additional options ''' - + + def __init__(self, **kwargs): '''initialisation can be done at creation time if all proper variables are passed or later with set_up''' super(FISTA, self).__init__() - self.f = None - self.g = None + self.f = kwargs.get('f', None) + self.g = kwargs.get('g', None) + self.x_init = kwargs.get('x_init',None) self.invL = None self.t_old = 1 - args = ['x_init', 'f', 'g', 'opt'] - for k,v in kwargs.items(): - if k in args: - args.pop(args.index(k)) - if len(args) == 0: - return self.set_up(kwargs['x_init'], - f=kwargs['f'], - g=kwargs['g'], - opt=kwargs['opt']) + if self.f is not None and self.g is not None: + print ("Calling from creator") + self.set_up(self.x_init, self.f, self.g) + - def set_up(self, x_init, f=None, g=None, opt=None): + def set_up(self, x_init, f, g, opt=None, **kwargs): - # default inputs - if f is None: - self.f = ZeroFunction() - else: - self.f = f - if g is None: - g = ZeroFunction() - self.g = g - else: - self.g = g + self.f = f + self.g = g # algorithmic parameters if opt is None: @@ -87,4 +75,6 @@ class FISTA(Algorithm): self.t_old = self.t def update_objective(self): - self.loss.append( self.f(self.x) + self.g(self.x) )
\ No newline at end of file + self.loss.append( self.f(self.x) + self.g(self.x) ) + + |