summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-04-11 12:02:02 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2019-04-11 12:02:02 +0100
commit3ff8a543fb4ef59179ce3490bc28b8f61bf979ac (patch)
tree0daecd9e9dccdf95767feb502d3cd4c6e633e028
parent5c74019510f95599b87ba869a7b8efc71edcde23 (diff)
downloadframework-3ff8a543fb4ef59179ce3490bc28b8f61bf979ac.tar.gz
framework-3ff8a543fb4ef59179ce3490bc28b8f61bf979ac.tar.bz2
framework-3ff8a543fb4ef59179ce3490bc28b8f61bf979ac.tar.xz
framework-3ff8a543fb4ef59179ce3490bc28b8f61bf979ac.zip
fix PDHG optimised
-rw-r--r--Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py74
1 files changed, 52 insertions, 22 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
index 0b9921c..086e322 100644
--- a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
@@ -45,36 +45,66 @@ class PDHG(Algorithm):
self.y_old = self.operator.range_geometry().allocate()
self.xbar = self.x_old.copy()
- #x_tmp = x_old
+
self.x = self.x_old.copy()
self.y = self.y_old.copy()
- #y_tmp = y_old
+ if self.memopt:
+ self.y_tmp = self.y_old.copy()
+ self.x_tmp = self.x_old.copy()
#y = y_tmp
# relaxation parameter
self.theta = 1
def update(self):
- # Gradient descent, Dual problem solution
- self.y_old += self.sigma * self.operator.direct(self.xbar)
- self.y = self.f.proximal_conjugate(self.y_old, self.sigma)
-
- # Gradient ascent, Primal problem solution
- self.x_old -= self.tau * self.operator.adjoint(self.y)
- self.x = self.g.proximal(self.x_old, self.tau)
-
- #Update
- #xbar = x + theta * (x - x_old)
- self.xbar.fill(self.x)
- self.xbar -= self.x_old
- self.xbar *= self.theta
- self.xbar += self.x
-
-# self.x_old.fill(self.x)
-# self.y_old.fill(self.y)
- self.y_old = self.y.copy()
- self.x_old = self.x.copy()
- #self.y = self.y_old
+ if self.memopt:
+ # Gradient descent, Dual problem solution
+ # self.y_old += self.sigma * self.operator.direct(self.xbar)
+ self.operator.direct(self.xbar, out=self.y_tmp)
+ self.y_tmp *= self.sigma
+ self.y_old += self.y_tmp
+
+ #self.y = self.f.proximal_conjugate(self.y_old, self.sigma)
+ self.f.proximal_conjugate(self.y_old, self.sigma, out=self.y)
+
+ # Gradient ascent, Primal problem solution
+ self.operator.adjoint(self.y, out=self.x_tmp)
+ self.x_tmp *= self.tau
+ self.x_old -= self.x_tmp
+
+ self.g.proximal(self.x_old, self.tau, out=self.x)
+
+ #Update
+ self.x.subtract(self.x_old, out=self.xbar)
+ #self.xbar -= self.x_old
+ self.xbar *= self.theta
+ self.xbar += self.x
+
+ self.x_old.fill(self.x)
+ self.y_old.fill(self.y)
+ #self.y_old = self.y.copy()
+ #self.x_old = self.x.copy()
+ else:
+ # Gradient descent, Dual problem solution
+ self.y_old += self.sigma * self.operator.direct(self.xbar)
+ self.y = self.f.proximal_conjugate(self.y_old, self.sigma)
+
+ # Gradient ascent, Primal problem solution
+ self.x_old -= self.tau * self.operator.adjoint(self.y)
+ self.x = self.g.proximal(self.x_old, self.tau)
+
+ #Update
+ #xbar = x + theta * (x - x_old)
+ self.xbar.fill(self.x)
+ self.xbar -= self.x_old
+ self.xbar *= self.theta
+ self.xbar += self.x
+
+ self.x_old.fill(self.x)
+ self.y_old.fill(self.y)
+ #self.y_old = self.y.copy()
+ #self.x_old = self.x.copy()
+ #self.y = self.y_old
def update_objective(self):
p1 = self.f(self.operator.direct(self.x)) + self.g(self.x)