summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2017-10-18 11:48:16 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2017-10-18 11:48:16 +0100
commit9a126e05d03a474850c122cc44e971383069fb8d (patch)
tree1adfdbdc8bfd81f10893f3ab85ff843226aac1dd
parent99e8a3130d6ee161fc8e73faf526d7e0a7a9db44 (diff)
downloadregularization-9a126e05d03a474850c122cc44e971383069fb8d.tar.gz
regularization-9a126e05d03a474850c122cc44e971383069fb8d.tar.bz2
regularization-9a126e05d03a474850c122cc44e971383069fb8d.tar.xz
regularization-9a126e05d03a474850c122cc44e971383069fb8d.zip
minor reorganization of the code
added RSME
-rw-r--r--src/Python/test_reconstructor.py97
1 files changed, 45 insertions, 52 deletions
diff --git a/src/Python/test_reconstructor.py b/src/Python/test_reconstructor.py
index f8f6b3c..2f188b4 100644
--- a/src/Python/test_reconstructor.py
+++ b/src/Python/test_reconstructor.py
@@ -11,10 +11,17 @@ import numpy
from ccpi.fista.FISTAReconstructor import FISTAReconstructor
import astra
+import matplotlib.pyplot as plt
-##def getEntry(nx, location):
-## for item in nx[location].keys():
-## print (item)
+def RMSE(signal1, signal2):
+ '''RMSE Root Mean Squared Error'''
+ if numpy.shape(signal1) == numpy.shape(signal2):
+ err = (signal1 - signal2)
+ err = numpy.sum( err * err )/numpy.size(signal1); # MSE
+ err = sqrt(err); # RMSE
+ return err
+ else:
+ raise Exception('Input signals must have the same shape')
filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5'
nx = h5py.File(filename, "r")
@@ -68,7 +75,6 @@ fistaRecon.setParameter(number_of_iterations = 12)
fistaRecon.setParameter(Lipschitz_constant = 767893952.0)
fistaRecon.setParameter(ring_alpha = 21)
fistaRecon.setParameter(ring_lambda_R_L1 = 0.002)
-#fistaRecon.setParameter(use_studentt_fidelity= True)
## Ordered subset
if False:
@@ -95,18 +101,33 @@ if False:
if True:
- fistaRecon.prepareForIteration()
print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant']))
-
+ print ("prepare for iteration")
+ fistaRecon.prepareForIteration()
+
+ print("initializing ...")
+ if False:
+ # if X doesn't exist
+ #N = params.vol_geom.GridColCount
+ N = vol_geom['GridColCount']
+ print ("N " + str(N))
+ X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float)
+ else:
+ #X = fistaRecon.initialize()
+ X = numpy.load("X.npy")
+
+ print (numpy.shape(X))
+ X_t = X.copy()
+ print ("initialized")
proj_geom , vol_geom, sino , \
SlicesZ = fistaRecon.getParameter(['projector_geometry' ,
'output_geometry',
'input_sinogram',
'SlicesZ'])
- fistaRecon.setParameter(number_of_iterations = 3)
+ #fistaRecon.setParameter(number_of_iterations = 3)
iterFISTA = fistaRecon.getParameter('number_of_iterations')
# errors vector (if the ground truth is given)
Resid_error = numpy.zeros((iterFISTA));
@@ -114,23 +135,10 @@ if True:
objective = numpy.zeros((iterFISTA));
- print ("line")
t = 1
- print ("line")
- if False:
- # if X doesn't exist
- #N = params.vol_geom.GridColCount
- N = vol_geom['GridColCount']
- print ("N " + str(N))
- X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float)
- else:
- #X = fistaRecon.initialize()
- X = numpy.load("X.npy")
-
- print (numpy.shape(X))
- X_t = X.copy()
- print ("X_t copy")
+
+ print ("starting iterations")
## % Outer FISTA iterations loop
for i in range(fistaRecon.getParameter('number_of_iterations')):
X_old = X.copy()
@@ -147,7 +155,6 @@ if True:
vol_geomT['GridSliceCount'] = 1;
sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float)
for kkk in range(SlicesZ):
- print (kkk)
sino_id, sino_updt[kkk] = \
astra.creators.create_sino3d_gpu(
X_t[kkk:kkk+1], proj_geomT, vol_geomT)
@@ -169,8 +176,9 @@ if True:
SlicesZ, anglesNumb, Detectors = \
numpy.shape(fistaRecon.getParameter('input_sinogram'))
if lambdaR_L1 > 0 :
+ print ("ring removal")
for kkk in range(anglesNumb):
- print ("angles {0}".format(kkk))
+
residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \
((sino_updt[:,kkk,:]).squeeze() - \
(sino[:,kkk,:]).squeeze() -\
@@ -194,39 +202,15 @@ if True:
## r = r_x - (1./L_const).*vec;
## objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output
- else:
- if fistaRecon.getParameter('use_studentt_fidelity'):
- residual = weights * (sino_updt - sino)
- for kkk in range(SlicesZ):
- # reshape(residual(:,:,kkk), Detectors*anglesNumb, 1)
- # 1D
- res_vec = numpy.reshape(residual[kkk], (Detectors * anglesNumb,1))
-
-## else
-## if (studentt == 1)
-## % artifacts removal with Students t penalty
-## residual = weights.*(sino_updt - sino);
-## for kkk = 1:SlicesZ
-## res_vec = reshape(residual(:,:,kkk), Detectors*anglesNumb, 1); % 1D vectorized sinogram
-## %s = 100;
-## %gr = (2)*res_vec./(s*2 + conj(res_vec).*res_vec);
-## [ff, gr] = studentst(res_vec, 1);
-## residual(:,:,kkk) = reshape(gr, Detectors, anglesNumb);
-## end
-## objective(i) = ff; % for the objective function output
-## else
-## % no ring removal (LS model)
-## residual = weights.*(sino_updt - sino);
-## objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output
-## end
-## end
+
# Projection/Backprojection Routine
if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \
fistaRecon.getParameter('projector_geometry')['type'] == 'parallel3d':
x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32)
+ print ("Projection/Backprojection Routine")
for kkk in range(SlicesZ):
- print ("Projection/Backprojection Routine {0}".format( kkk ))
+
x_id, x_temp[kkk] = \
astra.creators.create_backprojection3d_gpu(
residual[kkk:kkk+1],
@@ -248,9 +232,11 @@ if True:
# regularizer = fistaRecon.getParameter('regularizer')
# for slices:
# out = regularizer(input=X)
+ print ("skipping regularizer")
## FINAL
+ print ("final")
lambdaR_L1 = fistaRecon.getParameter('ring_lambda_R_L1')
if lambdaR_L1 > 0:
fistaRecon.r = numpy.max(
@@ -263,9 +249,16 @@ if True:
fistaRecon.r_x = fistaRecon.r + \
(((t_old-1)/t) * (fistaRecon.r - r_old))
- if fistaRecon.getParameter('ideal_image') is None:
+ if fistaRecon.getParameter('region_of_interest') is None:
string = 'Iteration Number {0} | Objective {1} \n'
print (string.format( i, objective[i]))
+ else:
+ ROI , X_ideal = fistaRecon.getParameter('region_of_interest',
+ 'ideal_image')
+
+ Resid_error[i] = RMSE(X*ROI, X_ideal*ROI)
+ string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n'
+ print (string.format(i,Resid_error[i], objective[i]))
## if (lambdaR_L1 > 0)
## r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector