summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition_old.py85
-rw-r--r--Wrappers/Python/wip/Demos/fista_test.py127
-rw-r--r--Wrappers/Python/wip/pdhg_TV_tomography2D.py2
3 files changed, 213 insertions, 1 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition_old.py b/Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition_old.py
new file mode 100644
index 0000000..70511bb
--- /dev/null
+++ b/Wrappers/Python/ccpi/optimisation/functions/FunctionOperatorComposition_old.py
@@ -0,0 +1,85 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Created on Fri Mar 8 09:55:36 2019
+
+@author: evangelos
+"""
+
+from ccpi.optimisation.functions import Function
+from ccpi.optimisation.functions import ScaledFunction
+
+
+class FunctionOperatorComposition(Function):
+
+ ''' Function composition with Operator, i.e., f(Ax)
+
+ A: operator
+ f: function
+
+ '''
+
+ def __init__(self, operator, function):
+
+ super(FunctionOperatorComposition, self).__init__()
+ self.function = function
+ self.operator = operator
+ alpha = 1
+
+ if isinstance (function, ScaledFunction):
+ alpha = function.scalar
+ self.L = 2 * alpha * operator.norm()**2
+
+
+ def __call__(self, x):
+
+ ''' Evaluate FunctionOperatorComposition at x
+
+ returns f(Ax)
+
+ '''
+
+ return self.function(self.operator.direct(x))
+
+ #TODO do not know if we need it
+ def call_adjoint(self, x):
+
+ return self.function(self.operator.adjoint(x))
+
+
+ def convex_conjugate(self, x):
+
+ ''' convex_conjugate does not take into account the Operator'''
+ return self.function.convex_conjugate(x)
+
+ def proximal(self, x, tau, out=None):
+
+ '''proximal does not take into account the Operator'''
+ if out is None:
+ return self.function.proximal(x, tau)
+ else:
+ self.function.proximal(x, tau, out=out)
+
+
+ def proximal_conjugate(self, x, tau, out=None):
+
+ ''' proximal conjugate does not take into account the Operator'''
+ if out is None:
+ return self.function.proximal_conjugate(x, tau)
+ else:
+ self.function.proximal_conjugate(x, tau, out=out)
+
+ def gradient(self, x, out=None):
+
+ ''' Gradient takes into account the Operator'''
+ if out is None:
+ return self.operator.adjoint(
+ self.function.gradient(self.operator.direct(x))
+ )
+ else:
+ self.operator.adjoint(
+ self.function.gradient(self.operator.direct(x),
+ out=out)
+ )
+
+ \ No newline at end of file
diff --git a/Wrappers/Python/wip/Demos/fista_test.py b/Wrappers/Python/wip/Demos/fista_test.py
new file mode 100644
index 0000000..dd1f6fa
--- /dev/null
+++ b/Wrappers/Python/wip/Demos/fista_test.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+# This work is part of the Core Imaging Library developed by
+# Visual Analytics and Imaging System Group of the Science Technology
+# Facilities Council, STFC
+
+# Copyright 2018-2019 Evangelos Papoutsellis and Edoardo Pasca
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ccpi.framework import ImageData, ImageGeometry
+
+import numpy as np
+import numpy
+import matplotlib.pyplot as plt
+
+from ccpi.optimisation.algorithms import FISTA, PDHG
+
+from ccpi.optimisation.operators import BlockOperator, Gradient, Identity
+from ccpi.optimisation.functions import L2NormSquared, L1Norm, \
+ MixedL21Norm, FunctionOperatorComposition, BlockFunction, ZeroFunction
+
+from skimage.util import random_noise
+
+# Create phantom for TV Gaussian denoising
+N = 100
+
+data = np.zeros((N,N))
+data[round(N/4):round(3*N/4),round(N/4):round(3*N/4)] = 0.5
+data[round(N/8):round(7*N/8),round(3*N/8):round(5*N/8)] = 1
+data = ImageData(data)
+ig = ImageGeometry(voxel_num_x = N, voxel_num_y = N)
+ag = ig
+
+# Create noisy data. Add Gaussian noise
+n1 = random_noise(data.as_array(), mode = 's&p', salt_vs_pepper = 0.9, amount=0.2)
+noisy_data = ImageData(n1)
+
+# Regularisation Parameter
+alpha = 5
+
+operator = Gradient(ig)
+
+#fidelity = L1Norm(b=noisy_data)
+#regulariser = FunctionOperatorComposition(alpha * L2NormSquared(), operator)
+
+fidelity = FunctionOperatorComposition(alpha * MixedL21Norm(), operator)
+regulariser = 0.5 * L2NormSquared(b = noisy_data)
+
+x_init = ig.allocate()
+
+## Setup and run the PDHG algorithm
+opt = {'tol': 1e-4, 'memopt':True}
+fista = FISTA(x_init=x_init , f=regulariser, g=fidelity, opt=opt)
+fista.max_iteration = 2000
+fista.update_objective_interval = 50
+fista.run(2000, verbose=True)
+
+plt.figure(figsize=(15,15))
+plt.subplot(3,1,1)
+plt.imshow(data.as_array())
+plt.title('Ground Truth')
+plt.colorbar()
+plt.subplot(3,1,2)
+plt.imshow(noisy_data.as_array())
+plt.title('Noisy Data')
+plt.colorbar()
+plt.subplot(3,1,3)
+plt.imshow(fista.get_output().as_array())
+plt.title('TV Reconstruction')
+plt.colorbar()
+plt.show()
+
+# Compare with PDHG
+method = '0'
+#
+if method == '0':
+#
+# # Create operators
+ op1 = Gradient(ig)
+ op2 = Identity(ig, ag)
+#
+# # Create BlockOperator
+ operator = BlockOperator(op1, op2, shape=(2,1) )
+ f = BlockFunction(alpha * L2NormSquared(), fidelity)
+ g = ZeroFunction()
+
+## Compute operator Norm
+normK = operator.norm()
+#
+## Primal & dual stepsizes
+sigma = 1
+tau = 1/(sigma*normK**2)
+#
+#
+## Setup and run the PDHG algorithm
+pdhg = PDHG(f=f,g=g,operator=operator, tau=tau, sigma=sigma, memopt=True)
+pdhg.max_iteration = 2000
+pdhg.update_objective_interval = 50
+pdhg.run(2000)
+#
+#%%
+plt.figure(figsize=(15,15))
+plt.subplot(3,1,1)
+plt.imshow(fista.get_output().as_array())
+plt.title('FISTA')
+plt.colorbar()
+plt.subplot(3,1,2)
+plt.imshow(pdhg.get_output().as_array())
+plt.title('PDHG')
+plt.colorbar()
+plt.subplot(3,1,3)
+plt.imshow(np.abs(pdhg.get_output().as_array()-fista.get_output().as_array()))
+plt.title('Diff FISTA-PDHG')
+plt.colorbar()
+plt.show()
+
+
diff --git a/Wrappers/Python/wip/pdhg_TV_tomography2D.py b/Wrappers/Python/wip/pdhg_TV_tomography2D.py
index cd91409..b04a609 100644
--- a/Wrappers/Python/wip/pdhg_TV_tomography2D.py
+++ b/Wrappers/Python/wip/pdhg_TV_tomography2D.py
@@ -91,7 +91,7 @@ g = ZeroFunction()
normK = operator.norm()
## Primal & dual stepsizes
-diag_precon = False
+diag_precon = True
if diag_precon: