diff options
Diffstat (limited to 'src/Python/ccpi')
-rw-r--r-- | src/Python/ccpi/fista/FISTAReconstructor.py | 184 |
1 files changed, 160 insertions, 24 deletions
diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py index cbd27da..8318ea6 100644 --- a/src/Python/ccpi/fista/FISTAReconstructor.py +++ b/src/Python/ccpi/fista/FISTAReconstructor.py @@ -78,19 +78,28 @@ class FISTAReconstructor(): # handle parmeters: # obligatory parameters self.pars = dict() - self.pars['projector_geometry'] = projector_geometry - self.pars['output_geometry'] = output_geometry - self.pars['input_sinogram'] = input_sinogram + self.pars['projector_geometry'] = projector_geometry # proj_geom + self.pars['output_geometry'] = output_geometry # vol_geom + self.pars['input_sinogram'] = input_sinogram # sino detectors, nangles, sliceZ = numpy.shape(input_sinogram) self.pars['detectors'] = detectors - self.pars['number_og_angles'] = nangles + self.pars['number_of_angles'] = nangles self.pars['SlicesZ'] = sliceZ print (self.pars) # handle optional input parameters (at instantiation) # Accepted input keywords - kw = ('number_of_iterations', + kw = ( + # mandatory fields + 'projector_geometry', + 'output_geometry', + 'input_sinogram', + 'detectors', + 'number_of_angles', + 'SlicesZ', + # optional fields + 'number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , 'weights' , @@ -98,8 +107,9 @@ class FISTAReconstructor(): 'initialize' , 'regularizer' , 'ring_lambda_R_L1', - 'ring_alpha') - self.acceptedInputKeywords = kw + 'ring_alpha', + 'subsets') + self.acceptedInputKeywords = list(kw) # handle keyworded parameters if kwargs is not None: @@ -122,8 +132,7 @@ class FISTAReconstructor(): if 'Lipschitz_constant' in kwargs.keys(): self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] else: - self.pars['Lipschitz_constant'] = \ - self.calculateLipschitzConstantWithPowerMethod() + self.pars['Lipschitz_constant'] = None if not 'ideal_image' in kwargs.keys(): self.pars['ideal_image'] = None @@ -143,31 +152,44 @@ class FISTAReconstructor(): self.pars['ring_lambda_R_L1'] = 0 if not 'ring_alpha' in kwargs.keys(): self.pars['ring_alpha'] = 1 - + + if not 'subsets' in kwargs.keys(): + self.pars['subsets'] = 0 + else: + self.createOrderedSubsets() + + if not 'initialize' in kwargs.keys(): + self.pars['initialize'] = False def setParameter(self, **kwargs): - '''set named parameter for the regularization engine + '''set named parameter for the reconstructor engine raises Exception if the named parameter is not recognized - Typical usage is: - - reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - reg.setParameter(input=u0) - reg.setParameter(regularization_parameter=10.) - it can be also used as - reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - reg.setParameter(input=u0 , regularization_parameter=10.) ''' - for key , value in kwargs.items(): - if key in self.acceptedInputKeywords.keys(): + if key in self.acceptedInputKeywords: self.pars[key] = value else: - raise Exception('Wrong parameter {0} for '.format(key) + - 'Reconstruction algorithm') + raise Exception('Wrong parameter {0} for '.format(key) + + 'reconstructor') # setParameter + + def getParameter(self, key): + if type(key) is str: + if key in self.acceptedInputKeywords: + return self.pars[key] + else: + raise Exception('Unrecongnised parameter: {0} '.format(key) ) + elif type(key) is list: + outpars = [] + for k in key: + outpars.append(self.getParameter(k)) + return outpars + else: + raise Exception('Unhandled input {0}' .format(str(type(key)))) + def calculateLipschitzConstantWithPowerMethod(self): ''' using Power method (PM) to establish L constant''' @@ -289,5 +311,119 @@ class FISTAReconstructor(): if regularizer is not None: self.pars['regularizer'] = regularizer + + def initialize(self): + # convenience variable storage + proj_geom = self.pars['projector_geometry'] + vol_geom = self.pars['output_geometry'] + sino = self.pars['input_sinogram'] + + # a 'warm start' with SIRT method + # Create a data object for the reconstruction + rec_id = astra.matlab.data3d('create', '-vol', + vol_geom); + + #sinogram_id = astra_mex_data3d('create', '-proj3d', proj_geom, sino); + sinogram_id = astra.matlab.data3d('create', '-proj3d', + proj_geom, + sino) + + sirt_config = astra.astra_dict('SIRT3D_CUDA') + sirt_config['ReconstructionDataId' ] = rec_id + sirt_config['ProjectionDataId'] = sinogram_id + + sirt = astra.algorithm.create(sirt_config) + astra.algorithm.run(sirt, iterations=35) + X = astra.matlab.data3d('get', rec_id) + + # clean up memory + astra.matlab.data3d('delete', rec_id) + astra.matlab.data3d('delete', sinogram_id) + astra.algorithm.delete(sirt) + + + + return X + + def createOrderedSubsets(self, subsets=None): + if subsets is None: + try: + subsets = self.getParameter('subsets') + except Exception(): + subsets = 0 + #return subsets + + angles = self.getParameter('projector_geometry')['ProjectionAngles'] + + + + + + + def prepareForIteration(self): + self.residual_error = numpy.zeros((self.pars['number_of_iterations'])) + self.objective = numpy.zeros((self.pars['number_of_iterations'])) + + #2D array (for 3D data) of sparse "ring" + detectors, nangles, sliceZ = numpy.shape(self.pars['input_sinogram']) + self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float) + # another ring variable + self.rx = self.r.copy() + + self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram'])) + + if self.getParameter('Lipschitz_constant') is None: + self.pars['Lipschitz_constant'] = \ + self.calculateLipschitzConstantWithPowerMethod() + + # prepareForIteration + + def iterate(self, Xin=None): + # convenience variable storage + proj_geom , vol_geom, sino , \ + SlicesZ = self.getParameter(['projector_geometry' , + 'output_geometry', + 'input_sinogram', + 'SlicesZ']) + + t = 1 + if Xin is None: + if self.getParameter('initialize'): + X = self.initialize() + else: + N = vol_geom['GridColCount'] + X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) + else: + X = Xin.copy() + + X_t = X.copy() + + for i in range(self.getParameter('number_of_iterations')): + X_old = X.copy() + t_old = t + r_old = self.r.copy() + if self.pars['projector_geometry']['type'] == 'parallel' or \ + self.pars['projector_geometry']['type'] == 'parallel3d': + # if the geometry is parallel use slice-by-slice + # projection-backprojection routine + #sino_updt = zeros(size(sino),'single'); + sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) + + #for kkk = 1:SlicesZ + # [sino_id, sino_updt(:,:,kkk)] = + # astra_create_sino3d_cuda(X_t(:,:,kkk), proj_geomT, vol_geomT); + # astra_mex_data3d('delete', sino_id); + for kkk in range(SlicesZ): + sino_id, sino_updt[kkk] = \ + astra.creators.create_sino3d_gpu( + X_t[kkk], proj_geomT, vol_geomT) + + else: + # for divergent 3D geometry (watch GPU memory overflow in + # Astra < 1.8 + sino_id, y = astra.creators.create_sino3d_gpu(X_t, + proj_geom, + vol_geom) - + + |