diff options
author | epapoutsellis <epapoutsellis@gmail.com> | 2019-05-15 00:43:41 +0100 |
---|---|---|
committer | epapoutsellis <epapoutsellis@gmail.com> | 2019-05-15 00:43:41 +0100 |
commit | ee580a5a8c852958c7ce28d935806c37bca12d44 (patch) | |
tree | e4f515e1784469d7650059cdfca25df766864d4f | |
parent | abb4ce7d7aea5e88c442891da756a32e80ccb9b0 (diff) | |
download | framework-ee580a5a8c852958c7ce28d935806c37bca12d44.tar.gz framework-ee580a5a8c852958c7ce28d935806c37bca12d44.tar.bz2 framework-ee580a5a8c852958c7ce28d935806c37bca12d44.tar.xz framework-ee580a5a8c852958c7ce28d935806c37bca12d44.zip |
check FISTA, constr, regul
-rw-r--r-- | Wrappers/Python/demos/FISTA_examples/FISTA_Tikhonov_Poisson_Denoising.py | 53 |
1 files changed, 36 insertions, 17 deletions
diff --git a/Wrappers/Python/demos/FISTA_examples/FISTA_Tikhonov_Poisson_Denoising.py b/Wrappers/Python/demos/FISTA_examples/FISTA_Tikhonov_Poisson_Denoising.py index 5b1bb16..be397ff 100644 --- a/Wrappers/Python/demos/FISTA_examples/FISTA_Tikhonov_Poisson_Denoising.py +++ b/Wrappers/Python/demos/FISTA_examples/FISTA_Tikhonov_Poisson_Denoising.py @@ -23,7 +23,7 @@ Tikhonov for Poisson denoising using FISTA algorithm: -Problem: min_x, x>0 \alpha * ||\nabla x||_{2} + \int x - g * log(x) +Problem: min_x, x>0 \alpha * ||\nabla x||_{2}^{2} + \int x - g * log(x) \alpha: Regularization parameter @@ -42,9 +42,8 @@ import matplotlib.pyplot as plt from ccpi.optimisation.algorithms import FISTA -from ccpi.optimisation.operators import Gradient, BlockOperator, Identity -from ccpi.optimisation.functions import KullbackLeibler, IndicatorBox, BlockFunction, \ - L2NormSquared, IndicatorBox, FunctionOperatorComposition +from ccpi.optimisation.operators import Gradient +from ccpi.optimisation.functions import KullbackLeibler, L2NormSquared, FunctionOperatorComposition from ccpi.framework import TestData import os, sys @@ -75,19 +74,42 @@ plt.imshow(noisy_data.as_array()) plt.title('Noisy Data') plt.colorbar() plt.show() - #%% - # Regularisation Parameter -alpha = 20 +alpha = 10 # Setup and run the FISTA algorithm -op1 = Gradient(ig) -op2 = BlockOperator(Identity(ig), Identity(ig), shape=(2,1)) +operator = Gradient(ig) +fid = KullbackLeibler(noisy_data) -tmp_function = BlockFunction( KullbackLeibler(noisy_data), IndicatorBox(lower=0) ) +def KL_Prox_PosCone(x, tau, out=None): + + if out is None: + tmp = 0.5 *( (x - fid.bnoise - tau) + ( (x + fid.bnoise - tau)**2 + 4*tau*fid.b ) .sqrt() ) + return tmp.maximum(0) + else: + tmp = 0.5 *( (x - fid.bnoise - tau) + + ( (x + fid.bnoise - tau)**2 + 4*tau*fid.b ) .sqrt() + ) + x.add(fid.bnoise, out=out) + out -= tau + out *= out + tmp = fid.b * (4 * tau) + out.add(tmp, out=out) + out.sqrt(out=out) + + x.subtract(fid.bnoise, out=tmp) + tmp -= tau + + out += tmp + + out *= 0.5 + + # ADD the constraint here + out.maximum(0, out=out) + +fid.proximal = KL_Prox_PosCone -fid = tmp reg = FunctionOperatorComposition(alpha * L2NormSquared(), operator) x_init = ig.allocate() @@ -97,7 +119,6 @@ fista.max_iteration = 2000 fista.update_objective_interval = 500 fista.run(2000, verbose=True) -#%% # Show results plt.figure(figsize=(15,15)) plt.subplot(3,1,1) @@ -120,6 +141,7 @@ plt.legend() plt.title('Middle Line Profiles') plt.show() + #%% Check with CVX solution from ccpi.optimisation.operators import SparseFiniteDiff @@ -173,12 +195,9 @@ if cvx_not_installable: plt.title('Middle Line Profiles') plt.show() + #TODO what is the output of fista.objective, fista.loss print('Primal Objective (CVX) {} '.format(obj.value)) - print('Primal Objective (FISTA) {} '.format(fista.objective[-1][0])) - - - - + print('Primal Objective (FISTA) {} '.format(fista.loss[1])) |