diff options
Diffstat (limited to 'Wrappers')
| -rwxr-xr-x | Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py | 39 | 
1 files changed, 4 insertions, 35 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py index d509621..3f285be 100755 --- a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py @@ -59,35 +59,18 @@ class FISTA(Algorithm):          if opt is None:               opt = {'tol': 1e-4} -#        self.tol = opt['tol'] if 'tol' in opt.keys() else 1e-4 -#        memopt = opt['memopt'] if 'memopt' in opt.keys() else False -#        self.memopt = memopt -             -        # initialization -#        if memopt:          self.y = x_init.copy()          self.x_old = x_init.copy()          self.x = x_init.copy()          self.u = x_init.copy()             -#        else: -#            self.x_old = x_init.copy() -#            self.y = x_init.copy() -         -        #timing = numpy.zeros(max_iter) -        #criter = numpy.zeros(max_iter) -         -     + +          self.invL = 1/f.L          self.t_old = 1      def update(self): -    # algorithm loop -    #for it in range(0, max_iter): -     -#        if self.memopt: -            # u = y - invL*f.grad(y) -            # store the result in x_old +          self.f.gradient(self.y, out=self.u)          self.u.__imul__( -self.invL )          self.u.__iadd__( self.y ) @@ -96,26 +79,12 @@ class FISTA(Algorithm):          self.t = 0.5*(1 + numpy.sqrt(1 + 4*(self.t_old**2))) -        # y = x + (t_old-1)/t*(x-x_old)          self.x.subtract(self.x_old, out=self.y)          self.y.__imul__ ((self.t_old-1)/self.t)          self.y.__iadd__( self.x )          self.x_old.fill(self.x) -        self.t_old = self.t -             -             -#        else: -#            u = self.y - self.invL*self.f.gradient(self.y) -#             -#            self.x = self.g.proximal(u,self.invL) -#             -#            self.t = 0.5*(1 + numpy.sqrt(1 + 4*(self.t_old**2))) -#             -#            self.y = self.x + (self.t_old-1)/self.t*(self.x-self.x_old) -#             -#            self.x_old = self.x.copy() -#            self.t_old = self.t +        self.t_old = self.t                  def update_objective(self):          self.loss.append( self.f(self.x) + self.g(self.x) )
\ No newline at end of file  | 
