summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/Python/ccpi/reconstruction/FISTAReconstructor.py186
1 files changed, 173 insertions, 13 deletions
diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py
index b8e1027..4f1709c 100644
--- a/src/Python/ccpi/reconstruction/FISTAReconstructor.py
+++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py
@@ -430,7 +430,13 @@ class FISTAReconstructor():
# prepareForIteration
- def iterate(self, Xin=None):
+ def iterate (self, Xin=None):
+ if self.getParameter('subset') == 0:
+ return self.iterateStandard(Xin)
+ else:
+ return self.iterateOrderedSubsets(Xin)
+
+ def iterateStandard(self, Xin=None):
print ("FISTA Reconstructor: iterate")
if Xin is None:
@@ -613,14 +619,14 @@ class FISTAReconstructor():
def updateLoop(self, i, X, X_old, r_old, t, t_old):
print ("FISTA Reconstructor: update loop")
lambdaR_L1 = self.getParameter('ring_lambda_R_L1')
- if lambdaR_L1 > 0:
- self.r = numpy.max(
- numpy.abs(self.r) - lambdaR_L1 , 0) * \
- numpy.sign(self.r)
+
t = (1 + numpy.sqrt(1 + 4 * t**2))/2
X_t = X + (((t_old -1)/t) * (X - X_old))
if lambdaR_L1 > 0:
+ self.r = numpy.max(
+ numpy.abs(self.r) - lambdaR_L1 , 0) * \
+ numpy.sign(self.r)
self.r_x = self.r + \
(((t_old-1)/t) * (self.r - r_old))
@@ -636,8 +642,8 @@ class FISTAReconstructor():
print (string.format(i,Resid_error[i], self.objective[i]))
return (X , X_t, t)
- def os_iterate(self, Xin=None):
- print ("FISTA Reconstructor: iterate")
+ def iterateOS(self, Xin=None):
+ print ("FISTA Reconstructor: Ordered Subsets iterate")
if Xin is None:
if self.getParameter('initialize'):
@@ -653,9 +659,163 @@ class FISTAReconstructor():
X_t = X.copy()
# some useful constants
- proj_geom , vol_geom, sino , \
- SlicesZ, weights , alpha_ring ,\
- lambdaR_L1 , L_const = self.getParameter(
- ['projector_geometry' , 'output_geometry',
- 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha' ,
- 'ring_lambda_R_L1', 'Lipschitz_constant'])
+ proj_geom , vol_geom, sino , \
+ SlicesZ, weights , alpha_ring ,\
+ lambdaR_L1 , L_const , iterFISTA = self.getParameter(
+ ['projector_geometry' , 'output_geometry', 'input_sinogram',
+ 'SlicesZ' , 'weights', 'ring_alpha' ,
+ 'ring_lambda_R_L1', 'Lipschitz_constant',
+ 'number_of_iterations'])
+
+
+ # errors vector (if the ground truth is given)
+ Resid_error = numpy.zeros((iterFISTA));
+ # objective function values vector
+ objective = numpy.zeros((iterFISTA));
+
+
+ t = 1
+
+ ## additional for
+ proj_geomSUB = proj_geom.copy()
+ self.residual2 = numpy.zeros(numpy.shape(sino))
+ residual2 = self.residual2
+ sino_updt_FULL = self.residual.copy()
+ r_x = self.r.copy()
+
+ print ("starting iterations")
+ ## % Outer FISTA iterations loop
+ for i in range(fistaRecon.getParameter('number_of_iterations')):
+ # With OS approach it becomes trickier to correlate independent
+ # subsets, hence additional work is required one solution is to
+ # work with a full sinogram at times
+
+ r_old = self.r.copy()
+ t_old = t
+ SlicesZ, anglesNumb, Detectors = \
+ numpy.shape(self.getParameter('input_sinogram')) ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4
+ if (i > 1 and lambdaR_L1 > 0) :
+ for kkk in range(anglesNumb):
+
+ residual2[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \
+ ((sino_updt_FULL[:,kkk,:]).squeeze() - \
+ (sino[:,kkk,:]).squeeze() -\
+ (alpha_ring * r_x)
+ )
+
+ vec = self.residual.sum(axis = 1)
+ #if SlicesZ > 1:
+ # vec = vec[:,1,:] # 1 or 0?
+ r_x = self.r_x
+ # update ring variable
+ self.r = (r_x - (1./L_const) * vec).copy()
+
+ # subset loop
+ counterInd = 1
+ geometry_type = self.getParameter('projector_geometry')['type']
+ angles = self.getParameter('projector_geometry')['ProjectionAngles']
+
+ for ss in range(self.getParameter('subsets')):
+ print ("Subset {0}".format(ss))
+ X_old = X.copy()
+ t_old = t
+
+ # the number of projections per subset
+ numProjSub = self.getParameter('os_bins')[ss]
+ CurrSubIndices = self.getParameter('os_indices')\
+ [counterInd:counterInd+numProjSub]
+ #print ("Len CurrSubIndices {0}".format(numProjSub))
+ mask = numpy.zeros(numpy.shape(angles), dtype=bool)
+ cc = 0
+ for j in range(len(CurrSubIndices)):
+ mask[int(CurrSubIndices[j])] = True
+ proj_geomSUB['ProjectionAngles'] = angles[mask]
+
+ shape = list(numpy.shape(self.getParameter('input_sinogram')))
+ shape[1] = numProjSub
+ sino_updt_Sub = numpy.zeros(shape)
+ if geometry_type == 'parallel' or \
+ geometry_type == 'fanflat' or \
+ geometry_type == 'fanflat_vec' :
+
+ for kkk in range(SlicesZ):
+ sino_id, sinoT = astra.creators.create_sino3d_gpu (
+ X_t[kkk:kkk+1] , proj_geomSUB, vol_geom)
+ sino_updt_Sub[kkk] = sinoT.T.copy()
+
+ else:
+ # for 3D geometry (watch the GPU memory overflow in
+ # ASTRA < 1.8)
+ sino_id, sino_updt_Sub = \
+ astra.creators.create_sino3d_gpu (X_t, proj_geomSUB, vol_geom)
+
+ astra.matlab.data3d('delete', sino_id)
+
+
+ if lambdaR_L1 > 0 :
+ ## RING REMOVAL
+ print ("ring removal")
+ residualSub = self.ringRemovalOrderedSubsets(sino_updt_Sub,
+ sino_updt_FULL)
+ else:
+ #PWLS model
+ print ("PWLS model")
+ residualSub = weights[:,CurrSubIndices,:] * \
+ ( sino_updt_Sub - \
+ sino[:,CurrSubIndices,:].squeeze() )
+ objective[i] = 0.5 * numpy.linalg.norm(residualSub)
+
+ # projection/backprojection routine
+ if geometry_type == 'parallel' or \
+ geometry_type == 'fanflat' or \
+ geometry_type == 'fanflat_vec' :
+ # if geometry is 2D use slice-by-slice projection-backprojection
+ # routine
+ x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32)
+ for kkk in range(SlicesZ):
+
+ x_id, x_temp[kkk] = \
+ astra.creators.create_backprojection3d_gpu(
+ residualSub[kkk:kkk+1],
+ proj_geomSUB, vol_geom)
+
+ else:
+ x_id, x_temp = \
+ astra.creators.create_backprojection3d_gpu(
+ residualSub, proj_geomSUB, vol_geom)
+
+ astra.matlab.data3d('delete', x_id)
+ X = X_t - (1/L_const) * x_temp
+ ## REGULARIZATION
+ X = self.regularize(X)
+
+ # FINAL
+ ## Update Loop
+ X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old)
+ self.setParameter(output_volume=X)
+ counterInd = counterInd + numProjSub
+
+ return X
+
+ def ringRemovalOrderedSubsets(self, sino_updt_Sub, sino_updt_FULL):
+ residual = self.residual
+ r_x = self.r_x
+ weights , alpha_ring , sino = \
+ self.getParameter( ['weights', 'ring_alpha', 'input_sinogram'])
+ numProjSub = self.getParameter('os_bins')[ss]
+ CurrSubIndices = self.getParameter('os_indices')\
+ [counterInd:counterInd+numProjSub]
+ residualSub = numpy.zeros(shape)
+
+ for kkk in range(numProjSub):
+ #print ("ring removal indC ... {0}".format(kkk))
+ indC = int(CurrSubIndices[kkk])
+ residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \
+ (sino_updt_Sub[:,kkk,:].squeeze() - \
+ sino[:,indC,:].squeeze() - alpha_ring * r_x)
+ # filling the full sinogram
+ sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze()
+
+ return residualSub
+
+