diff options
| author | Edoardo Pasca <edo.paskino@gmail.com> | 2017-10-31 12:49:10 +0000 | 
|---|---|---|
| committer | Edoardo Pasca <edo.paskino@gmail.com> | 2018-01-19 14:26:06 +0000 | 
| commit | 79dee31236454d56136cea8b63ede769b78d9839 (patch) | |
| tree | ce0ed9e414d1d193513fb5ee9a6e78ce69e2aff0 /src/Python/ccpi | |
| parent | 9e2f024fb1e961f5978124d394ec26d2802273bb (diff) | |
| download | regularization-79dee31236454d56136cea8b63ede769b78d9839.tar.gz regularization-79dee31236454d56136cea8b63ede769b78d9839.tar.bz2 regularization-79dee31236454d56136cea8b63ede769b78d9839.tar.xz regularization-79dee31236454d56136cea8b63ede769b78d9839.zip | |
Added Ordered Subsets
Diffstat (limited to 'src/Python/ccpi')
| -rw-r--r-- | src/Python/ccpi/reconstruction/FISTAReconstructor.py | 186 | 
1 files changed, 173 insertions, 13 deletions
| diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py index b8e1027..4f1709c 100644 --- a/src/Python/ccpi/reconstruction/FISTAReconstructor.py +++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py @@ -430,7 +430,13 @@ class FISTAReconstructor():      # prepareForIteration -    def iterate(self, Xin=None): +    def iterate (self, Xin=None): +        if self.getParameter('subset') == 0: +            return self.iterateStandard(Xin) +        else: +            return self.iterateOrderedSubsets(Xin) +         +    def iterateStandard(self, Xin=None):          print ("FISTA Reconstructor: iterate")          if Xin is None:     @@ -613,14 +619,14 @@ class FISTAReconstructor():      def updateLoop(self, i, X, X_old, r_old, t, t_old):          print ("FISTA Reconstructor: update loop")          lambdaR_L1 = self.getParameter('ring_lambda_R_L1') -        if lambdaR_L1 > 0: -            self.r = numpy.max( -                numpy.abs(self.r) - lambdaR_L1 , 0) * \ -                numpy.sign(self.r) +                      t = (1 + numpy.sqrt(1 + 4 * t**2))/2          X_t = X + (((t_old -1)/t) * (X - X_old))          if lambdaR_L1 > 0: +            self.r = numpy.max( +                numpy.abs(self.r) - lambdaR_L1 , 0) * \ +                numpy.sign(self.r)              self.r_x = self.r + \                               (((t_old-1)/t) * (self.r - r_old)) @@ -636,8 +642,8 @@ class FISTAReconstructor():              print (string.format(i,Resid_error[i], self.objective[i]))          return (X , X_t, t) -    def os_iterate(self, Xin=None): -        print ("FISTA Reconstructor: iterate") +    def iterateOS(self, Xin=None): +        print ("FISTA Reconstructor: Ordered Subsets iterate")          if Xin is None:                  if self.getParameter('initialize'): @@ -653,9 +659,163 @@ class FISTAReconstructor():          X_t = X.copy()          # some useful constants -        proj_geom , vol_geom, sino , \ -          SlicesZ, weights , alpha_ring ,\ -          lambdaR_L1 , L_const = self.getParameter( -            ['projector_geometry' , 'output_geometry', -             'input_sinogram', 'SlicesZ' ,  'weights', 'ring_alpha' , -             'ring_lambda_R_L1', 'Lipschitz_constant']) +        proj_geom ,    vol_geom, sino , \ +          SlicesZ,     weights , alpha_ring ,\ +          lambdaR_L1 , L_const , iterFISTA         = self.getParameter( +            ['projector_geometry' , 'output_geometry', 'input_sinogram', +             'SlicesZ' ,            'weights',         'ring_alpha' , +             'ring_lambda_R_L1',    'Lipschitz_constant', +             'number_of_iterations']) + +         +        # errors vector (if the ground truth is given) +        Resid_error = numpy.zeros((iterFISTA)); +        # objective function values vector +        objective = numpy.zeros((iterFISTA));  + +           +        t = 1 + +        ## additional for  +        proj_geomSUB = proj_geom.copy() +        self.residual2 = numpy.zeros(numpy.shape(sino)) +        residual2 = self.residual2 +        sino_updt_FULL = self.residual.copy() +        r_x = self.r.copy() + +        print ("starting iterations") +        ##    % Outer FISTA iterations loop +        for i in range(fistaRecon.getParameter('number_of_iterations')): +            # With OS approach it becomes trickier to correlate independent +            # subsets, hence additional work is required one solution is to +            # work with a full sinogram at times + +            r_old = self.r.copy() +            t_old = t +            SlicesZ, anglesNumb, Detectors = \ +                        numpy.shape(self.getParameter('input_sinogram'))        ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4 +            if (i > 1 and lambdaR_L1 > 0) : +                for kkk in range(anglesNumb): +                      +                     residual2[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ +                                           ((sino_updt_FULL[:,kkk,:]).squeeze() - \ +                                            (sino[:,kkk,:]).squeeze() -\ +                                            (alpha_ring * r_x) +                                            ) +                 +                vec = self.residual.sum(axis = 1) +                #if SlicesZ > 1: +                #    vec = vec[:,1,:] # 1 or 0? +                r_x = self.r_x +                # update ring variable +                self.r = (r_x - (1./L_const) * vec).copy() + +            # subset loop +            counterInd = 1 +            geometry_type = self.getParameter('projector_geometry')['type'] +            angles = self.getParameter('projector_geometry')['ProjectionAngles'] + +            for ss in range(self.getParameter('subsets')): +                print ("Subset {0}".format(ss)) +                X_old = X.copy() +                t_old = t +                 +                # the number of projections per subset +                numProjSub = self.getParameter('os_bins')[ss] +                CurrSubIndices = self.getParameter('os_indices')\ +                                 [counterInd:counterInd+numProjSub] +                #print ("Len CurrSubIndices {0}".format(numProjSub)) +                mask = numpy.zeros(numpy.shape(angles), dtype=bool) +                cc = 0 +                for j in range(len(CurrSubIndices)): +                    mask[int(CurrSubIndices[j])] = True +                proj_geomSUB['ProjectionAngles'] = angles[mask] + +                shape = list(numpy.shape(self.getParameter('input_sinogram'))) +                shape[1] = numProjSub +                sino_updt_Sub = numpy.zeros(shape) +                if geometry_type == 'parallel' or \ +                   geometry_type == 'fanflat' or \ +                   geometry_type == 'fanflat_vec' : + +                    for kkk in range(SlicesZ): +                        sino_id, sinoT = astra.creators.create_sino3d_gpu ( +                            X_t[kkk:kkk+1] , proj_geomSUB, vol_geom) +                        sino_updt_Sub[kkk] = sinoT.T.copy() +                         +                else: +                    # for 3D geometry (watch the GPU memory overflow in +                    # ASTRA < 1.8) +                    sino_id, sino_updt_Sub = \ +                         astra.creators.create_sino3d_gpu (X_t, proj_geomSUB, vol_geom) +                     +                astra.matlab.data3d('delete', sino_id) +         +                 +                if lambdaR_L1 > 0 : +                    ## RING REMOVAL +                    print ("ring removal") +                    residualSub = self.ringRemovalOrderedSubsets(sino_updt_Sub, +                                                   sino_updt_FULL) +                else: +                    #PWLS model +                    print ("PWLS model") +                    residualSub = weights[:,CurrSubIndices,:] * \ +                                  ( sino_updt_Sub - \ +                                    sino[:,CurrSubIndices,:].squeeze() ) +                    objective[i] = 0.5 * numpy.linalg.norm(residualSub) + +                # projection/backprojection routine +                if geometry_type == 'parallel' or \ +                   geometry_type == 'fanflat' or \ +                   geometry_type == 'fanflat_vec' : +                    # if geometry is 2D use slice-by-slice projection-backprojection +                    # routine +                    x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32) +                    for kkk in range(SlicesZ): +                         +                        x_id, x_temp[kkk] = \ +                                 astra.creators.create_backprojection3d_gpu( +                                     residualSub[kkk:kkk+1], +                                     proj_geomSUB, vol_geom) +                         +                else: +                    x_id, x_temp = \ +                          astra.creators.create_backprojection3d_gpu( +                              residualSub, proj_geomSUB, vol_geom) + +                astra.matlab.data3d('delete', x_id) +                X = X_t - (1/L_const) * x_temp +                ## REGULARIZATION +                X = self.regularize(X) +             +            # FINAL +            ## Update Loop +            X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old) +            self.setParameter(output_volume=X) +            counterInd = counterInd + numProjSub + +        return X +     +    def ringRemovalOrderedSubsets(self, sino_updt_Sub, sino_updt_FULL): +        residual = self.residual +        r_x = self.r_x +        weights , alpha_ring , sino = \ +                self.getParameter( ['weights', 'ring_alpha', 'input_sinogram']) +        numProjSub = self.getParameter('os_bins')[ss] +        CurrSubIndices = self.getParameter('os_indices')\ +                         [counterInd:counterInd+numProjSub] +        residualSub = numpy.zeros(shape) + +        for kkk in range(numProjSub): +            #print ("ring removal indC ... {0}".format(kkk)) +            indC = int(CurrSubIndices[kkk]) +            residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \ +                (sino_updt_Sub[:,kkk,:].squeeze() - \ +                sino[:,indC,:].squeeze() - alpha_ring * r_x) +            # filling the full sinogram +            sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze() + +        return residualSub + + | 
