summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorepapoutsellis <epapoutsellis@gmail.com>2019-06-03 21:35:53 +0100
committerepapoutsellis <epapoutsellis@gmail.com>2019-06-03 21:35:53 +0100
commit521cbed2e02c38f8a277d23c02f1a7eb9c8542ca (patch)
tree37fa2d865bb654f6c5440e75b9f5077778580133
parent8ebe128bf1a893843f9ae34a2a7d5fb4ae91da98 (diff)
downloadframework-521cbed2e02c38f8a277d23c02f1a7eb9c8542ca.tar.gz
framework-521cbed2e02c38f8a277d23c02f1a7eb9c8542ca.tar.bz2
framework-521cbed2e02c38f8a277d23c02f1a7eb9c8542ca.tar.xz
framework-521cbed2e02c38f8a277d23c02f1a7eb9c8542ca.zip
fix memopt/input FISTA
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/FISTA.py42
1 files changed, 16 insertions, 26 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
index 3f285be..04e7c38 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
@@ -20,40 +20,28 @@ class FISTA(Algorithm):
x_init: initial guess
f: data fidelity
g: regularizer
- h:
- opt: additional algorithm
+ opt: additional options
'''
-
+
+
def __init__(self, **kwargs):
'''initialisation can be done at creation time if all
proper variables are passed or later with set_up'''
super(FISTA, self).__init__()
- self.f = None
- self.g = None
+ self.f = kwargs.get('f', None)
+ self.g = kwargs.get('g', None)
+ self.x_init = kwargs.get('x_init',None)
self.invL = None
self.t_old = 1
- args = ['x_init', 'f', 'g', 'opt']
- for k,v in kwargs.items():
- if k in args:
- args.pop(args.index(k))
- if len(args) == 0:
- return self.set_up(kwargs['x_init'],
- f=kwargs['f'],
- g=kwargs['g'],
- opt=kwargs['opt'])
+ if self.f is not None and self.g is not None:
+ print ("Calling from creator")
+ self.set_up(self.x_init, self.f, self.g)
+
- def set_up(self, x_init, f=None, g=None, opt=None):
+ def set_up(self, x_init, f, g, opt=None, **kwargs):
- # default inputs
- if f is None:
- self.f = ZeroFunction()
- else:
- self.f = f
- if g is None:
- g = ZeroFunction()
- self.g = g
- else:
- self.g = g
+ self.f = f
+ self.g = g
# algorithmic parameters
if opt is None:
@@ -87,4 +75,6 @@ class FISTA(Algorithm):
self.t_old = self.t
def update_objective(self):
- self.loss.append( self.f(self.x) + self.g(self.x) ) \ No newline at end of file
+ self.loss.append( self.f(self.x) + self.g(self.x) )
+
+