summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorepapoutsellis <epapoutsellis@gmail.com>2019-05-15 00:43:41 +0100
committerepapoutsellis <epapoutsellis@gmail.com>2019-05-15 00:43:41 +0100
commitee580a5a8c852958c7ce28d935806c37bca12d44 (patch)
treee4f515e1784469d7650059cdfca25df766864d4f
parentabb4ce7d7aea5e88c442891da756a32e80ccb9b0 (diff)
downloadframework-ee580a5a8c852958c7ce28d935806c37bca12d44.tar.gz
framework-ee580a5a8c852958c7ce28d935806c37bca12d44.tar.bz2
framework-ee580a5a8c852958c7ce28d935806c37bca12d44.tar.xz
framework-ee580a5a8c852958c7ce28d935806c37bca12d44.zip
check FISTA, constr, regul
-rw-r--r--Wrappers/Python/demos/FISTA_examples/FISTA_Tikhonov_Poisson_Denoising.py53
1 files changed, 36 insertions, 17 deletions
diff --git a/Wrappers/Python/demos/FISTA_examples/FISTA_Tikhonov_Poisson_Denoising.py b/Wrappers/Python/demos/FISTA_examples/FISTA_Tikhonov_Poisson_Denoising.py
index 5b1bb16..be397ff 100644
--- a/Wrappers/Python/demos/FISTA_examples/FISTA_Tikhonov_Poisson_Denoising.py
+++ b/Wrappers/Python/demos/FISTA_examples/FISTA_Tikhonov_Poisson_Denoising.py
@@ -23,7 +23,7 @@
Tikhonov for Poisson denoising using FISTA algorithm:
-Problem: min_x, x>0 \alpha * ||\nabla x||_{2} + \int x - g * log(x)
+Problem: min_x, x>0 \alpha * ||\nabla x||_{2}^{2} + \int x - g * log(x)
\alpha: Regularization parameter
@@ -42,9 +42,8 @@ import matplotlib.pyplot as plt
from ccpi.optimisation.algorithms import FISTA
-from ccpi.optimisation.operators import Gradient, BlockOperator, Identity
-from ccpi.optimisation.functions import KullbackLeibler, IndicatorBox, BlockFunction, \
- L2NormSquared, IndicatorBox, FunctionOperatorComposition
+from ccpi.optimisation.operators import Gradient
+from ccpi.optimisation.functions import KullbackLeibler, L2NormSquared, FunctionOperatorComposition
from ccpi.framework import TestData
import os, sys
@@ -75,19 +74,42 @@ plt.imshow(noisy_data.as_array())
plt.title('Noisy Data')
plt.colorbar()
plt.show()
-
#%%
-
# Regularisation Parameter
-alpha = 20
+alpha = 10
# Setup and run the FISTA algorithm
-op1 = Gradient(ig)
-op2 = BlockOperator(Identity(ig), Identity(ig), shape=(2,1))
+operator = Gradient(ig)
+fid = KullbackLeibler(noisy_data)
-tmp_function = BlockFunction( KullbackLeibler(noisy_data), IndicatorBox(lower=0) )
+def KL_Prox_PosCone(x, tau, out=None):
+
+ if out is None:
+ tmp = 0.5 *( (x - fid.bnoise - tau) + ( (x + fid.bnoise - tau)**2 + 4*tau*fid.b ) .sqrt() )
+ return tmp.maximum(0)
+ else:
+ tmp = 0.5 *( (x - fid.bnoise - tau) +
+ ( (x + fid.bnoise - tau)**2 + 4*tau*fid.b ) .sqrt()
+ )
+ x.add(fid.bnoise, out=out)
+ out -= tau
+ out *= out
+ tmp = fid.b * (4 * tau)
+ out.add(tmp, out=out)
+ out.sqrt(out=out)
+
+ x.subtract(fid.bnoise, out=tmp)
+ tmp -= tau
+
+ out += tmp
+
+ out *= 0.5
+
+ # ADD the constraint here
+ out.maximum(0, out=out)
+
+fid.proximal = KL_Prox_PosCone
-fid = tmp
reg = FunctionOperatorComposition(alpha * L2NormSquared(), operator)
x_init = ig.allocate()
@@ -97,7 +119,6 @@ fista.max_iteration = 2000
fista.update_objective_interval = 500
fista.run(2000, verbose=True)
-#%%
# Show results
plt.figure(figsize=(15,15))
plt.subplot(3,1,1)
@@ -120,6 +141,7 @@ plt.legend()
plt.title('Middle Line Profiles')
plt.show()
+
#%% Check with CVX solution
from ccpi.optimisation.operators import SparseFiniteDiff
@@ -173,12 +195,9 @@ if cvx_not_installable:
plt.title('Middle Line Profiles')
plt.show()
+ #TODO what is the output of fista.objective, fista.loss
print('Primal Objective (CVX) {} '.format(obj.value))
- print('Primal Objective (FISTA) {} '.format(fista.objective[-1][0]))
-
-
-
-
+ print('Primal Objective (FISTA) {} '.format(fista.loss[1]))