summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py78
-rw-r--r--Wrappers/Python/wip/pdhg_TV_tomography2D.py2
2 files changed, 51 insertions, 29 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
index 110f998..0f5e8ef 100644
--- a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
@@ -47,34 +47,43 @@ class PDHG(Algorithm):
self.y_old = self.operator.range_geometry().allocate()
self.xbar = self.x_old.copy()
-
+ self.x_tmp = self.x_old.copy()
self.x = self.x_old.copy()
- self.y = self.y_old.copy()
- if self.memopt:
- self.y_tmp = self.y_old.copy()
- self.x_tmp = self.x_old.copy()
- #y = y_tmp
+
+ self.y_tmp = self.y_old.copy()
+ self.y = self.y_tmp.copy()
+
+
+
+ #self.y = self.y_old.copy()
+
+
+ #if self.memopt:
+ # self.y_tmp = self.y_old.copy()
+ # self.x_tmp = self.x_old.copy()
+
# relaxation parameter
self.theta = 1
def update(self):
+
if self.memopt:
# Gradient descent, Dual problem solution
# self.y_old += self.sigma * self.operator.direct(self.xbar)
self.operator.direct(self.xbar, out=self.y_tmp)
self.y_tmp *= self.sigma
- self.y_old += self.y_tmp
+ self.y_tmp += self.y_old
#self.y = self.f.proximal_conjugate(self.y_old, self.sigma)
- self.f.proximal_conjugate(self.y_old, self.sigma, out=self.y)
+ self.f.proximal_conjugate(self.y_tmp, self.sigma, out=self.y)
# Gradient ascent, Primal problem solution
self.operator.adjoint(self.y, out=self.x_tmp)
- self.x_tmp *= self.tau
- self.x_old -= self.x_tmp
+ self.x_tmp *= -1*self.tau
+ self.x_tmp += self.x_old
- self.g.proximal(self.x_old, self.tau, out=self.x)
+ self.g.proximal(self.x_tmp, self.tau, out=self.x)
#Update
self.x.subtract(self.x_old, out=self.xbar)
@@ -83,7 +92,8 @@ class PDHG(Algorithm):
self.x_old.fill(self.x)
self.y_old.fill(self.y)
-
+
+
else:
# Gradient descent, Dual problem solution
self.y_old += self.sigma * self.operator.direct(self.xbar)
@@ -94,14 +104,23 @@ class PDHG(Algorithm):
self.x = self.g.proximal(self.x_old, self.tau)
#Update
- #xbar = x + theta * (x - x_old)
- self.xbar.fill(self.x)
- self.xbar -= self.x_old
+ self.x.subtract(self.x_old, out=self.xbar)
self.xbar *= self.theta
self.xbar += self.x
- self.x_old = self.x
- self.y_old = self.y
+ self.x_old.fill(self.x)
+ self.y_old.fill(self.y)
+
+ #xbar = x + theta * (x - x_old)
+# self.xbar.fill(self.x)
+# self.xbar -= self.x_old
+# self.xbar *= self.theta
+# self.xbar += self.x
+
+# self.x_old.fill(self.x)
+# self.y_old.fill(self.y)
+
+
def update_objective(self):
@@ -153,7 +172,7 @@ def PDHG_old(f, g, operator, tau = None, sigma = None, opt = None, **kwargs):
if not memopt:
-
+
y_tmp = y_old + sigma * operator.direct(xbar)
y = f.proximal_conjugate(y_tmp, sigma)
@@ -163,10 +182,7 @@ def PDHG_old(f, g, operator, tau = None, sigma = None, opt = None, **kwargs):
x.subtract(x_old, out=xbar)
xbar *= theta
xbar += x
-
- x_old.fill(x)
- y_old.fill(y)
-
+
if i%50==0:
p1 = f(operator.direct(x)) + g(x)
@@ -174,7 +190,11 @@ def PDHG_old(f, g, operator, tau = None, sigma = None, opt = None, **kwargs):
primal.append(p1)
dual.append(d1)
pdgap.append(p1-d1)
- print(p1, d1, p1-d1)
+ print(p1, d1, p1-d1)
+
+ x_old.fill(x)
+ y_old.fill(y)
+
else:
@@ -192,10 +212,7 @@ def PDHG_old(f, g, operator, tau = None, sigma = None, opt = None, **kwargs):
x.subtract(x_old, out=xbar)
xbar *= theta
xbar += x
-
- x_old.fill(x)
- y_old.fill(y)
-
+
if i%50==0:
p1 = f(operator.direct(x)) + g(x)
@@ -203,7 +220,12 @@ def PDHG_old(f, g, operator, tau = None, sigma = None, opt = None, **kwargs):
primal.append(p1)
dual.append(d1)
pdgap.append(p1-d1)
- print(p1, d1, p1-d1)
+ print(p1, d1, p1-d1)
+
+ x_old.fill(x)
+ y_old.fill(y)
+
+
t_end = time.time()
diff --git a/Wrappers/Python/wip/pdhg_TV_tomography2D.py b/Wrappers/Python/wip/pdhg_TV_tomography2D.py
index 91c48c7..0e167e3 100644
--- a/Wrappers/Python/wip/pdhg_TV_tomography2D.py
+++ b/Wrappers/Python/wip/pdhg_TV_tomography2D.py
@@ -126,7 +126,7 @@ t3 = timer()
res1, time1, primal1, dual1, pdgap1 = PDHG_old(f, g, operator, tau = tau, sigma = sigma, opt = opt1)
t4 = timer()
#
-
+#
plt.figure(figsize=(15,15))
plt.subplot(3,1,1)
plt.imshow(res.as_array())