diff options
| author | Edoardo Pasca <edo.paskino@gmail.com> | 2017-10-24 11:31:36 +0100 | 
|---|---|---|
| committer | Edoardo Pasca <edo.paskino@gmail.com> | 2017-10-24 11:31:36 +0100 | 
| commit | 546104f8dfea5691801137c1be99d09e1e999d82 (patch) | |
| tree | 7b66e2ec46c49ea4ff7b872cd8ac602fe2a9b8d7 /src/Python/ccpi | |
| parent | 909a7bb4d71bdb14d4e68f42c2297f6154a77ed0 (diff) | |
| download | regularization-546104f8dfea5691801137c1be99d09e1e999d82.tar.gz regularization-546104f8dfea5691801137c1be99d09e1e999d82.tar.bz2 regularization-546104f8dfea5691801137c1be99d09e1e999d82.tar.xz regularization-546104f8dfea5691801137c1be99d09e1e999d82.zip | |
removed fista directory
use the standard package reconstruction directory for the fista code
Diffstat (limited to 'src/Python/ccpi')
| -rw-r--r-- | src/Python/ccpi/fista/FISTAReconstructor.py | 609 | ||||
| -rw-r--r-- | src/Python/ccpi/fista/Reconstructor.py | 425 | ||||
| -rw-r--r-- | src/Python/ccpi/fista/__init__.py | 0 | 
3 files changed, 0 insertions, 1034 deletions
| diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py deleted file mode 100644 index 85bfac5..0000000 --- a/src/Python/ccpi/fista/FISTAReconstructor.py +++ /dev/null @@ -1,609 +0,0 @@ -# -*- coding: utf-8 -*- -############################################################################### -#This work is part of the Core Imaging Library developed by -#Visual Analytics and Imaging System Group of the Science Technology -#Facilities Council, STFC -# -#Copyright 2017 Edoardo Pasca, Srikanth Nagella -#Copyright 2017 Daniil Kazantsev -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -#http://www.apache.org/licenses/LICENSE-2.0 -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. -############################################################################### - - - -import numpy -#from ccpi.reconstruction.parallelbeam import alg - -#from ccpi.imaging.Regularizer import Regularizer -from enum import Enum - -import astra - -    -     -class FISTAReconstructor(): -    '''FISTA-based reconstruction algorithm using ASTRA-toolbox -     -    ''' -    # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> -    # ___Input___: -    # params.[] file: -    #       - .proj_geom (geometry of the projector) [required] -    #       - .vol_geom (geometry of the reconstructed object) [required] -    #       - .sino (vectorized in 2D or 3D sinogram) [required] -    #       - .iterFISTA (iterations for the main loop, default 40) -    #       - .L_const (Lipschitz constant, default Power method)                                                                                                    ) -    #       - .X_ideal (ideal image, if given) -    #       - .weights (statisitcal weights, size of the sinogram) -    #       - .ROI (Region-of-interest, only if X_ideal is given) -    #       - .initialize (a 'warm start' using SIRT method from ASTRA) -    #----------------Regularization choices------------------------ -    #       - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) -    #       - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) -    #       - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) -    #       - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) -    #       - .Regul_Iterations (iterations for the selected penalty, default 25) -    #       - .Regul_tauLLT (time step parameter for LLT term) -    #       - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) -    #       - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) -    #----------------Visualization parameters------------------------ -    #       - .show (visualize reconstruction 1/0, (0 default)) -    #       - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) -    #       - .slice (for 3D volumes - slice number to imshow) -    # ___Output___: -    # 1. X - reconstructed image/volume -    # 2. output - a structure with -    #    - .Resid_error - residual error (if X_ideal is given) -    #    - .objective: value of the objective function -    #    - .L_const: Lipshitz constant to avoid recalculations -     -    # References: -    # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse -    # Problems" by A. Beck and M Teboulle -    # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo -    # 3. "A novel tomographic reconstruction method based on the robust -    # Student's t function for suppressing data outliers" D. Kazantsev et.al. -    # D. Kazantsev, 2016-17 -    def __init__(self, projector_geometry, output_geometry, input_sinogram, -                 **kwargs): -        # handle parmeters: -        # obligatory parameters -        self.pars = dict() -        self.pars['projector_geometry'] = projector_geometry # proj_geom -        self.pars['output_geometry'] = output_geometry       # vol_geom -        self.pars['input_sinogram'] = input_sinogram         # sino -        sliceZ, nangles, detectors = numpy.shape(input_sinogram) -        self.pars['detectors'] = detectors -        self.pars['number_of_angles'] = nangles -        self.pars['SlicesZ'] = sliceZ -        self.pars['output_volume'] = None - -        print (self.pars) -        # handle optional input parameters (at instantiation) -         -        # Accepted input keywords -        kw = ( -              # mandatory fields -              'projector_geometry', -              'output_geometry', -              'input_sinogram', -              'detectors', -              'number_of_angles', -              'SlicesZ', -              # optional fields -              'number_of_iterations',  -              'Lipschitz_constant' ,  -              'ideal_image' , -              'weights' ,  -              'region_of_interest' ,  -              'initialize' ,  -              'regularizer' ,  -              'ring_lambda_R_L1', -              'ring_alpha', -              'subsets', -              'output_volume', -              'os_subsets', -              'os_indices', -              'os_bins') -        self.acceptedInputKeywords = list(kw) -         -        # handle keyworded parameters -        if kwargs is not None: -            for key, value in kwargs.items(): -                if key in kw: -                    #print("{0} = {1}".format(key, value))                         -                    self.pars[key] = value -                     -        # set the default values for the parameters if not set -        if 'number_of_iterations' in kwargs.keys(): -            self.pars['number_of_iterations'] = kwargs['number_of_iterations'] -        else: -            self.pars['number_of_iterations'] = 40 -        if 'weights' in kwargs.keys(): -            self.pars['weights'] = kwargs['weights'] -        else: -            self.pars['weights'] = \ -                                 numpy.ones(numpy.shape( -                                     self.pars['input_sinogram'])) -        if 'Lipschitz_constant' in kwargs.keys(): -            self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] -        else: -            self.pars['Lipschitz_constant'] = None -         -        if not 'ideal_image' in kwargs.keys(): -            self.pars['ideal_image'] = None -         -        if not 'region_of_interest'in kwargs.keys() : -            if self.pars['ideal_image'] == None: -                self.pars['region_of_interest'] = None -            else: -                ## nonzero if the image is larger than m -                fsm = numpy.frompyfunc(lambda x,m: 1 if x>m else 0, 2,1) -                 -                self.pars['region_of_interest'] = fsm(self.pars['ideal_image'], 0) -                 -        # the regularizer must be a correctly instantiated object     -        if not 'regularizer' in kwargs.keys() : -            self.pars['regularizer'] = None - -        #RING REMOVAL -        if not 'ring_lambda_R_L1' in kwargs.keys(): -            self.pars['ring_lambda_R_L1'] = 0 -        if not 'ring_alpha' in kwargs.keys(): -            self.pars['ring_alpha'] = 1 - -        # ORDERED SUBSET -        if not 'subsets' in kwargs.keys(): -            self.pars['subsets'] = 0 -        else: -            self.createOrderedSubsets() - -        if not 'initialize' in kwargs.keys(): -            self.pars['initialize'] = False - -         -             -             -    def setParameter(self, **kwargs): -        '''set named parameter for the reconstructor engine -         -        raises Exception if the named parameter is not recognized -         -        ''' -        for key , value in kwargs.items(): -            if key in self.acceptedInputKeywords: -                self.pars[key] = value -            else: -                raise Exception('Wrong parameter {0} for '.format(key) + -                                'reconstructor') -    # setParameter - -    def getParameter(self, key): -        if type(key) is str: -            if key in self.acceptedInputKeywords: -                return self.pars[key] -            else: -                raise Exception('Unrecongnised parameter: {0} '.format(key) ) -        elif type(key) is list: -            outpars = [] -            for k in key: -                outpars.append(self.getParameter(k)) -            return outpars -        else: -            raise Exception('Unhandled input {0}' .format(str(type(key)))) -             -     -    def calculateLipschitzConstantWithPowerMethod(self): -        ''' using Power method (PM) to establish L constant''' -         -        N = self.pars['output_geometry']['GridColCount'] -        proj_geom = self.pars['projector_geometry'] -        vol_geom = self.pars['output_geometry'] -        weights = self.pars['weights'] -        SlicesZ = self.pars['SlicesZ'] -         -             -                                -        if (proj_geom['type'] == 'parallel') or \ -           (proj_geom['type'] == 'parallel3d'): -            #% for parallel geometry we can do just one slice -            #print('Calculating Lipshitz constant for parallel beam geometry...') -            niter = 5;# % number of iteration for the PM -            #N = params.vol_geom.GridColCount; -            #x1 = rand(N,N,1); -            x1 = numpy.random.rand(1,N,N) -            #sqweight = sqrt(weights(:,:,1)); -            sqweight = numpy.sqrt(weights[0]) -            proj_geomT = proj_geom.copy(); -            proj_geomT['DetectorRowCount'] = 1; -            vol_geomT = vol_geom.copy(); -            vol_geomT['GridSliceCount'] = 1; -             -            #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); -             -             -            for i in range(niter): -            #        [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT); -            #            s = norm(x1(:)); -            #            x1 = x1/s; -            #            [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); -            #            y = sqweight.*y; -            #            astra_mex_data3d('delete', sino_id); -            #            astra_mex_data3d('delete', id); -                #print ("iteration {0}".format(i)) -                                 -                sino_id, y = astra.creators.create_sino3d_gpu(x1, -                                                          proj_geomT, -                                                          vol_geomT) -                 -                y = (sqweight * y).copy() # element wise multiplication -                 -                #b=fig.add_subplot(2,1,2) -                #imgplot = plt.imshow(x1[0]) -                #plt.show() -                 -                #astra_mex_data3d('delete', sino_id); -                astra.matlab.data3d('delete', sino_id) -                del x1 -                     -                idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(),  -                                                                    proj_geomT, -                                                                    vol_geomT) -                del y -                 -                                                                     -                s = numpy.linalg.norm(x1) -                ### this line? -                x1 = (x1/s).copy(); -                 -            #        ### this line? -            #        sino_id, y = astra.creators.create_sino3d_gpu(x1,  -            #                                                      proj_geomT,  -            #                                                      vol_geomT); -            #        y = sqweight * y; -                astra.matlab.data3d('delete', sino_id); -                astra.matlab.data3d('delete', idx) -                print ("iteration {0} s= {1}".format(i,s)) -                 -            #end -            del proj_geomT -            del vol_geomT -            #plt.show() -        else: -            #% divergen beam geometry -            print('Calculating Lipshitz constant for divergen beam geometry...') -            niter = 8; #% number of iteration for PM -            x1 = numpy.random.rand(SlicesZ , N , N); -            #sqweight = sqrt(weights); -            sqweight = numpy.sqrt(weights[0]) -             -            sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); -            y = sqweight*y; -            #astra_mex_data3d('delete', sino_id); -            astra.matlab.data3d('delete', sino_id); -             -            for i in range(niter): -                #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); -                idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y,  -                                                                    proj_geom,  -                                                                    vol_geom) -                s = numpy.linalg.norm(x1) -                ### this line? -                x1 = x1/s; -                ### this line? -                #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); -                sino_id, y = astra.creators.create_sino3d_gpu(x1,  -                                                              proj_geom,  -                                                              vol_geom); -                 -                y = sqweight*y; -                #astra_mex_data3d('delete', sino_id); -                #astra_mex_data3d('delete', id); -                astra.matlab.data3d('delete', sino_id); -                astra.matlab.data3d('delete', idx); -            #end -            #clear x1 -            del x1 - -         -        return s -     -     -    def setRegularizer(self, regularizer): -        if regularizer is not None: -            self.pars['regularizer'] = regularizer -         - -    def initialize(self): -        # convenience variable storage -        proj_geom = self.pars['projector_geometry'] -        vol_geom = self.pars['output_geometry'] -        sino = self.pars['input_sinogram'] -         -        # a 'warm start' with SIRT method -        # Create a data object for the reconstruction -        rec_id = astra.matlab.data3d('create', '-vol', -                                    vol_geom); -         -        #sinogram_id = astra_mex_data3d('create', '-proj3d', proj_geom, sino); -        sinogram_id = astra.matlab.data3d('create', '-proj3d', -                                          proj_geom, -                                          sino) - -        sirt_config = astra.astra_dict('SIRT3D_CUDA') -        sirt_config['ReconstructionDataId' ] = rec_id -        sirt_config['ProjectionDataId'] = sinogram_id - -        sirt = astra.algorithm.create(sirt_config) -        astra.algorithm.run(sirt, iterations=35) -        X = astra.matlab.data3d('get', rec_id) - -        # clean up memory -        astra.matlab.data3d('delete', rec_id) -        astra.matlab.data3d('delete', sinogram_id) -        astra.algorithm.delete(sirt) - -         - -        return X - -    def createOrderedSubsets(self, subsets=None): -        if subsets is None: -            try: -                subsets = self.getParameter('subsets') -            except Exception(): -                subsets = 0 -            #return subsets - -        angles = self.getParameter('projector_geometry')['ProjectionAngles']  -         -        #binEdges = numpy.linspace(angles.min(), -        #                          angles.max(), -        #                          subsets + 1) -        binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) -        # get rearranged subset indices -        IndicesReorg = numpy.zeros((numpy.shape(angles))) -        counterM = 0 -        for ii in range(binsDiscr.max()): -            counter = 0 -            for jj in range(subsets): -                curr_index = ii + jj  + counter -                #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM)) -                if binsDiscr[jj] > ii: -                    if (counterM < numpy.size(IndicesReorg)): -                        IndicesReorg[counterM] = curr_index -                    counterM = counterM + 1 -                     -                counter = counter + binsDiscr[jj] - 1     -                 -        # store the OS in parameters -        self.setParameter(os_subsets=subsets, -                          os_bins=binsDiscr, -                          os_indices=IndicesReorg) -             - -    def prepareForIteration(self): -        print ("FISTA Reconstructor: prepare for iteration") -         -        self.residual_error = numpy.zeros((self.pars['number_of_iterations'])) -        self.objective = numpy.zeros((self.pars['number_of_iterations'])) - -        #2D array (for 3D data) of sparse "ring"  -        detectors, nangles, sliceZ  = numpy.shape(self.pars['input_sinogram']) -        self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float) -        # another ring variable -        self.r_x = self.r.copy() - -        self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram'])) -         -        if self.getParameter('Lipschitz_constant') is None: -            self.pars['Lipschitz_constant'] = \ -                            self.calculateLipschitzConstantWithPowerMethod() -        # errors vector (if the ground truth is given) -        self.Resid_error = numpy.zeros((self.getParameter('number_of_iterations'))); -        # objective function values vector -        self.objective = numpy.zeros((self.getParameter('number_of_iterations')));       -         - -    # prepareForIteration - -    def iterate(self, Xin=None): -        print ("FISTA Reconstructor: iterate") -         -        if Xin is None:     -            if self.getParameter('initialize'): -                X = self.initialize() -            else: -                N = vol_geom['GridColCount'] -                X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) -        else: -            # copy by reference -            X = Xin -        # store the output volume in the parameters -        self.setParameter(output_volume=X) -        X_t = X.copy() -        # convenience variable storage -        proj_geom , vol_geom, sino , \ -          SlicesZ  = self.getParameter([ 'projector_geometry' , -                                                'output_geometry', -                                                'input_sinogram', -                                                'SlicesZ' ]) -                    -        t = 1 -         -        for i in range(self.getParameter('number_of_iterations')): -            X_old = X.copy() -            t_old = t -            r_old = self.r.copy() -            if self.getParameter('projector_geometry')['type'] == 'parallel' or \ -               self.getParameter('projector_geometry')['type'] == 'fanflat' or \ -               self.getParameter('projector_geometry')['type'] == 'fanflat_vec': -                # if the geometry is parallel use slice-by-slice -                # projection-backprojection routine -                #sino_updt = zeros(size(sino),'single'); -                proj_geomT = proj_geom.copy() -                proj_geomT['DetectorRowCount'] = 1 -                vol_geomT = vol_geom.copy() -                vol_geomT['GridSliceCount'] = 1; -                self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) -                for kkk in range(SlicesZ): -                    sino_id, self.sino_updt[kkk] = \ -                             astra.creators.create_sino3d_gpu( -                                 X_t[kkk:kkk+1], proj_geomT, vol_geomT) -                    astra.matlab.data3d('delete', sino_id) -            else: -                # for divergent 3D geometry (watch the GPU memory overflow in -                # ASTRA versions < 1.8) -                #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); -                sino_id, self.sino_updt = astra.creators.create_sino3d_gpu( -                    X_t, proj_geom, vol_geom) - - -            ## RING REMOVAL -            self.ringRemoval(i) -            ## Projection/Backprojection Routine -            self.projectionBackprojection(X, X_t) -            astra.matlab.data3d('delete', sino_id) -            ## REGULARIZATION -            X = self.regularize(X) -            ## Update Loop -            X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old) -            self.setParameter(output_volume=X) -        return X -    ## iterate -     -    def ringRemoval(self, i): -        print ("FISTA Reconstructor: ring removal") -        residual = self.residual -        lambdaR_L1 , alpha_ring , weights , L_const , sino= \ -                   self.getParameter(['ring_lambda_R_L1', -                                      'ring_alpha' , 'weights', -                                      'Lipschitz_constant', -                                      'input_sinogram']) -        r_x = self.r_x -        sino_updt = self.sino_updt -         -        SlicesZ, anglesNumb, Detectors = \ -                    numpy.shape(self.getParameter('input_sinogram')) -        if lambdaR_L1 > 0 : -             for kkk in range(anglesNumb): -                  -                 residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ -                                       ((sino_updt[:,kkk,:]).squeeze() - \ -                                        (sino[:,kkk,:]).squeeze() -\ -                                        (alpha_ring * r_x) -                                        ) -             vec = residual.sum(axis = 1) -             #if SlicesZ > 1: -             #    vec = vec[:,1,:].squeeze() -             self.r = (r_x - (1./L_const) * vec).copy() -             self.objective[i] = (0.5 * (residual ** 2).sum()) - -    def projectionBackprojection(self, X, X_t): -        print ("FISTA Reconstructor: projection-backprojection routine") -         -        # a few useful variables -        SlicesZ, anglesNumb, Detectors = \ -                    numpy.shape(self.getParameter('input_sinogram')) -        residual = self.residual -        proj_geom , vol_geom , L_const = \ -                  self.getParameter(['projector_geometry' , -                                                  'output_geometry', -                                                  'Lipschitz_constant']) -         -         -        if self.getParameter('projector_geometry')['type'] == 'parallel' or \ -           self.getParameter('projector_geometry')['type'] == 'fanflat' or \ -           self.getParameter('projector_geometry')['type'] == 'fanflat_vec': -            # if the geometry is parallel use slice-by-slice -            # projection-backprojection routine -            #sino_updt = zeros(size(sino),'single'); -            proj_geomT = proj_geom.copy() -            proj_geomT['DetectorRowCount'] = 1 -            vol_geomT = vol_geom.copy() -            vol_geomT['GridSliceCount'] = 1; -            x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) -             -            for kkk in range(SlicesZ): -                 -                x_id, x_temp[kkk] = \ -                         astra.creators.create_backprojection3d_gpu( -                             residual[kkk:kkk+1], -                             proj_geomT, vol_geomT) -                astra.matlab.data3d('delete', x_id) -        else: -            x_id, x_temp = \ -                  astra.creators.create_backprojection3d_gpu( -                      residual, proj_geom, vol_geom)             - -        X = X_t - (1/L_const) * x_temp -        #astra.matlab.data3d('delete', sino_id) -        astra.matlab.data3d('delete', x_id) - -    def regularize(self, X): -        print ("FISTA Reconstructor: regularize") -         -        regularizer = self.getParameter('regularizer') -        if regularizer is not None: -            return regularizer(input=X) -        else: -            return X - -    def updateLoop(self, i, X, X_old, r_old, t, t_old): -        print ("FISTA Reconstructor: update loop") -        lambdaR_L1 = self.getParameter('ring_lambda_R_L1') -        if lambdaR_L1 > 0: -            self.r = numpy.max( -                numpy.abs(self.r) - lambdaR_L1 , 0) * \ -                numpy.sign(self.r) -        t = (1 + numpy.sqrt(1 + 4 * t**2))/2 -        X_t = X + (((t_old -1)/t) * (X - X_old)) - -        if lambdaR_L1 > 0: -            self.r_x = self.r + \ -                             (((t_old-1)/t) * (self.r - r_old)) - -        if self.getParameter('region_of_interest') is None: -            string = 'Iteration Number {0} | Objective {1} \n' -            print (string.format( i, self.objective[i])) -        else: -            ROI , X_ideal = fistaRecon.getParameter('region_of_interest', -                                                    'ideal_image') -             -            Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) -            string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' -            print (string.format(i,Resid_error[i], self.objective[i])) -        return (X , X_t, t) - -    def os_iterate(self, Xin=None): -        print ("FISTA Reconstructor: iterate") -         -        if Xin is None:     -            if self.getParameter('initialize'): -                X = self.initialize() -            else: -                N = vol_geom['GridColCount'] -                X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) -        else: -            # copy by reference -            X = Xin -        # store the output volume in the parameters -        self.setParameter(output_volume=X) -        X_t = X.copy() - -        # some useful constants -        proj_geom , vol_geom, sino , \ -          SlicesZ, weights , alpha_ring , -          lambdaR_L1 , L_const = self.getParameter( -            ['projector_geometry' , 'output_geometry', -             'input_sinogram', 'SlicesZ' ,  'weights', 'ring_alpha' , -             'ring_lambda_R_L1', 'Lipschitz_constant']) diff --git a/src/Python/ccpi/fista/Reconstructor.py b/src/Python/ccpi/fista/Reconstructor.py deleted file mode 100644 index d29ac0d..0000000 --- a/src/Python/ccpi/fista/Reconstructor.py +++ /dev/null @@ -1,425 +0,0 @@ -# -*- coding: utf-8 -*- -############################################################################### -#This work is part of the Core Imaging Library developed by -#Visual Analytics and Imaging System Group of the Science Technology -#Facilities Council, STFC -# -#Copyright 2017 Edoardo Pasca, Srikanth Nagella -#Copyright 2017 Daniil Kazantsev -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -#http://www.apache.org/licenses/LICENSE-2.0 -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. -############################################################################### - - - -import numpy -import h5py -from ccpi.reconstruction.parallelbeam import alg - -from Regularizer import Regularizer -from enum import Enum - -import astra - -    -     -class FISTAReconstructor(): -    '''FISTA-based reconstruction algorithm using ASTRA-toolbox -     -    ''' -    # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> -    # ___Input___: -    # params.[] file: -    #       - .proj_geom (geometry of the projector) [required] -    #       - .vol_geom (geometry of the reconstructed object) [required] -    #       - .sino (vectorized in 2D or 3D sinogram) [required] -    #       - .iterFISTA (iterations for the main loop, default 40) -    #       - .L_const (Lipschitz constant, default Power method)                                                                                                    ) -    #       - .X_ideal (ideal image, if given) -    #       - .weights (statisitcal weights, size of the sinogram) -    #       - .ROI (Region-of-interest, only if X_ideal is given) -    #       - .initialize (a 'warm start' using SIRT method from ASTRA) -    #----------------Regularization choices------------------------ -    #       - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) -    #       - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) -    #       - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) -    #       - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) -    #       - .Regul_Iterations (iterations for the selected penalty, default 25) -    #       - .Regul_tauLLT (time step parameter for LLT term) -    #       - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) -    #       - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) -    #----------------Visualization parameters------------------------ -    #       - .show (visualize reconstruction 1/0, (0 default)) -    #       - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) -    #       - .slice (for 3D volumes - slice number to imshow) -    # ___Output___: -    # 1. X - reconstructed image/volume -    # 2. output - a structure with -    #    - .Resid_error - residual error (if X_ideal is given) -    #    - .objective: value of the objective function -    #    - .L_const: Lipshitz constant to avoid recalculations -     -    # References: -    # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse -    # Problems" by A. Beck and M Teboulle -    # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo -    # 3. "A novel tomographic reconstruction method based on the robust -    # Student's t function for suppressing data outliers" D. Kazantsev et.al. -    # D. Kazantsev, 2016-17 -    def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): -        self.params = dict() -        self.params['projector_geometry'] = projector_geometry -        self.params['output_geometry'] = output_geometry -        self.params['input_sinogram'] = input_sinogram -        detectors, nangles, sliceZ = numpy.shape(input_sinogram) -        self.params['detectors'] = detectors -        self.params['number_og_angles'] = nangles -        self.params['SlicesZ'] = sliceZ -         -        # Accepted input keywords -        kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , -              'weights' , 'region_of_interest' , 'initialize' ,  -              'regularizer' ,  -              'ring_lambda_R_L1', -              'ring_alpha') -         -        # handle keyworded parameters -        if kwargs is not None: -            for key, value in kwargs.items(): -                if key in kw: -                    #print("{0} = {1}".format(key, value))                         -                    self.pars[key] = value -                     -        # set the default values for the parameters if not set -        if 'number_of_iterations' in kwargs.keys(): -            self.pars['number_of_iterations'] = kwargs['number_of_iterations'] -        else: -            self.pars['number_of_iterations'] = 40 -        if 'weights' in kwargs.keys(): -            self.pars['weights'] = kwargs['weights'] -        else: -            self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) -        if 'Lipschitz_constant' in kwargs.keys(): -            self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] -        else: -            self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() -         -        if not self.pars['ideal_image'] in kwargs.keys(): -            self.pars['ideal_image'] = None -         -        if not self.pars['region_of_interest'] : -            if self.pars['ideal_image'] == None: -                pass -            else: -                self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) -             -        if not self.pars['regularizer'] : -            self.pars['regularizer'] = None -        else: -            # the regularizer must be a correctly instantiated object -            if not self.pars['ring_lambda_R_L1']: -                self.pars['ring_lambda_R_L1'] = 0 -            if not self.pars['ring_alpha']: -                self.pars['ring_alpha'] = 1 -         -             -             -         -    def calculateLipschitzConstantWithPowerMethod(self): -        ''' using Power method (PM) to establish L constant''' -         -        #N = params.vol_geom.GridColCount -        N = self.pars['output_geometry'].GridColCount -        proj_geom = self.params['projector_geometry'] -        vol_geom = self.params['output_geometry'] -        weights = self.pars['weights'] -        SlicesZ = self.pars['SlicesZ'] -         -        if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): -            #% for parallel geometry we can do just one slice -            #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); -            niter = 15;# % number of iteration for the PM -            #N = params.vol_geom.GridColCount; -            #x1 = rand(N,N,1); -            x1 = numpy.random.rand(1,N,N) -            #sqweight = sqrt(weights(:,:,1)); -            sqweight = numpy.sqrt(weights.T[0]) -            proj_geomT = proj_geom.copy(); -            proj_geomT.DetectorRowCount = 1; -            vol_geomT = vol_geom.copy(); -            vol_geomT['GridSliceCount'] = 1; -             -             -            for i in range(niter): -                if i == 0: -                    #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); -                    sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); -                    y = sqweight * y # element wise multiplication -                    #astra_mex_data3d('delete', sino_id); -                    astra.matlab.data3d('delete', sino_id) -                     -                idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); -                s = numpy.linalg.norm(x1) -                ### this line? -                x1 = x1/s; -                ### this line? -                sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); -                y = sqweight*y; -                astra.matlab.data3d('delete', sino_id); -                astra.matlab.data3d('delete', idx); -            #end -            del proj_geomT -            del vol_geomT -        else -            #% divergen beam geometry -            #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); -            niter = 8; #% number of iteration for PM -            x1 = numpy.random.rand(SlicesZ , N , N); -            #sqweight = sqrt(weights); -            sqweight = numpy.sqrt(weights.T[0]) -             -            sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); -            y = sqweight*y; -            #astra_mex_data3d('delete', sino_id); -            astra.matlab.data3d('delete', sino_id); -             -            for i in range(niter): -                #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); -                idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y,  -                                                                    proj_geom,  -                                                                    vol_geom) -                s = numpy.linalg.norm(x1) -                ### this line? -                x1 = x1/s; -                ### this line? -                #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); -                sino_id, y = astra.creators.create_sino3d_gpu(x1,  -                                                              proj_geom,  -                                                              vol_geom); -                 -                y = sqweight*y; -                #astra_mex_data3d('delete', sino_id); -                #astra_mex_data3d('delete', id); -                astra.matlab.data3d('delete', sino_id); -                astra.matlab.data3d('delete', idx); -            #end -            #clear x1 -            del x1 -         -        return s -     -     -    def setRegularizer(self, regularizer): -        if regularizer -        self.pars['regularizer'] = regularizer -         -     -     - - -def getEntry(location): -    for item in nx[location].keys(): -        print (item) - - -print ("Loading Data") - -##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" -####ind = [i * 1049 for i in range(360)] -#### use only 360 images -##images = 200 -##ind = [int(i * 1049 / images) for i in range(images)] -##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) - -#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" -fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" -nx = h5py.File(fname, "r") - -# the data are stored in a particular location in the hdf5 -for item in nx['entry1/tomo_entry/data'].keys(): -    print (item) - -data = nx.get('entry1/tomo_entry/data/rotation_angle') -angles = numpy.zeros(data.shape) -data.read_direct(angles) -print (angles) -# angles should be in degrees - -data = nx.get('entry1/tomo_entry/data/data') -stack = numpy.zeros(data.shape) -data.read_direct(stack) -print (data.shape) - -print ("Data Loaded") - - -# Normalize -data = nx.get('entry1/tomo_entry/instrument/detector/image_key') -itype = numpy.zeros(data.shape) -data.read_direct(itype) -# 2 is dark field -darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] -dark = darks[0] -for i in range(1, len(darks)): -    dark += darks[i] -dark = dark / len(darks) -#dark[0][0] = dark[0][1] - -# 1 is flat field -flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] -flat = flats[0] -for i in range(1, len(flats)): -    flat += flats[i] -flat = flat / len(flats) -#flat[0][0] = dark[0][1] - - -# 0 is projection data -proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = numpy.asarray (angle_proj) -angle_proj = angle_proj.astype(numpy.float32) - -# normalized data are -# norm = (projection - dark)/(flat-dark) - -def normalize(projection, dark, flat, def_val=0.1): -    a = (projection - dark) -    b = (flat-dark) -    with numpy.errstate(divide='ignore', invalid='ignore'): -        c = numpy.true_divide( a, b ) -        c[ ~ numpy.isfinite( c )] = def_val  # set to not zero if 0/0  -    return c -     - -norm = [normalize(projection, dark, flat) for projection in proj] -norm = numpy.asarray (norm) -norm = norm.astype(numpy.float32) - -#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, -#                 angles = angle_proj, center_of_rotation = 86.2 ,  -#                 flat_field = flat, dark_field = dark,  -#                 iterations = 15, resolution = 1, isLogScale = False, threads = 3) - -#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, -#                 angles = angle_proj, center_of_rotation = 86.2 ,  -#                 flat_field = flat, dark_field = dark,  -#                 iterations = 15, resolution = 1, isLogScale = False, threads = 3) -#img_cgls = recon.reconstruct() -# -#pars = dict() -#pars['algorithm'] = Reconstructor.Algorithm.SIRT -#pars['projection_data'] = proj -#pars['angles'] = angle_proj -#pars['center_of_rotation'] = numpy.double(86.2) -#pars['flat_field'] = flat -#pars['iterations'] = 15 -#pars['dark_field'] = dark -#pars['resolution'] = 1 -#pars['isLogScale'] = False -#pars['threads'] = 3 -# -#img_sirt = recon.reconstruct(pars) -# -#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM -#img_mlem = recon.reconstruct() - -############################################################ -############################################################ -#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV -#recon.pars['regularize'] = numpy.double(0.1) -#img_cgls_conv = recon.reconstruct() - -niterations = 15 -threads = 3 - -img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -                              iteration_values, False) -print ("iteration values %s" % str(iteration_values)) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -                                      numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) -iteration_values = numpy.zeros((niterations,)) -img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -                                      numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) - - -##numpy.save("cgls_recon.npy", img_data) -import matplotlib.pyplot as plt -fig, ax = plt.subplots(1,6,sharey=True) -ax[0].imshow(img_cgls[80]) -ax[0].axis('off')  # clear x- and y-axes -ax[1].imshow(img_sirt[80]) -ax[1].axis('off')  # clear x- and y-axes -ax[2].imshow(img_mlem[80]) -ax[2].axis('off')  # clear x- and y-axesplt.show() -ax[3].imshow(img_cgls_conv[80]) -ax[3].axis('off')  # clear x- and y-axesplt.show() -ax[4].imshow(img_cgls_tikhonov[80]) -ax[4].axis('off')  # clear x- and y-axesplt.show() -ax[5].imshow(img_cgls_TVreg[80]) -ax[5].axis('off')  # clear x- and y-axesplt.show() - - -plt.show() - -#viewer = edo.CILViewer() -#viewer.setInputAsNumpy(img_cgls2) -#viewer.displaySliceActor(0) -#viewer.startRenderLoop() - -import vtk - -def NumpyToVTKImageData(numpyarray): -    if (len(numpy.shape(numpyarray)) == 3): -        doubleImg = vtk.vtkImageData() -        shape = numpy.shape(numpyarray) -        doubleImg.SetDimensions(shape[0], shape[1], shape[2]) -        doubleImg.SetOrigin(0,0,0) -        doubleImg.SetSpacing(1,1,1) -        doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) -        #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) -        doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) -         -        for i in range(shape[0]): -            for j in range(shape[1]): -                for k in range(shape[2]): -                    doubleImg.SetScalarComponentFromDouble( -                        i,j,k,0, numpyarray[i][j][k]) -    #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) -        # rescale to appropriate VTK_UNSIGNED_SHORT -        stats = vtk.vtkImageAccumulate() -        stats.SetInputData(doubleImg) -        stats.Update() -        iMin = stats.GetMin()[0] -        iMax = stats.GetMax()[0] -        scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) - -        shiftScaler = vtk.vtkImageShiftScale () -        shiftScaler.SetInputData(doubleImg) -        shiftScaler.SetScale(scale) -        shiftScaler.SetShift(iMin) -        shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) -        shiftScaler.Update() -        return shiftScaler.GetOutput() -         -#writer = vtk.vtkMetaImageWriter() -#writer.SetFileName(alg + "_recon.mha") -#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) -#writer.Write() diff --git a/src/Python/ccpi/fista/__init__.py b/src/Python/ccpi/fista/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/src/Python/ccpi/fista/__init__.py +++ /dev/null | 
