diff options
-rw-r--r-- | Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition_old.py | 85 | ||||
-rw-r--r-- | Wrappers/Python/wip/Demos/fista_test.py | 127 | ||||
-rw-r--r-- | Wrappers/Python/wip/pdhg_TV_tomography2D.py | 2 |
3 files changed, 213 insertions, 1 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition_old.py b/Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition_old.py new file mode 100644 index 0000000..70511bb --- /dev/null +++ b/Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition_old.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Fri Mar 8 09:55:36 2019 + +@author: evangelos +""" + +from ccpi.optimisation.functions import Function +from ccpi.optimisation.functions import ScaledFunction + + +class FunctionOperatorComposition(Function): + + ''' Function composition with Operator, i.e., f(Ax) + + A: operator + f: function + + ''' + + def __init__(self, operator, function): + + super(FunctionOperatorComposition, self).__init__() + self.function = function + self.operator = operator + alpha = 1 + + if isinstance (function, ScaledFunction): + alpha = function.scalar + self.L = 2 * alpha * operator.norm()**2 + + + def __call__(self, x): + + ''' Evaluate FunctionOperatorComposition at x + + returns f(Ax) + + ''' + + return self.function(self.operator.direct(x)) + + #TODO do not know if we need it + def call_adjoint(self, x): + + return self.function(self.operator.adjoint(x)) + + + def convex_conjugate(self, x): + + ''' convex_conjugate does not take into account the Operator''' + return self.function.convex_conjugate(x) + + def proximal(self, x, tau, out=None): + + '''proximal does not take into account the Operator''' + if out is None: + return self.function.proximal(x, tau) + else: + self.function.proximal(x, tau, out=out) + + + def proximal_conjugate(self, x, tau, out=None): + + ''' proximal conjugate does not take into account the Operator''' + if out is None: + return self.function.proximal_conjugate(x, tau) + else: + self.function.proximal_conjugate(x, tau, out=out) + + def gradient(self, x, out=None): + + ''' Gradient takes into account the Operator''' + if out is None: + return self.operator.adjoint( + self.function.gradient(self.operator.direct(x)) + ) + else: + self.operator.adjoint( + self.function.gradient(self.operator.direct(x), + out=out) + ) + +
\ No newline at end of file diff --git a/Wrappers/Python/wip/Demos/fista_test.py b/Wrappers/Python/wip/Demos/fista_test.py new file mode 100644 index 0000000..dd1f6fa --- /dev/null +++ b/Wrappers/Python/wip/Demos/fista_test.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +# This work is part of the Core Imaging Library developed by +# Visual Analytics and Imaging System Group of the Science Technology +# Facilities Council, STFC + +# Copyright 2018-2019 Evangelos Papoutsellis and Edoardo Pasca + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ccpi.framework import ImageData, ImageGeometry + +import numpy as np +import numpy +import matplotlib.pyplot as plt + +from ccpi.optimisation.algorithms import FISTA, PDHG + +from ccpi.optimisation.operators import BlockOperator, Gradient, Identity +from ccpi.optimisation.functions import L2NormSquared, L1Norm, \ + MixedL21Norm, FunctionOperatorComposition, BlockFunction, ZeroFunction + +from skimage.util import random_noise + +# Create phantom for TV Gaussian denoising +N = 100 + +data = np.zeros((N,N)) +data[round(N/4):round(3*N/4),round(N/4):round(3*N/4)] = 0.5 +data[round(N/8):round(7*N/8),round(3*N/8):round(5*N/8)] = 1 +data = ImageData(data) +ig = ImageGeometry(voxel_num_x = N, voxel_num_y = N) +ag = ig + +# Create noisy data. Add Gaussian noise +n1 = random_noise(data.as_array(), mode = 's&p', salt_vs_pepper = 0.9, amount=0.2) +noisy_data = ImageData(n1) + +# Regularisation Parameter +alpha = 5 + +operator = Gradient(ig) + +#fidelity = L1Norm(b=noisy_data) +#regulariser = FunctionOperatorComposition(alpha * L2NormSquared(), operator) + +fidelity = FunctionOperatorComposition(alpha * MixedL21Norm(), operator) +regulariser = 0.5 * L2NormSquared(b = noisy_data) + +x_init = ig.allocate() + +## Setup and run the PDHG algorithm +opt = {'tol': 1e-4, 'memopt':True} +fista = FISTA(x_init=x_init , f=regulariser, g=fidelity, opt=opt) +fista.max_iteration = 2000 +fista.update_objective_interval = 50 +fista.run(2000, verbose=True) + +plt.figure(figsize=(15,15)) +plt.subplot(3,1,1) +plt.imshow(data.as_array()) +plt.title('Ground Truth') +plt.colorbar() +plt.subplot(3,1,2) +plt.imshow(noisy_data.as_array()) +plt.title('Noisy Data') +plt.colorbar() +plt.subplot(3,1,3) +plt.imshow(fista.get_output().as_array()) +plt.title('TV Reconstruction') +plt.colorbar() +plt.show() + +# Compare with PDHG +method = '0' +# +if method == '0': +# +# # Create operators + op1 = Gradient(ig) + op2 = Identity(ig, ag) +# +# # Create BlockOperator + operator = BlockOperator(op1, op2, shape=(2,1) ) + f = BlockFunction(alpha * L2NormSquared(), fidelity) + g = ZeroFunction() + +## Compute operator Norm +normK = operator.norm() +# +## Primal & dual stepsizes +sigma = 1 +tau = 1/(sigma*normK**2) +# +# +## Setup and run the PDHG algorithm +pdhg = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma, memopt=True) +pdhg.max_iteration = 2000 +pdhg.update_objective_interval = 50 +pdhg.run(2000) +# +#%% +plt.figure(figsize=(15,15)) +plt.subplot(3,1,1) +plt.imshow(fista.get_output().as_array()) +plt.title('FISTA') +plt.colorbar() +plt.subplot(3,1,2) +plt.imshow(pdhg.get_output().as_array()) +plt.title('PDHG') +plt.colorbar() +plt.subplot(3,1,3) +plt.imshow(np.abs(pdhg.get_output().as_array()-fista.get_output().as_array())) +plt.title('Diff FISTA-PDHG') +plt.colorbar() +plt.show() + + diff --git a/Wrappers/Python/wip/pdhg_TV_tomography2D.py b/Wrappers/Python/wip/pdhg_TV_tomography2D.py index cd91409..b04a609 100644 --- a/Wrappers/Python/wip/pdhg_TV_tomography2D.py +++ b/Wrappers/Python/wip/pdhg_TV_tomography2D.py @@ -91,7 +91,7 @@ g = ZeroFunction() normK = operator.norm() ## Primal & dual stepsizes -diag_precon = False +diag_precon = True if diag_precon: |