summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-04-01 16:33:54 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2019-04-01 16:34:04 +0100
commit93517aa9f1472458fa962beae1abebb3e1223a6c (patch)
tree9eda6ad5a668cd66d92bf950ceb49f53aa661c26 /Wrappers/Python
parentad9e67c197aa347a83f59f3a0d7ab96745bef8ad (diff)
downloadframework-93517aa9f1472458fa962beae1abebb3e1223a6c.tar.gz
framework-93517aa9f1472458fa962beae1abebb3e1223a6c.tar.bz2
framework-93517aa9f1472458fa962beae1abebb3e1223a6c.tar.xz
framework-93517aa9f1472458fa962beae1abebb3e1223a6c.zip
PDHG as Algorithm
Diffstat (limited to 'Wrappers/Python')
-rw-r--r--Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py115
1 files changed, 60 insertions, 55 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
index 7e55ee8..fb2bfd8 100644
--- a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
@@ -5,6 +5,8 @@ Created on Mon Feb 4 16:18:06 2019
@author: evangelos
"""
+from ccpi.optimisation.algorithms import Algorithm
+
from ccpi.framework import ImageData
import numpy as np
@@ -13,67 +15,70 @@ import time
from ccpi.optimisation.operators import BlockOperator
from ccpi.framework import BlockDataContainer
-def PDHG(f, g, operator, tau = None, sigma = None, opt = None, **kwargs):
-
- # algorithmic parameters
- if opt is None:
- opt = {'tol': 1e-6, 'niter': 500, 'show_iter': 100, \
- 'memopt': False}
-
- if sigma is None and tau is None:
- raise ValueError('Need sigma*tau||K||^2<1')
-
- niter = opt['niter'] if 'niter' in opt.keys() else 1000
- tol = opt['tol'] if 'tol' in opt.keys() else 1e-4
- memopt = opt['memopt'] if 'memopt' in opt.keys() else False
- show_iter = opt['show_iter'] if 'show_iter' in opt.keys() else False
- stop_crit = opt['stop_crit'] if 'stop_crit' in opt.keys() else False
+class PDHG(Algorithm):
+ '''Primal Dual Hybrid Gradient'''
- if isinstance(operator, BlockOperator):
- x_old = operator.domain_geometry().allocate()
- y_old = operator.range_geometry().allocate()
- else:
- x_old = operator.domain_geometry().allocate()
- y_old = operator.range_geometry().allocate()
-
-
- xbar = x_old
- x_tmp = x_old
- x = x_old
-
- y_tmp = y_old
- y = y_tmp
-
- # relaxation parameter
- theta = 1
-
- t = time.time()
-
- objective = []
+ def __init__(self, **kwargs):
+ super(PDHG, self).__init__()
+ self.f = kwargs.get('f', None)
+ self.operator = kwargs.get('operator', None)
+ self.g = kwargs.get('g', None)
+ self.tau = kwargs.get('tau', None)
+ self.sigma = kwargs.get('sigma', None)
+
+ if self.f is not None and self.operator is not None and \
+ self.g is not None:
+ print ("Calling from creator")
+ self.set_up(self.f,
+ self.operator,
+ self.g,
+ self.tau,
+ self.sigma)
+
+ def set_up(self, f, g, operator, tau = None, sigma = None, opt = None, **kwargs):
+ # algorithmic parameters
+
+ if sigma is None and tau is None:
+ raise ValueError('Need sigma*tau||K||^2<1')
+
- for i in range(niter):
+ self.x_old = self.operator.domain_geometry().allocate()
+ self.y_old = self.operator.range_geometry().allocate()
+ self.xbar = self.x_old.copy()
+ #x_tmp = x_old
+ self.x = self.x_old.copy()
+ self.y = self.y_old.copy()
+ #y_tmp = y_old
+ #y = y_tmp
+
+ # relaxation parameter
+ self.theta = 1
+
+ def update(self):
# Gradient descent, Dual problem solution
- y_tmp = y_old + sigma * operator.direct(xbar)
- y = f.proximal_conjugate(y_tmp, sigma)
+ self.y_old += self.sigma * self.operator.direct(self.xbar)
+ self.y = self.f.proximal_conjugate(self.y_old, self.sigma)
# Gradient ascent, Primal problem solution
- x_tmp = x_old - tau * operator.adjoint(y)
- x = g.proximal(x_tmp, tau)
+ self.x_old -= self.tau * self.operator.adjoint(self.y)
+ self.x = self.g.proximal(self.x_old, self.tau)
#Update
- xbar = x + theta * (x - x_old)
-
- x_old = x
- y_old = y
-
-# if i%100==0:
-#
-# plt.imshow(x.as_array()[100])
-# plt.show()
-# print(f(operator.direct(x)) + g(x), i)
-
- t_end = time.time()
-
- return x, t_end - t, objective
+ #xbar = x + theta * (x - x_old)
+ self.xbar.fill(self.x)
+ self.xbar -= self.x_old
+ self.xbar *= self.theta
+ self.xbar += self.x
+
+ self.x_old.fill(self.x)
+ self.y_old.fill(self.y)
+ #self.y_old = y.copy()
+ #self.y = self.y_old
+
+ def update_objective(self):
+ self.loss.append([self.f(self.operator.direct(self.x)) + self.g(self.x),
+ -(self.f.convex_conjugate(self.y) + self.g.convex_conjugate(- 1 * self.operator.adjoint(self.y)))
+ ])
+