From 93517aa9f1472458fa962beae1abebb3e1223a6c Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Mon, 1 Apr 2019 16:33:54 +0100 Subject: PDHG as Algorithm --- .../Python/ccpi/optimisation/algorithms/PDHG.py | 115 +++++++++++---------- 1 file changed, 60 insertions(+), 55 deletions(-) (limited to 'Wrappers/Python') diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py index 7e55ee8..fb2bfd8 100644 --- a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py @@ -5,6 +5,8 @@ Created on Mon Feb 4 16:18:06 2019 @author: evangelos """ +from ccpi.optimisation.algorithms import Algorithm + from ccpi.framework import ImageData import numpy as np @@ -13,67 +15,70 @@ import time from ccpi.optimisation.operators import BlockOperator from ccpi.framework import BlockDataContainer -def PDHG(f, g, operator, tau = None, sigma = None, opt = None, **kwargs): - - # algorithmic parameters - if opt is None: - opt = {'tol': 1e-6, 'niter': 500, 'show_iter': 100, \ - 'memopt': False} - - if sigma is None and tau is None: - raise ValueError('Need sigma*tau||K||^2<1') - - niter = opt['niter'] if 'niter' in opt.keys() else 1000 - tol = opt['tol'] if 'tol' in opt.keys() else 1e-4 - memopt = opt['memopt'] if 'memopt' in opt.keys() else False - show_iter = opt['show_iter'] if 'show_iter' in opt.keys() else False - stop_crit = opt['stop_crit'] if 'stop_crit' in opt.keys() else False +class PDHG(Algorithm): + '''Primal Dual Hybrid Gradient''' - if isinstance(operator, BlockOperator): - x_old = operator.domain_geometry().allocate() - y_old = operator.range_geometry().allocate() - else: - x_old = operator.domain_geometry().allocate() - y_old = operator.range_geometry().allocate() - - - xbar = x_old - x_tmp = x_old - x = x_old - - y_tmp = y_old - y = y_tmp - - # relaxation parameter - theta = 1 - - t = time.time() - - objective = [] + def __init__(self, **kwargs): + super(PDHG, self).__init__() + self.f = kwargs.get('f', None) + self.operator = kwargs.get('operator', None) + self.g = kwargs.get('g', None) + self.tau = kwargs.get('tau', None) + self.sigma = kwargs.get('sigma', None) + + if self.f is not None and self.operator is not None and \ + self.g is not None: + print ("Calling from creator") + self.set_up(self.f, + self.operator, + self.g, + self.tau, + self.sigma) + + def set_up(self, f, g, operator, tau = None, sigma = None, opt = None, **kwargs): + # algorithmic parameters + + if sigma is None and tau is None: + raise ValueError('Need sigma*tau||K||^2<1') + - for i in range(niter): + self.x_old = self.operator.domain_geometry().allocate() + self.y_old = self.operator.range_geometry().allocate() + self.xbar = self.x_old.copy() + #x_tmp = x_old + self.x = self.x_old.copy() + self.y = self.y_old.copy() + #y_tmp = y_old + #y = y_tmp + + # relaxation parameter + self.theta = 1 + + def update(self): # Gradient descent, Dual problem solution - y_tmp = y_old + sigma * operator.direct(xbar) - y = f.proximal_conjugate(y_tmp, sigma) + self.y_old += self.sigma * self.operator.direct(self.xbar) + self.y = self.f.proximal_conjugate(self.y_old, self.sigma) # Gradient ascent, Primal problem solution - x_tmp = x_old - tau * operator.adjoint(y) - x = g.proximal(x_tmp, tau) + self.x_old -= self.tau * self.operator.adjoint(self.y) + self.x = self.g.proximal(self.x_old, self.tau) #Update - xbar = x + theta * (x - x_old) - - x_old = x - y_old = y - -# if i%100==0: -# -# plt.imshow(x.as_array()[100]) -# plt.show() -# print(f(operator.direct(x)) + g(x), i) - - t_end = time.time() - - return x, t_end - t, objective + #xbar = x + theta * (x - x_old) + self.xbar.fill(self.x) + self.xbar -= self.x_old + self.xbar *= self.theta + self.xbar += self.x + + self.x_old.fill(self.x) + self.y_old.fill(self.y) + #self.y_old = y.copy() + #self.y = self.y_old + + def update_objective(self): + self.loss.append([self.f(self.operator.direct(self.x)) + self.g(self.x), + -(self.f.convex_conjugate(self.y) + self.g.convex_conjugate(- 1 * self.operator.adjoint(self.y))) + ]) + -- cgit v1.2.3