diff options
author | Edoardo Pasca <edo.paskino@gmail.com> | 2017-10-18 11:48:16 +0100 |
---|---|---|
committer | Edoardo Pasca <edo.paskino@gmail.com> | 2017-10-18 11:48:16 +0100 |
commit | 9a126e05d03a474850c122cc44e971383069fb8d (patch) | |
tree | 1adfdbdc8bfd81f10893f3ab85ff843226aac1dd | |
parent | 99e8a3130d6ee161fc8e73faf526d7e0a7a9db44 (diff) | |
download | regularization-9a126e05d03a474850c122cc44e971383069fb8d.tar.gz regularization-9a126e05d03a474850c122cc44e971383069fb8d.tar.bz2 regularization-9a126e05d03a474850c122cc44e971383069fb8d.tar.xz regularization-9a126e05d03a474850c122cc44e971383069fb8d.zip |
minor reorganization of the code
added RSME
-rw-r--r-- | src/Python/test_reconstructor.py | 97 |
1 files changed, 45 insertions, 52 deletions
diff --git a/src/Python/test_reconstructor.py b/src/Python/test_reconstructor.py index f8f6b3c..2f188b4 100644 --- a/src/Python/test_reconstructor.py +++ b/src/Python/test_reconstructor.py @@ -11,10 +11,17 @@ import numpy from ccpi.fista.FISTAReconstructor import FISTAReconstructor import astra +import matplotlib.pyplot as plt -##def getEntry(nx, location): -## for item in nx[location].keys(): -## print (item) +def RMSE(signal1, signal2): + '''RMSE Root Mean Squared Error''' + if numpy.shape(signal1) == numpy.shape(signal2): + err = (signal1 - signal2) + err = numpy.sum( err * err )/numpy.size(signal1); # MSE + err = sqrt(err); # RMSE + return err + else: + raise Exception('Input signals must have the same shape') filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5' nx = h5py.File(filename, "r") @@ -68,7 +75,6 @@ fistaRecon.setParameter(number_of_iterations = 12) fistaRecon.setParameter(Lipschitz_constant = 767893952.0) fistaRecon.setParameter(ring_alpha = 21) fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) -#fistaRecon.setParameter(use_studentt_fidelity= True) ## Ordered subset if False: @@ -95,18 +101,33 @@ if False: if True: - fistaRecon.prepareForIteration() print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) - + print ("prepare for iteration") + fistaRecon.prepareForIteration() + + print("initializing ...") + if False: + # if X doesn't exist + #N = params.vol_geom.GridColCount + N = vol_geom['GridColCount'] + print ("N " + str(N)) + X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) + else: + #X = fistaRecon.initialize() + X = numpy.load("X.npy") + + print (numpy.shape(X)) + X_t = X.copy() + print ("initialized") proj_geom , vol_geom, sino , \ SlicesZ = fistaRecon.getParameter(['projector_geometry' , 'output_geometry', 'input_sinogram', 'SlicesZ']) - fistaRecon.setParameter(number_of_iterations = 3) + #fistaRecon.setParameter(number_of_iterations = 3) iterFISTA = fistaRecon.getParameter('number_of_iterations') # errors vector (if the ground truth is given) Resid_error = numpy.zeros((iterFISTA)); @@ -114,23 +135,10 @@ if True: objective = numpy.zeros((iterFISTA)); - print ("line") t = 1 - print ("line") - if False: - # if X doesn't exist - #N = params.vol_geom.GridColCount - N = vol_geom['GridColCount'] - print ("N " + str(N)) - X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) - else: - #X = fistaRecon.initialize() - X = numpy.load("X.npy") - - print (numpy.shape(X)) - X_t = X.copy() - print ("X_t copy") + + print ("starting iterations") ## % Outer FISTA iterations loop for i in range(fistaRecon.getParameter('number_of_iterations')): X_old = X.copy() @@ -147,7 +155,6 @@ if True: vol_geomT['GridSliceCount'] = 1; sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) for kkk in range(SlicesZ): - print (kkk) sino_id, sino_updt[kkk] = \ astra.creators.create_sino3d_gpu( X_t[kkk:kkk+1], proj_geomT, vol_geomT) @@ -169,8 +176,9 @@ if True: SlicesZ, anglesNumb, Detectors = \ numpy.shape(fistaRecon.getParameter('input_sinogram')) if lambdaR_L1 > 0 : + print ("ring removal") for kkk in range(anglesNumb): - print ("angles {0}".format(kkk)) + residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ ((sino_updt[:,kkk,:]).squeeze() - \ (sino[:,kkk,:]).squeeze() -\ @@ -194,39 +202,15 @@ if True: ## r = r_x - (1./L_const).*vec; ## objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output - else: - if fistaRecon.getParameter('use_studentt_fidelity'): - residual = weights * (sino_updt - sino) - for kkk in range(SlicesZ): - # reshape(residual(:,:,kkk), Detectors*anglesNumb, 1) - # 1D - res_vec = numpy.reshape(residual[kkk], (Detectors * anglesNumb,1)) - -## else -## if (studentt == 1) -## % artifacts removal with Students t penalty -## residual = weights.*(sino_updt - sino); -## for kkk = 1:SlicesZ -## res_vec = reshape(residual(:,:,kkk), Detectors*anglesNumb, 1); % 1D vectorized sinogram -## %s = 100; -## %gr = (2)*res_vec./(s*2 + conj(res_vec).*res_vec); -## [ff, gr] = studentst(res_vec, 1); -## residual(:,:,kkk) = reshape(gr, Detectors, anglesNumb); -## end -## objective(i) = ff; % for the objective function output -## else -## % no ring removal (LS model) -## residual = weights.*(sino_updt - sino); -## objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output -## end -## end + # Projection/Backprojection Routine if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ fistaRecon.getParameter('projector_geometry')['type'] == 'parallel3d': x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) + print ("Projection/Backprojection Routine") for kkk in range(SlicesZ): - print ("Projection/Backprojection Routine {0}".format( kkk )) + x_id, x_temp[kkk] = \ astra.creators.create_backprojection3d_gpu( residual[kkk:kkk+1], @@ -248,9 +232,11 @@ if True: # regularizer = fistaRecon.getParameter('regularizer') # for slices: # out = regularizer(input=X) + print ("skipping regularizer") ## FINAL + print ("final") lambdaR_L1 = fistaRecon.getParameter('ring_lambda_R_L1') if lambdaR_L1 > 0: fistaRecon.r = numpy.max( @@ -263,9 +249,16 @@ if True: fistaRecon.r_x = fistaRecon.r + \ (((t_old-1)/t) * (fistaRecon.r - r_old)) - if fistaRecon.getParameter('ideal_image') is None: + if fistaRecon.getParameter('region_of_interest') is None: string = 'Iteration Number {0} | Objective {1} \n' print (string.format( i, objective[i])) + else: + ROI , X_ideal = fistaRecon.getParameter('region_of_interest', + 'ideal_image') + + Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) + string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' + print (string.format(i,Resid_error[i], objective[i])) ## if (lambdaR_L1 > 0) ## r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector |