summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python
diff options
context:
space:
mode:
authorepapoutsellis <epapoutsellis@gmail.com>2019-05-14 23:49:12 +0100
committerepapoutsellis <epapoutsellis@gmail.com>2019-05-14 23:49:12 +0100
commit75239edf1e0d616caa5f2af0b6779739819ed265 (patch)
treee243e7083c95915d7d49bbb6f8daf491926dd0f0 /Wrappers/Python
parent1b8c5193938a68548a34b42433ddb97fcaec587a (diff)
downloadframework-75239edf1e0d616caa5f2af0b6779739819ed265.tar.gz
framework-75239edf1e0d616caa5f2af0b6779739819ed265.tar.bz2
framework-75239edf1e0d616caa5f2af0b6779739819ed265.tar.xz
framework-75239edf1e0d616caa5f2af0b6779739819ed265.zip
minor fix
Diffstat (limited to 'Wrappers/Python')
-rw-r--r--Wrappers/Python/demos/CompareAlgorithms/CGLS_PDHG_Tikhonov.py52
1 files changed, 16 insertions, 36 deletions
diff --git a/Wrappers/Python/demos/CompareAlgorithms/CGLS_PDHG_Tikhonov.py b/Wrappers/Python/demos/CompareAlgorithms/CGLS_PDHG_Tikhonov.py
index 942d328..9b6d10f 100644
--- a/Wrappers/Python/demos/CompareAlgorithms/CGLS_PDHG_Tikhonov.py
+++ b/Wrappers/Python/demos/CompareAlgorithms/CGLS_PDHG_Tikhonov.py
@@ -32,8 +32,7 @@ Problem: min_x alpha * ||\grad x ||^{2}_{2} + || A x - g ||_{2}^{2}
"""
-from ccpi.framework import ImageData, ImageGeometry, \
- AcquisitionGeometry, BlockDataContainer, AcquisitionData
+from ccpi.framework import AcquisitionGeometry, BlockDataContainer, AcquisitionData
import numpy as np
import numpy
@@ -42,28 +41,30 @@ import matplotlib.pyplot as plt
from ccpi.optimisation.algorithms import PDHG, CGLS
from ccpi.optimisation.operators import BlockOperator, Gradient
-from ccpi.optimisation.functions import ZeroFunction, BlockFunction, L2NormSquared
+from ccpi.optimisation.functions import ZeroFunction, BlockFunction, L2NormSquared
+from ccpi.astra.ops import AstraProjectorSimple
+from ccpi.framework import TestData
+import os, sys
-# Create Ground truth phantom and Sinogram
-N = 128
-x = np.zeros((N,N))
-x[round(N/4):round(3*N/4),round(N/4):round(3*N/4)] = 0.5
-x[round(N/8):round(7*N/8),round(3*N/8):round(5*N/8)] = 1
+loader = TestData(data_dir=os.path.join(sys.prefix, 'share','ccpi'))
-data = ImageData(x)
-ig = ImageGeometry(voxel_num_x = N, voxel_num_y = N)
+# Create Ground truth phantom and Sinogram
+N = 150
+M = 150
+data = loader.load(TestData.SIMPLE_PHANTOM_2D, size=(N,M), scale=(0,1))
+ig = data.geometry
detectors = N
angles = np.linspace(0, np.pi, N, dtype=np.float32)
-
ag = AcquisitionGeometry('parallel','2D', angles, detectors)
+
device = input('Available device: GPU==1 / CPU==0 ')
-ag = AcquisitionGeometry('parallel','2D', angles, detectors)
if device=='1':
dev = 'gpu'
else:
dev = 'cpu'
+Aop = AstraProjectorSimple(ig, ag, dev)
sin = Aop.direct(data)
noisy_data = AcquisitionData( sin.as_array() + np.random.normal(0,3,ig.shape))
@@ -77,7 +78,6 @@ op_CGLS = BlockOperator( Aop, alpha * Grad, shape=(2,1))
block_data = BlockDataContainer(noisy_data, Grad.range_geometry().allocate())
x_init = ig.allocate()
-
cgls = CGLS(x_init=x_init, operator=op_CGLS, data=block_data)
cgls.max_iteration = 1000
cgls.update_objective_interval = 200
@@ -88,7 +88,6 @@ cgls.run(1000,verbose=False)
# Create BlockOperator
op_PDHG = BlockOperator(Grad, Aop, shape=(2,1) )
-
# Create functions
f1 = 0.5 * alpha**2 * L2NormSquared()
f2 = 0.5 * L2NormSquared(b = noisy_data)
@@ -102,13 +101,11 @@ normK = op_PDHG.norm()
sigma = 10
tau = 1/(sigma*normK**2)
-pdhg = PDHG(f=f,g=g,operator=op_PDHG, tau=tau, sigma=sigma, memopt=True)
+pdhg = PDHG(f=f,g=g,operator=op_PDHG, tau=tau, sigma=sigma)
pdhg.max_iteration = 1000
pdhg.update_objective_interval = 200
pdhg.run(1000, verbose=False)
-
-#%%
# Show results
plt.figure(figsize=(10,10))
@@ -129,25 +126,8 @@ plt.title('Diff PDHG vs CGLS')
plt.colorbar()
plt.show()
-plt.plot(np.linspace(0,N,N), pdhg.get_output().as_array()[int(N/2),:], label = 'PDHG')
-plt.plot(np.linspace(0,N,N), cgls.get_output().as_array()[int(N/2),:], label = 'CGLS')
+plt.plot(np.linspace(0,N,M), pdhg.get_output().as_array()[int(N/2),:], label = 'PDHG')
+plt.plot(np.linspace(0,N,M), cgls.get_output().as_array()[int(N/2),:], label = 'CGLS')
plt.legend()
plt.title('Middle Line Profiles')
plt.show()
-
-
-
-
-
-
-
-
-
-#
-#
-#
-#
-#
-#
-#
-#