summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/FISTA.py78
1 files changed, 39 insertions, 39 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
index ee51049..d509621 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py
@@ -57,21 +57,21 @@ class FISTA(Algorithm):
# algorithmic parameters
if opt is None:
- opt = {'tol': 1e-4, 'memopt':False}
+ opt = {'tol': 1e-4}
- self.tol = opt['tol'] if 'tol' in opt.keys() else 1e-4
- memopt = opt['memopt'] if 'memopt' in opt.keys() else False
- self.memopt = memopt
+# self.tol = opt['tol'] if 'tol' in opt.keys() else 1e-4
+# memopt = opt['memopt'] if 'memopt' in opt.keys() else False
+# self.memopt = memopt
# initialization
- if memopt:
- self.y = x_init.copy()
- self.x_old = x_init.copy()
- self.x = x_init.copy()
- self.u = x_init.copy()
- else:
- self.x_old = x_init.copy()
- self.y = x_init.copy()
+# if memopt:
+ self.y = x_init.copy()
+ self.x_old = x_init.copy()
+ self.x = x_init.copy()
+ self.u = x_init.copy()
+# else:
+# self.x_old = x_init.copy()
+# self.y = x_init.copy()
#timing = numpy.zeros(max_iter)
#criter = numpy.zeros(max_iter)
@@ -85,37 +85,37 @@ class FISTA(Algorithm):
# algorithm loop
#for it in range(0, max_iter):
- if self.memopt:
+# if self.memopt:
# u = y - invL*f.grad(y)
# store the result in x_old
- self.f.gradient(self.y, out=self.u)
- self.u.__imul__( -self.invL )
- self.u.__iadd__( self.y )
- # x = g.prox(u,invL)
- self.g.proximal(self.u, self.invL, out=self.x)
-
- self.t = 0.5*(1 + numpy.sqrt(1 + 4*(self.t_old**2)))
-
- # y = x + (t_old-1)/t*(x-x_old)
- self.x.subtract(self.x_old, out=self.y)
- self.y.__imul__ ((self.t_old-1)/self.t)
- self.y.__iadd__( self.x )
-
- self.x_old.fill(self.x)
- self.t_old = self.t
-
-
- else:
- u = self.y - self.invL*self.f.gradient(self.y)
-
- self.x = self.g.proximal(u,self.invL)
-
- self.t = 0.5*(1 + numpy.sqrt(1 + 4*(self.t_old**2)))
+ self.f.gradient(self.y, out=self.u)
+ self.u.__imul__( -self.invL )
+ self.u.__iadd__( self.y )
+ # x = g.prox(u,invL)
+ self.g.proximal(self.u, self.invL, out=self.x)
+
+ self.t = 0.5*(1 + numpy.sqrt(1 + 4*(self.t_old**2)))
+
+ # y = x + (t_old-1)/t*(x-x_old)
+ self.x.subtract(self.x_old, out=self.y)
+ self.y.__imul__ ((self.t_old-1)/self.t)
+ self.y.__iadd__( self.x )
+
+ self.x_old.fill(self.x)
+ self.t_old = self.t
- self.y = self.x + (self.t_old-1)/self.t*(self.x-self.x_old)
- self.x_old = self.x.copy()
- self.t_old = self.t
+# else:
+# u = self.y - self.invL*self.f.gradient(self.y)
+#
+# self.x = self.g.proximal(u,self.invL)
+#
+# self.t = 0.5*(1 + numpy.sqrt(1 + 4*(self.t_old**2)))
+#
+# self.y = self.x + (self.t_old-1)/self.t*(self.x-self.x_old)
+#
+# self.x_old = self.x.copy()
+# self.t_old = self.t
def update_objective(self):
self.loss.append( self.f(self.x) + self.g(self.x) ) \ No newline at end of file