diff options
-rw-r--r-- | src/Python/ccpi/reconstruction/FISTAReconstructor.py | 186 |
1 files changed, 173 insertions, 13 deletions
diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py index b8e1027..4f1709c 100644 --- a/src/Python/ccpi/reconstruction/FISTAReconstructor.py +++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py @@ -430,7 +430,13 @@ class FISTAReconstructor(): # prepareForIteration - def iterate(self, Xin=None): + def iterate (self, Xin=None): + if self.getParameter('subset') == 0: + return self.iterateStandard(Xin) + else: + return self.iterateOrderedSubsets(Xin) + + def iterateStandard(self, Xin=None): print ("FISTA Reconstructor: iterate") if Xin is None: @@ -613,14 +619,14 @@ class FISTAReconstructor(): def updateLoop(self, i, X, X_old, r_old, t, t_old): print ("FISTA Reconstructor: update loop") lambdaR_L1 = self.getParameter('ring_lambda_R_L1') - if lambdaR_L1 > 0: - self.r = numpy.max( - numpy.abs(self.r) - lambdaR_L1 , 0) * \ - numpy.sign(self.r) + t = (1 + numpy.sqrt(1 + 4 * t**2))/2 X_t = X + (((t_old -1)/t) * (X - X_old)) if lambdaR_L1 > 0: + self.r = numpy.max( + numpy.abs(self.r) - lambdaR_L1 , 0) * \ + numpy.sign(self.r) self.r_x = self.r + \ (((t_old-1)/t) * (self.r - r_old)) @@ -636,8 +642,8 @@ class FISTAReconstructor(): print (string.format(i,Resid_error[i], self.objective[i])) return (X , X_t, t) - def os_iterate(self, Xin=None): - print ("FISTA Reconstructor: iterate") + def iterateOS(self, Xin=None): + print ("FISTA Reconstructor: Ordered Subsets iterate") if Xin is None: if self.getParameter('initialize'): @@ -653,9 +659,163 @@ class FISTAReconstructor(): X_t = X.copy() # some useful constants - proj_geom , vol_geom, sino , \ - SlicesZ, weights , alpha_ring ,\ - lambdaR_L1 , L_const = self.getParameter( - ['projector_geometry' , 'output_geometry', - 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha' , - 'ring_lambda_R_L1', 'Lipschitz_constant']) + proj_geom , vol_geom, sino , \ + SlicesZ, weights , alpha_ring ,\ + lambdaR_L1 , L_const , iterFISTA = self.getParameter( + ['projector_geometry' , 'output_geometry', 'input_sinogram', + 'SlicesZ' , 'weights', 'ring_alpha' , + 'ring_lambda_R_L1', 'Lipschitz_constant', + 'number_of_iterations']) + + + # errors vector (if the ground truth is given) + Resid_error = numpy.zeros((iterFISTA)); + # objective function values vector + objective = numpy.zeros((iterFISTA)); + + + t = 1 + + ## additional for + proj_geomSUB = proj_geom.copy() + self.residual2 = numpy.zeros(numpy.shape(sino)) + residual2 = self.residual2 + sino_updt_FULL = self.residual.copy() + r_x = self.r.copy() + + print ("starting iterations") + ## % Outer FISTA iterations loop + for i in range(fistaRecon.getParameter('number_of_iterations')): + # With OS approach it becomes trickier to correlate independent + # subsets, hence additional work is required one solution is to + # work with a full sinogram at times + + r_old = self.r.copy() + t_old = t + SlicesZ, anglesNumb, Detectors = \ + numpy.shape(self.getParameter('input_sinogram')) ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4 + if (i > 1 and lambdaR_L1 > 0) : + for kkk in range(anglesNumb): + + residual2[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ + ((sino_updt_FULL[:,kkk,:]).squeeze() - \ + (sino[:,kkk,:]).squeeze() -\ + (alpha_ring * r_x) + ) + + vec = self.residual.sum(axis = 1) + #if SlicesZ > 1: + # vec = vec[:,1,:] # 1 or 0? + r_x = self.r_x + # update ring variable + self.r = (r_x - (1./L_const) * vec).copy() + + # subset loop + counterInd = 1 + geometry_type = self.getParameter('projector_geometry')['type'] + angles = self.getParameter('projector_geometry')['ProjectionAngles'] + + for ss in range(self.getParameter('subsets')): + print ("Subset {0}".format(ss)) + X_old = X.copy() + t_old = t + + # the number of projections per subset + numProjSub = self.getParameter('os_bins')[ss] + CurrSubIndices = self.getParameter('os_indices')\ + [counterInd:counterInd+numProjSub] + #print ("Len CurrSubIndices {0}".format(numProjSub)) + mask = numpy.zeros(numpy.shape(angles), dtype=bool) + cc = 0 + for j in range(len(CurrSubIndices)): + mask[int(CurrSubIndices[j])] = True + proj_geomSUB['ProjectionAngles'] = angles[mask] + + shape = list(numpy.shape(self.getParameter('input_sinogram'))) + shape[1] = numProjSub + sino_updt_Sub = numpy.zeros(shape) + if geometry_type == 'parallel' or \ + geometry_type == 'fanflat' or \ + geometry_type == 'fanflat_vec' : + + for kkk in range(SlicesZ): + sino_id, sinoT = astra.creators.create_sino3d_gpu ( + X_t[kkk:kkk+1] , proj_geomSUB, vol_geom) + sino_updt_Sub[kkk] = sinoT.T.copy() + + else: + # for 3D geometry (watch the GPU memory overflow in + # ASTRA < 1.8) + sino_id, sino_updt_Sub = \ + astra.creators.create_sino3d_gpu (X_t, proj_geomSUB, vol_geom) + + astra.matlab.data3d('delete', sino_id) + + + if lambdaR_L1 > 0 : + ## RING REMOVAL + print ("ring removal") + residualSub = self.ringRemovalOrderedSubsets(sino_updt_Sub, + sino_updt_FULL) + else: + #PWLS model + print ("PWLS model") + residualSub = weights[:,CurrSubIndices,:] * \ + ( sino_updt_Sub - \ + sino[:,CurrSubIndices,:].squeeze() ) + objective[i] = 0.5 * numpy.linalg.norm(residualSub) + + # projection/backprojection routine + if geometry_type == 'parallel' or \ + geometry_type == 'fanflat' or \ + geometry_type == 'fanflat_vec' : + # if geometry is 2D use slice-by-slice projection-backprojection + # routine + x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32) + for kkk in range(SlicesZ): + + x_id, x_temp[kkk] = \ + astra.creators.create_backprojection3d_gpu( + residualSub[kkk:kkk+1], + proj_geomSUB, vol_geom) + + else: + x_id, x_temp = \ + astra.creators.create_backprojection3d_gpu( + residualSub, proj_geomSUB, vol_geom) + + astra.matlab.data3d('delete', x_id) + X = X_t - (1/L_const) * x_temp + ## REGULARIZATION + X = self.regularize(X) + + # FINAL + ## Update Loop + X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old) + self.setParameter(output_volume=X) + counterInd = counterInd + numProjSub + + return X + + def ringRemovalOrderedSubsets(self, sino_updt_Sub, sino_updt_FULL): + residual = self.residual + r_x = self.r_x + weights , alpha_ring , sino = \ + self.getParameter( ['weights', 'ring_alpha', 'input_sinogram']) + numProjSub = self.getParameter('os_bins')[ss] + CurrSubIndices = self.getParameter('os_indices')\ + [counterInd:counterInd+numProjSub] + residualSub = numpy.zeros(shape) + + for kkk in range(numProjSub): + #print ("ring removal indC ... {0}".format(kkk)) + indC = int(CurrSubIndices[kkk]) + residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \ + (sino_updt_Sub[:,kkk,:].squeeze() - \ + sino[:,indC,:].squeeze() - alpha_ring * r_x) + # filling the full sinogram + sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze() + + return residualSub + + |