diff options
author | epapoutsellis <epapoutsellis@gmail.com> | 2019-05-14 23:49:12 +0100 |
---|---|---|
committer | epapoutsellis <epapoutsellis@gmail.com> | 2019-05-14 23:49:12 +0100 |
commit | 75239edf1e0d616caa5f2af0b6779739819ed265 (patch) | |
tree | e243e7083c95915d7d49bbb6f8daf491926dd0f0 /Wrappers/Python | |
parent | 1b8c5193938a68548a34b42433ddb97fcaec587a (diff) | |
download | framework-75239edf1e0d616caa5f2af0b6779739819ed265.tar.gz framework-75239edf1e0d616caa5f2af0b6779739819ed265.tar.bz2 framework-75239edf1e0d616caa5f2af0b6779739819ed265.tar.xz framework-75239edf1e0d616caa5f2af0b6779739819ed265.zip |
minor fix
Diffstat (limited to 'Wrappers/Python')
-rw-r--r-- | Wrappers/Python/demos/CompareAlgorithms/CGLS_PDHG_Tikhonov.py | 52 |
1 files changed, 16 insertions, 36 deletions
diff --git a/Wrappers/Python/demos/CompareAlgorithms/CGLS_PDHG_Tikhonov.py b/Wrappers/Python/demos/CompareAlgorithms/CGLS_PDHG_Tikhonov.py index 942d328..9b6d10f 100644 --- a/Wrappers/Python/demos/CompareAlgorithms/CGLS_PDHG_Tikhonov.py +++ b/Wrappers/Python/demos/CompareAlgorithms/CGLS_PDHG_Tikhonov.py @@ -32,8 +32,7 @@ Problem: min_x alpha * ||\grad x ||^{2}_{2} + || A x - g ||_{2}^{2} """ -from ccpi.framework import ImageData, ImageGeometry, \ - AcquisitionGeometry, BlockDataContainer, AcquisitionData +from ccpi.framework import AcquisitionGeometry, BlockDataContainer, AcquisitionData import numpy as np import numpy @@ -42,28 +41,30 @@ import matplotlib.pyplot as plt from ccpi.optimisation.algorithms import PDHG, CGLS from ccpi.optimisation.operators import BlockOperator, Gradient -from ccpi.optimisation.functions import ZeroFunction, BlockFunction, L2NormSquared +from ccpi.optimisation.functions import ZeroFunction, BlockFunction, L2NormSquared +from ccpi.astra.ops import AstraProjectorSimple +from ccpi.framework import TestData +import os, sys -# Create Ground truth phantom and Sinogram -N = 128 -x = np.zeros((N,N)) -x[round(N/4):round(3*N/4),round(N/4):round(3*N/4)] = 0.5 -x[round(N/8):round(7*N/8),round(3*N/8):round(5*N/8)] = 1 +loader = TestData(data_dir=os.path.join(sys.prefix, 'share','ccpi')) -data = ImageData(x) -ig = ImageGeometry(voxel_num_x = N, voxel_num_y = N) +# Create Ground truth phantom and Sinogram +N = 150 +M = 150 +data = loader.load(TestData.SIMPLE_PHANTOM_2D, size=(N,M), scale=(0,1)) +ig = data.geometry detectors = N angles = np.linspace(0, np.pi, N, dtype=np.float32) - ag = AcquisitionGeometry('parallel','2D', angles, detectors) + device = input('Available device: GPU==1 / CPU==0 ') -ag = AcquisitionGeometry('parallel','2D', angles, detectors) if device=='1': dev = 'gpu' else: dev = 'cpu' +Aop = AstraProjectorSimple(ig, ag, dev) sin = Aop.direct(data) noisy_data = AcquisitionData( sin.as_array() + np.random.normal(0,3,ig.shape)) @@ -77,7 +78,6 @@ op_CGLS = BlockOperator( Aop, alpha * Grad, shape=(2,1)) block_data = BlockDataContainer(noisy_data, Grad.range_geometry().allocate()) x_init = ig.allocate() - cgls = CGLS(x_init=x_init, operator=op_CGLS, data=block_data) cgls.max_iteration = 1000 cgls.update_objective_interval = 200 @@ -88,7 +88,6 @@ cgls.run(1000,verbose=False) # Create BlockOperator op_PDHG = BlockOperator(Grad, Aop, shape=(2,1) ) - # Create functions f1 = 0.5 * alpha**2 * L2NormSquared() f2 = 0.5 * L2NormSquared(b = noisy_data) @@ -102,13 +101,11 @@ normK = op_PDHG.norm() sigma = 10 tau = 1/(sigma*normK**2) -pdhg = PDHG(f=f,g=g,operator=op_PDHG, tau=tau, sigma=sigma, memopt=True) +pdhg = PDHG(f=f,g=g,operator=op_PDHG, tau=tau, sigma=sigma) pdhg.max_iteration = 1000 pdhg.update_objective_interval = 200 pdhg.run(1000, verbose=False) - -#%% # Show results plt.figure(figsize=(10,10)) @@ -129,25 +126,8 @@ plt.title('Diff PDHG vs CGLS') plt.colorbar() plt.show() -plt.plot(np.linspace(0,N,N), pdhg.get_output().as_array()[int(N/2),:], label = 'PDHG') -plt.plot(np.linspace(0,N,N), cgls.get_output().as_array()[int(N/2),:], label = 'CGLS') +plt.plot(np.linspace(0,N,M), pdhg.get_output().as_array()[int(N/2),:], label = 'PDHG') +plt.plot(np.linspace(0,N,M), cgls.get_output().as_array()[int(N/2),:], label = 'CGLS') plt.legend() plt.title('Middle Line Profiles') plt.show() - - - - - - - - - -# -# -# -# -# -# -# -# |