diff options
author | Daniil Kazantsev <dkazanc@hotmail.com> | 2018-04-03 12:34:10 +0100 |
---|---|---|
committer | Daniil Kazantsev <dkazanc@hotmail.com> | 2018-04-03 12:34:10 +0100 |
commit | 52ed0437bfbf5bb8a6fdafd65628c4460184fd2d (patch) | |
tree | d8cd588e7ca8992ec2e87ae3819f11a2ff5fd97e /Wrappers/Python/ccpi | |
parent | 3ed9b4129cdab5110e67a4705b3bd52fd9781f8b (diff) | |
download | regularization-52ed0437bfbf5bb8a6fdafd65628c4460184fd2d.tar.gz regularization-52ed0437bfbf5bb8a6fdafd65628c4460184fd2d.tar.bz2 regularization-52ed0437bfbf5bb8a6fdafd65628c4460184fd2d.tar.xz regularization-52ed0437bfbf5bb8a6fdafd65628c4460184fd2d.zip |
cleaning stage, leaving only FGP/ROF and related compilers, demos
Diffstat (limited to 'Wrappers/Python/ccpi')
-rw-r--r-- | Wrappers/Python/ccpi/filters/Regularizer.py | 325 | ||||
-rw-r--r-- | Wrappers/Python/ccpi/reconstruction/AstraDevice.py | 95 | ||||
-rw-r--r-- | Wrappers/Python/ccpi/reconstruction/DeviceModel.py | 63 | ||||
-rw-r--r-- | Wrappers/Python/ccpi/reconstruction/FISTAReconstructor.py | 882 | ||||
-rw-r--r-- | Wrappers/Python/ccpi/reconstruction/Reconstructor.py | 598 |
5 files changed, 0 insertions, 1963 deletions
diff --git a/Wrappers/Python/ccpi/filters/Regularizer.py b/Wrappers/Python/ccpi/filters/Regularizer.py deleted file mode 100644 index 4ca94f2..0000000 --- a/Wrappers/Python/ccpi/filters/Regularizer.py +++ /dev/null @@ -1,325 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Created on Tue Aug 8 14:26:00 2017 - -@author: ofn77899 -""" - -from ccpi.filters import cpu_regularizers -import numpy as np -from enum import Enum -import timeit - -class Regularizer(): - '''Class to handle regularizer algorithms to be used during reconstruction - - Currently 5 CPU (OMP) regularization algorithms are available: - - 1) SplitBregman_TV - 2) FGP_TV - 3) LLT_model - 4) PatchBased_Regul - 5) TGV_PD - - Usage: - the regularizer can be invoked as object or as static method - Depending on the actual regularizer the input parameter may vary, and - a different default setting is defined. - reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - - out = reg(input=u0, regularization_parameter=10., number_of_iterations=30, - tolerance_constant=1e-4, - TV_Penalty=Regularizer.TotalVariationPenalty.l1) - - out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., - number_of_iterations=30, tolerance_constant=1e-4, - TV_Penalty=Regularizer.TotalVariationPenalty.l1) - - A number of optional parameters can be passed or skipped - out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) - - ''' - class Algorithm(Enum): - SplitBregman_TV = cpu_regularizers.SplitBregman_TV - FGP_TV = cpu_regularizers.FGP_TV - LLT_model = cpu_regularizers.LLT_model - PatchBased_Regul = cpu_regularizers.PatchBased_Regul - TGV_PD = cpu_regularizers.TGV_PD - # Algorithm - - class TotalVariationPenalty(Enum): - isotropic = 0 - l1 = 1 - # TotalVariationPenalty - - def __init__(self , algorithm, debug = True): - self.setAlgorithm ( algorithm ) - self.debug = debug - # __init__ - - def setAlgorithm(self, algorithm): - self.algorithm = algorithm - self.pars = self.getDefaultParsForAlgorithm(algorithm) - # setAlgorithm - - def getDefaultParsForAlgorithm(self, algorithm): - pars = dict() - - if algorithm == Regularizer.Algorithm.SplitBregman_TV : - pars['algorithm'] = algorithm - pars['input'] = None - pars['regularization_parameter'] = None - pars['number_of_iterations'] = 35 - pars['tolerance_constant'] = 0.0001 - pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic - - elif algorithm == Regularizer.Algorithm.FGP_TV : - pars['algorithm'] = algorithm - pars['input'] = None - pars['regularization_parameter'] = None - pars['number_of_iterations'] = 50 - pars['tolerance_constant'] = 0.001 - pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic - - elif algorithm == Regularizer.Algorithm.LLT_model: - pars['algorithm'] = algorithm - pars['input'] = None - pars['regularization_parameter'] = None - pars['time_step'] = None - pars['number_of_iterations'] = None - pars['tolerance_constant'] = None - pars['restrictive_Z_smoothing'] = 0 - - elif algorithm == Regularizer.Algorithm.PatchBased_Regul: - pars['algorithm'] = algorithm - pars['input'] = None - pars['searching_window_ratio'] = None - pars['similarity_window_ratio'] = None - pars['PB_filtering_parameter'] = None - pars['regularization_parameter'] = None - - elif algorithm == Regularizer.Algorithm.TGV_PD: - pars['algorithm'] = algorithm - pars['input'] = None - pars['first_order_term'] = None - pars['second_order_term'] = None - pars['number_of_iterations'] = None - pars['regularization_parameter'] = None - - else: - raise Exception('Unknown regularizer algorithm') - - self.acceptedInputKeywords = pars.keys() - - return pars - # parsForAlgorithm - - def setParameter(self, **kwargs): - '''set named parameter for the regularization engine - - raises Exception if the named parameter is not recognized - Typical usage is: - - reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - reg.setParameter(input=u0) - reg.setParameter(regularization_parameter=10.) - - it can be also used as - reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - reg.setParameter(input=u0 , regularization_parameter=10.) - ''' - - for key , value in kwargs.items(): - if key in self.pars.keys(): - self.pars[key] = value - else: - raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) - # setParameter - - def getParameter(self, key): - if type(key) is str: - if key in self.acceptedInputKeywords: - return self.pars[key] - else: - raise Exception('Unrecongnised parameter: {0} '.format(key) ) - elif type(key) is list: - outpars = [] - for k in key: - outpars.append(self.getParameter(k)) - return outpars - else: - raise Exception('Unhandled input {0}' .format(str(type(key)))) - # getParameter - - - def __call__(self, input = None, regularization_parameter = None, - output_all = False, **kwargs): - '''Actual call for the regularizer. - - One can either set the regularization parameters first and then call the - algorithm or set the regularization parameter during the call (as - is done in the static methods). - ''' - - if kwargs is not None: - for key, value in kwargs.items(): - #print("{0} = {1}".format(key, value)) - self.pars[key] = value - - if input is not None: - self.pars['input'] = input - if regularization_parameter is not None: - self.pars['regularization_parameter'] = regularization_parameter - - if self.debug: - print ("--------------------------------------------------") - for key, value in self.pars.items(): - if key== 'algorithm' : - print("{0} = {1}".format(key, value.__name__)) - elif key == 'input': - print("{0} = {1}".format(key, np.shape(value))) - else: - print("{0} = {1}".format(key, value)) - - - if None in self.pars: - raise Exception("Not all parameters have been provided") - - input = self.pars['input'] - regularization_parameter = self.pars['regularization_parameter'] - if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : - ret = self.algorithm(input, regularization_parameter, - self.pars['number_of_iterations'], - self.pars['tolerance_constant'], - self.pars['TV_penalty'].value ) - elif self.algorithm == Regularizer.Algorithm.FGP_TV : - ret = self.algorithm(input, regularization_parameter, - self.pars['number_of_iterations'], - self.pars['tolerance_constant'], - self.pars['TV_penalty'].value ) - elif self.algorithm == Regularizer.Algorithm.LLT_model : - #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) - # no default - ret = self.algorithm(input, - regularization_parameter, - self.pars['time_step'] , - self.pars['number_of_iterations'], - self.pars['tolerance_constant'], - self.pars['restrictive_Z_smoothing'] ) - elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : - #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) - # no default - ret = self.algorithm(input, regularization_parameter, - self.pars['searching_window_ratio'] , - self.pars['similarity_window_ratio'] , - self.pars['PB_filtering_parameter']) - elif self.algorithm == Regularizer.Algorithm.TGV_PD : - #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) - # no default - if len(np.shape(input)) == 2: - ret = self.algorithm(input, regularization_parameter, - self.pars['first_order_term'] , - self.pars['second_order_term'] , - self.pars['number_of_iterations']) - elif len(np.shape(input)) == 3: - #assuming it's 3D - # run independent calls on each slice - out3d = input.copy() - for i in range(np.shape(input)[0]): - out = self.algorithm(input[i], regularization_parameter, - self.pars['first_order_term'] , - self.pars['second_order_term'] , - self.pars['number_of_iterations']) - # copy the result in the 3D image - out3d[i] = out[0].copy() - # append the rest of the info that the algorithm returns - output = [out3d] - for i in range(1,len(out)): - output.append(out[i]) - ret = output - - - - if output_all: - return ret - else: - return ret[0] - - # __call__ - - @staticmethod - def SplitBregman_TV(input, regularization_parameter , **kwargs): - start_time = timeit.default_timer() - reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - a = reg(input, regularization_parameter, **kwargs) - txt = reg.printParametersToString() - txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) - return a, reg.pars, txt - - @staticmethod - def FGP_TV(input, regularization_parameter , **kwargs): - start_time = timeit.default_timer() - reg = Regularizer(Regularizer.Algorithm.FGP_TV) - a = reg(input, regularization_parameter, **kwargs) - txt = reg.printParametersToString() - txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) - return a, reg.pars, txt - - @staticmethod - def LLT_model(input, regularization_parameter , time_step, number_of_iterations, - tolerance_constant, restrictive_Z_smoothing=0): - start_time = timeit.default_timer() - reg = Regularizer(Regularizer.Algorithm.LLT_model) - a = reg(input, regularization_parameter, time_step=time_step, - number_of_iterations=number_of_iterations, - tolerance_constant=tolerance_constant, - restrictive_Z_smoothing=restrictive_Z_smoothing) - - txt = reg.printParametersToString() - txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) - return a, reg.pars, txt - - @staticmethod - def PatchBased_Regul(input, regularization_parameter, - searching_window_ratio, - similarity_window_ratio, - PB_filtering_parameter): - start_time = timeit.default_timer() - reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul) - a = reg(input, - regularization_parameter, - searching_window_ratio=searching_window_ratio, - similarity_window_ratio=similarity_window_ratio, - PB_filtering_parameter=PB_filtering_parameter ) - txt = reg.printParametersToString() - txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) - return a, reg.pars, txt - - @staticmethod - def TGV_PD(input, regularization_parameter , first_order_term, - second_order_term, number_of_iterations): - start_time = timeit.default_timer() - - reg = Regularizer(Regularizer.Algorithm.TGV_PD) - a = reg(input, regularization_parameter, - first_order_term=first_order_term, - second_order_term=second_order_term, - number_of_iterations=number_of_iterations) - txt = reg.printParametersToString() - txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) - - - return a, reg.pars, txt - - def printParametersToString(self): - txt = r'' - for key, value in self.pars.items(): - if key== 'algorithm' : - txt += "{0} = {1}".format(key, value.__name__) - elif key == 'input': - txt += "{0} = {1}".format(key, np.shape(value)) - else: - txt += "{0} = {1}".format(key, value) - txt += '\n' - return txt - diff --git a/Wrappers/Python/ccpi/reconstruction/AstraDevice.py b/Wrappers/Python/ccpi/reconstruction/AstraDevice.py deleted file mode 100644 index 57435f8..0000000 --- a/Wrappers/Python/ccpi/reconstruction/AstraDevice.py +++ /dev/null @@ -1,95 +0,0 @@ -import astra -from ccpi.reconstruction.DeviceModel import DeviceModel -import numpy - -class AstraDevice(DeviceModel): - '''Concrete class for Astra Device''' - - def __init__(self, - device_type, - data_aquisition_geometry, - reconstructed_volume_geometry): - - super(AstraDevice, self).__init__(device_type, - data_aquisition_geometry, - reconstructed_volume_geometry) - - self.type = device_type - self.proj_geom = astra.creators.create_proj_geom( - device_type, - self.acquisition_data_geometry['detectorSpacingX'], - self.acquisition_data_geometry['detectorSpacingY'], - self.acquisition_data_geometry['cameraX'], - self.acquisition_data_geometry['cameraY'], - self.acquisition_data_geometry['angles'], - ) - - self.vol_geom = astra.creators.create_vol_geom( - self.reconstructed_volume_geometry['X'], - self.reconstructed_volume_geometry['Y'], - self.reconstructed_volume_geometry['Z'] - ) - - def doForwardProject(self, volume): - '''Forward projects the volume according to the device geometry - -Uses Astra-toolbox -''' - - try: - sino_id, y = astra.creators.create_sino3d_gpu( - volume, self.proj_geom, self.vol_geom) - astra.matlab.data3d('delete', sino_id) - return y - except Exception as e: - print(e) - print("Value Error: ", self.proj_geom, self.vol_geom) - - def doBackwardProject(self, projections): - '''Backward projects the projections according to the device geometry - -Uses Astra-toolbox -''' - idx, volume = \ - astra.creators.create_backprojection3d_gpu( - projections, - self.proj_geom, - self.vol_geom) - - astra.matlab.data3d('delete', idx) - return volume - - def createReducedDevice(self, proj_par={'cameraY' : 1} , vol_par={'Z':1}): - '''Create a new device based on the current device by changing some parameter - -VERY RISKY''' - acquisition_data_geometry = self.acquisition_data_geometry.copy() - for k,v in proj_par.items(): - if k in acquisition_data_geometry.keys(): - acquisition_data_geometry[k] = v - proj_geom = [ - acquisition_data_geometry['cameraX'], - acquisition_data_geometry['cameraY'], - acquisition_data_geometry['detectorSpacingX'], - acquisition_data_geometry['detectorSpacingY'], - acquisition_data_geometry['angles'] - ] - - reconstructed_volume_geometry = self.reconstructed_volume_geometry.copy() - for k,v in vol_par.items(): - if k in reconstructed_volume_geometry.keys(): - reconstructed_volume_geometry[k] = v - - vol_geom = [ - reconstructed_volume_geometry['X'], - reconstructed_volume_geometry['Y'], - reconstructed_volume_geometry['Z'] - ] - return AstraDevice(self.type, proj_geom, vol_geom) - - - -if __name__=="main": - a = AstraDevice() - - diff --git a/Wrappers/Python/ccpi/reconstruction/DeviceModel.py b/Wrappers/Python/ccpi/reconstruction/DeviceModel.py deleted file mode 100644 index eeb9a34..0000000 --- a/Wrappers/Python/ccpi/reconstruction/DeviceModel.py +++ /dev/null @@ -1,63 +0,0 @@ -from abc import ABCMeta, abstractmethod -from enum import Enum - -class DeviceModel(metaclass=ABCMeta): - '''Abstract class that defines the device for projection and backprojection - -This class defines the methods that must be implemented by concrete classes. - - ''' - - class DeviceType(Enum): - '''Type of device -PARALLEL BEAM -PARALLEL BEAM 3D -CONE BEAM -HELICAL''' - - PARALLEL = 'parallel' - PARALLEL3D = 'parallel3d' - CONE_BEAM = 'cone-beam' - HELICAL = 'helical' - - def __init__(self, - device_type, - data_aquisition_geometry, - reconstructed_volume_geometry): - '''Initializes the class - -Mandatory parameters are: -device_type from DeviceType Enum -data_acquisition_geometry: tuple (camera_X, camera_Y, detectorSpacingX, - detectorSpacingY, angles) -reconstructed_volume_geometry: tuple (dimX,dimY,dimZ) -''' - self.device_geometry = device_type - self.acquisition_data_geometry = { - 'cameraX': data_aquisition_geometry[0], - 'cameraY': data_aquisition_geometry[1], - 'detectorSpacingX' : data_aquisition_geometry[2], - 'detectorSpacingY' : data_aquisition_geometry[3], - 'angles' : data_aquisition_geometry[4],} - self.reconstructed_volume_geometry = { - 'X': reconstructed_volume_geometry[0] , - 'Y': reconstructed_volume_geometry[1] , - 'Z': reconstructed_volume_geometry[2] } - - @abstractmethod - def doForwardProject(self, volume): - '''Forward projects the volume according to the device geometry''' - return NotImplemented - - - @abstractmethod - def doBackwardProject(self, projections): - '''Backward projects the projections according to the device geometry''' - return NotImplemented - - @abstractmethod - def createReducedDevice(self): - '''Create a Device to do forward/backward projections on 2D slices''' - return NotImplemented - - diff --git a/Wrappers/Python/ccpi/reconstruction/FISTAReconstructor.py b/Wrappers/Python/ccpi/reconstruction/FISTAReconstructor.py deleted file mode 100644 index e40ad24..0000000 --- a/Wrappers/Python/ccpi/reconstruction/FISTAReconstructor.py +++ /dev/null @@ -1,882 +0,0 @@ -# -*- coding: utf-8 -*- -############################################################################### -#This work is part of the Core Imaging Library developed by -#Visual Analytics and Imaging System Group of the Science Technology -#Facilities Council, STFC -# -#Copyright 2017 Edoardo Pasca, Srikanth Nagella -#Copyright 2017 Daniil Kazantsev -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -#http://www.apache.org/licenses/LICENSE-2.0 -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. -############################################################################### - - - -import numpy -#from ccpi.reconstruction.parallelbeam import alg - -#from ccpi.imaging.Regularizer import Regularizer -from enum import Enum - -import astra -from ccpi.reconstruction.AstraDevice import AstraDevice - - - -class FISTAReconstructor(): - '''FISTA-based reconstruction algorithm using ASTRA-toolbox - - ''' - # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> - # ___Input___: - # params.[] file: - # - .proj_geom (geometry of the projector) [required] - # - .vol_geom (geometry of the reconstructed object) [required] - # - .sino (vectorized in 2D or 3D sinogram) [required] - # - .iterFISTA (iterations for the main loop, default 40) - # - .L_const (Lipschitz constant, default Power method) ) - # - .X_ideal (ideal image, if given) - # - .weights (statisitcal weights, size of the sinogram) - # - .ROI (Region-of-interest, only if X_ideal is given) - # - .initialize (a 'warm start' using SIRT method from ASTRA) - #----------------Regularization choices------------------------ - # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) - # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) - # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) - # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) - # - .Regul_Iterations (iterations for the selected penalty, default 25) - # - .Regul_tauLLT (time step parameter for LLT term) - # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) - # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) - #----------------Visualization parameters------------------------ - # - .show (visualize reconstruction 1/0, (0 default)) - # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) - # - .slice (for 3D volumes - slice number to imshow) - # ___Output___: - # 1. X - reconstructed image/volume - # 2. output - a structure with - # - .Resid_error - residual error (if X_ideal is given) - # - .objective: value of the objective function - # - .L_const: Lipshitz constant to avoid recalculations - - # References: - # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse - # Problems" by A. Beck and M Teboulle - # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo - # 3. "A novel tomographic reconstruction method based on the robust - # Student's t function for suppressing data outliers" D. Kazantsev et.al. - # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, - output_geometry, - input_sinogram, - device, - **kwargs): - # handle parmeters: - # obligatory parameters - self.pars = dict() - self.pars['projector_geometry'] = projector_geometry # proj_geom - self.pars['output_geometry'] = output_geometry # vol_geom - self.pars['input_sinogram'] = input_sinogram # sino - sliceZ, nangles, detectors = numpy.shape(input_sinogram) - self.pars['detectors'] = detectors - self.pars['number_of_angles'] = nangles - self.pars['SlicesZ'] = sliceZ - self.pars['output_volume'] = None - self.pars['device_model'] = device - - self.use_device = True - - print (self.pars) - # handle optional input parameters (at instantiation) - - # Accepted input keywords - kw = ( - # mandatory fields - 'projector_geometry', - 'output_geometry', - 'input_sinogram', - 'detectors', - 'number_of_angles', - 'SlicesZ', - # optional fields - 'number_of_iterations', - 'Lipschitz_constant' , - 'ideal_image' , - 'weights' , - 'region_of_interest' , - 'initialize' , - 'regularizer' , - 'ring_lambda_R_L1', - 'ring_alpha', - 'subsets', - 'output_volume', - 'os_subsets', - 'os_indices', - 'os_bins', - 'device_model', - 'reduced_device_model') - self.acceptedInputKeywords = list(kw) - - # handle keyworded parameters - if kwargs is not None: - for key, value in kwargs.items(): - if key in kw: - #print("{0} = {1}".format(key, value)) - self.pars[key] = value - - # set the default values for the parameters if not set - if 'number_of_iterations' in kwargs.keys(): - self.pars['number_of_iterations'] = kwargs['number_of_iterations'] - else: - self.pars['number_of_iterations'] = 40 - if 'weights' in kwargs.keys(): - self.pars['weights'] = kwargs['weights'] - else: - self.pars['weights'] = \ - numpy.ones(numpy.shape( - self.pars['input_sinogram'])) - if 'Lipschitz_constant' in kwargs.keys(): - self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] - else: - self.pars['Lipschitz_constant'] = None - - if not 'ideal_image' in kwargs.keys(): - self.pars['ideal_image'] = None - - if not 'region_of_interest'in kwargs.keys() : - if self.pars['ideal_image'] == None: - self.pars['region_of_interest'] = None - else: - ## nonzero if the image is larger than m - fsm = numpy.frompyfunc(lambda x,m: 1 if x>m else 0, 2,1) - - self.pars['region_of_interest'] = fsm(self.pars['ideal_image'], 0) - - # the regularizer must be a correctly instantiated object - if not 'regularizer' in kwargs.keys() : - self.pars['regularizer'] = None - - #RING REMOVAL - if not 'ring_lambda_R_L1' in kwargs.keys(): - self.pars['ring_lambda_R_L1'] = 0 - if not 'ring_alpha' in kwargs.keys(): - self.pars['ring_alpha'] = 1 - - # ORDERED SUBSET - if not 'subsets' in kwargs.keys(): - self.pars['subsets'] = 0 - else: - self.createOrderedSubsets() - - if not 'initialize' in kwargs.keys(): - self.pars['initialize'] = False - - reduced_device = device.createReducedDevice() - self.setParameter(reduced_device_model=reduced_device) - - - - def setParameter(self, **kwargs): - '''set named parameter for the reconstructor engine - - raises Exception if the named parameter is not recognized - - ''' - for key , value in kwargs.items(): - if key in self.acceptedInputKeywords: - self.pars[key] = value - else: - raise Exception('Wrong parameter {0} for '.format(key) + - 'reconstructor') - # setParameter - - def getParameter(self, key): - if type(key) is str: - if key in self.acceptedInputKeywords: - return self.pars[key] - else: - raise Exception('Unrecongnised parameter: {0} '.format(key) ) - elif type(key) is list: - outpars = [] - for k in key: - outpars.append(self.getParameter(k)) - return outpars - else: - raise Exception('Unhandled input {0}' .format(str(type(key)))) - - - def calculateLipschitzConstantWithPowerMethod(self): - ''' using Power method (PM) to establish L constant''' - - N = self.pars['output_geometry']['GridColCount'] - proj_geom = self.pars['projector_geometry'] - vol_geom = self.pars['output_geometry'] - weights = self.pars['weights'] - SlicesZ = self.pars['SlicesZ'] - - - - if (proj_geom['type'] == 'parallel') or \ - (proj_geom['type'] == 'parallel3d'): - #% for parallel geometry we can do just one slice - #print('Calculating Lipshitz constant for parallel beam geometry...') - niter = 5;# % number of iteration for the PM - #N = params.vol_geom.GridColCount; - #x1 = rand(N,N,1); - x1 = numpy.random.rand(1,N,N) - #sqweight = sqrt(weights(:,:,1)); - sqweight = numpy.sqrt(weights[0:1,:,:]) - proj_geomT = proj_geom.copy(); - proj_geomT['DetectorRowCount'] = 1; - vol_geomT = vol_geom.copy(); - vol_geomT['GridSliceCount'] = 1; - - #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - - - for i in range(niter): - # [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT); - # s = norm(x1(:)); - # x1 = x1/s; - # [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - # y = sqweight.*y; - # astra_mex_data3d('delete', sino_id); - # astra_mex_data3d('delete', id); - #print ("iteration {0}".format(i)) - - sino_id, y = astra.creators.create_sino3d_gpu(x1, - proj_geomT, - vol_geomT) - - y = (sqweight * y).copy() # element wise multiplication - - #b=fig.add_subplot(2,1,2) - #imgplot = plt.imshow(x1[0]) - #plt.show() - - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id) - del x1 - - idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(), - proj_geomT, - vol_geomT) - del y - - - s = numpy.linalg.norm(x1) - ### this line? - x1 = (x1/s).copy(); - - # ### this line? - # sino_id, y = astra.creators.create_sino3d_gpu(x1, - # proj_geomT, - # vol_geomT); - # y = sqweight * y; - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx) - print ("iteration {0} s= {1}".format(i,s)) - - #end - del proj_geomT - del vol_geomT - #plt.show() - else: - #% divergen beam geometry - print('Calculating Lipshitz constant for divergen beam geometry...') - niter = 8; #% number of iteration for PM - x1 = numpy.random.rand(SlicesZ , N , N); - #sqweight = sqrt(weights); - sqweight = numpy.sqrt(weights[0]) - - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id); - - for i in range(niter): - #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, - proj_geom, - vol_geom) - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); - sino_id, y = astra.creators.create_sino3d_gpu(x1, - proj_geom, - vol_geom); - - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - #astra_mex_data3d('delete', id); - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - #clear x1 - del x1 - - - return s - - - def setRegularizer(self, regularizer): - if regularizer is not None: - self.pars['regularizer'] = regularizer - - - def initialize(self): - # convenience variable storage - proj_geom = self.pars['projector_geometry'] - vol_geom = self.pars['output_geometry'] - sino = self.pars['input_sinogram'] - - # a 'warm start' with SIRT method - # Create a data object for the reconstruction - rec_id = astra.matlab.data3d('create', '-vol', - vol_geom); - - #sinogram_id = astra_mex_data3d('create', '-proj3d', proj_geom, sino); - sinogram_id = astra.matlab.data3d('create', '-proj3d', - proj_geom, - sino) - - sirt_config = astra.astra_dict('SIRT3D_CUDA') - sirt_config['ReconstructionDataId' ] = rec_id - sirt_config['ProjectionDataId'] = sinogram_id - - sirt = astra.algorithm.create(sirt_config) - astra.algorithm.run(sirt, iterations=35) - X = astra.matlab.data3d('get', rec_id) - - # clean up memory - astra.matlab.data3d('delete', rec_id) - astra.matlab.data3d('delete', sinogram_id) - astra.algorithm.delete(sirt) - - - - return X - - def createOrderedSubsets(self, subsets=None): - if subsets is None: - try: - subsets = self.getParameter('subsets') - except Exception(): - subsets = 0 - #return subsets - else: - self.setParameter(subsets=subsets) - - - angles = self.getParameter('projector_geometry')['ProjectionAngles'] - - #binEdges = numpy.linspace(angles.min(), - # angles.max(), - # subsets + 1) - binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) - # get rearranged subset indices - IndicesReorg = numpy.zeros((numpy.shape(angles)), dtype=numpy.int32) - counterM = 0 - for ii in range(binsDiscr.max()): - counter = 0 - for jj in range(subsets): - curr_index = ii + jj + counter - #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM)) - if binsDiscr[jj] > ii: - if (counterM < numpy.size(IndicesReorg)): - IndicesReorg[counterM] = curr_index - counterM = counterM + 1 - - counter = counter + binsDiscr[jj] - 1 - - # store the OS in parameters - self.setParameter(os_subsets=subsets, - os_bins=binsDiscr, - os_indices=IndicesReorg) - - - def prepareForIteration(self): - print ("FISTA Reconstructor: prepare for iteration") - - self.residual_error = numpy.zeros((self.pars['number_of_iterations'])) - self.objective = numpy.zeros((self.pars['number_of_iterations'])) - - #2D array (for 3D data) of sparse "ring" - detectors, nangles, sliceZ = numpy.shape(self.pars['input_sinogram']) - self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float) - # another ring variable - self.r_x = self.r.copy() - - self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram'])) - - if self.getParameter('Lipschitz_constant') is None: - self.pars['Lipschitz_constant'] = \ - self.calculateLipschitzConstantWithPowerMethod() - # errors vector (if the ground truth is given) - self.Resid_error = numpy.zeros((self.getParameter('number_of_iterations'))); - # objective function values vector - self.objective = numpy.zeros((self.getParameter('number_of_iterations'))); - - - # prepareForIteration - - def iterate (self, Xin=None): - if self.getParameter('subsets') == 0: - return self.iterateStandard(Xin) - else: - return self.iterateOrderedSubsets(Xin) - - def iterateStandard(self, Xin=None): - print ("FISTA Reconstructor: iterate") - - if Xin is None: - if self.getParameter('initialize'): - X = self.initialize() - else: - N = vol_geom['GridColCount'] - X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) - else: - # copy by reference - X = Xin - # store the output volume in the parameters - self.setParameter(output_volume=X) - X_t = X.copy() - # convenience variable storage - proj_geom , vol_geom, sino , \ - SlicesZ , ring_lambda_R_L1 , weights = \ - self.getParameter([ 'projector_geometry' , - 'output_geometry', - 'input_sinogram', - 'SlicesZ' , - 'ring_lambda_R_L1', - 'weights']) - - t = 1 - - device = self.getParameter('device_model') - reduced_device = self.getParameter('reduced_device_model') - - for i in range(self.getParameter('number_of_iterations')): - print("iteration", i) - X_old = X.copy() - t_old = t - r_old = self.r.copy() - pg = self.getParameter('projector_geometry')['type'] - if pg == 'parallel' or \ - pg == 'fanflat' or \ - pg == 'fanflat_vec': - # if the geometry is parallel use slice-by-slice - # projection-backprojection routine - #sino_updt = zeros(size(sino),'single'); - - if self.use_device : - self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) - - for kkk in range(SlicesZ): - self.sino_updt[kkk] = \ - reduced_device.doForwardProject( X_t[kkk:kkk+1] ) - else: - proj_geomT = proj_geom.copy() - proj_geomT['DetectorRowCount'] = 1 - vol_geomT = vol_geom.copy() - vol_geomT['GridSliceCount'] = 1; - self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) - for kkk in range(SlicesZ): - sino_id, self.sino_updt[kkk] = \ - astra.creators.create_sino3d_gpu( - X_t[kkk:kkk+1], proj_geomT, vol_geomT) - astra.matlab.data3d('delete', sino_id) - else: - # for divergent 3D geometry (watch the GPU memory overflow in - # ASTRA versions < 1.8) - #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); - - if self.use_device: - self.sino_updt = device.doForwardProject(X_t) - else: - sino_id, self.sino_updt = astra.creators.create_sino3d_gpu( - X_t, proj_geom, vol_geom) - astra.matlab.data3d('delete', sino_id) - - - ## RING REMOVAL - if ring_lambda_R_L1 != 0: - self.ringRemoval(i) - else: - self.residual = weights * (self.sino_updt - sino) - self.objective[i] = 0.5 * numpy.linalg.norm(self.residual) - #objective(i) = 0.5*norm(residual(:)); % for the objective function output - ## Projection/Backprojection Routine - X, X_t = self.projectionBackprojection(X, X_t) - - ## REGULARIZATION - Y = self.regularize(X) - X = Y.copy() - ## Update Loop - X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old) - - print ("t" , t) - print ("X min {0} max {1}".format(X_t.min(),X_t.max())) - self.setParameter(output_volume=X) - return X - ## iterate - - def ringRemoval(self, i): - print ("FISTA Reconstructor: ring removal") - residual = self.residual - lambdaR_L1 , alpha_ring , weights , L_const , sino= \ - self.getParameter(['ring_lambda_R_L1', - 'ring_alpha' , 'weights', - 'Lipschitz_constant', - 'input_sinogram']) - r_x = self.r_x - sino_updt = self.sino_updt - - SlicesZ, anglesNumb, Detectors = \ - numpy.shape(self.getParameter('input_sinogram')) - if lambdaR_L1 > 0 : - for kkk in range(anglesNumb): - - residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ - ((sino_updt[:,kkk,:]).squeeze() - \ - (sino[:,kkk,:]).squeeze() -\ - (alpha_ring * r_x) - ) - vec = residual.sum(axis = 1) - #if SlicesZ > 1: - # vec = vec[:,1,:].squeeze() - self.r = (r_x - (1./L_const) * vec).copy() - self.objective[i] = (0.5 * (residual ** 2).sum()) - - def projectionBackprojection(self, X, X_t): - print ("FISTA Reconstructor: projection-backprojection routine") - - # a few useful variables - SlicesZ, anglesNumb, Detectors = \ - numpy.shape(self.getParameter('input_sinogram')) - residual = self.residual - proj_geom , vol_geom , L_const = \ - self.getParameter(['projector_geometry' , - 'output_geometry', - 'Lipschitz_constant']) - - device, reduced_device = self.getParameter(['device_model', - 'reduced_device_model']) - - if self.getParameter('projector_geometry')['type'] == 'parallel' or \ - self.getParameter('projector_geometry')['type'] == 'fanflat' or \ - self.getParameter('projector_geometry')['type'] == 'fanflat_vec': - # if the geometry is parallel use slice-by-slice - # projection-backprojection routine - #sino_updt = zeros(size(sino),'single'); - x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) - - if self.use_device: - proj_geomT = proj_geom.copy() - proj_geomT['DetectorRowCount'] = 1 - vol_geomT = vol_geom.copy() - vol_geomT['GridSliceCount'] = 1; - - for kkk in range(SlicesZ): - - x_id, x_temp[kkk] = \ - astra.creators.create_backprojection3d_gpu( - residual[kkk:kkk+1], - proj_geomT, vol_geomT) - astra.matlab.data3d('delete', x_id) - else: - for kkk in range(SliceZ): - x_temp[kkk] = \ - reduced_device.doBackwardProject(residual[kkk:kkk+1]) - else: - if self.use_device: - x_id, x_temp = \ - astra.creators.create_backprojection3d_gpu( - residual, proj_geom, vol_geom) - astra.matlab.data3d('delete', x_id) - else: - x_temp = \ - device.doBackwardProject(residual) - - - X = X_t - (1/L_const) * x_temp - #astra.matlab.data3d('delete', sino_id) - return (X , X_t) - - - def regularize(self, X , output_all=False): - #print ("FISTA Reconstructor: regularize") - - regularizer = self.getParameter('regularizer') - if regularizer is not None: - return regularizer(input=X, - output_all=output_all) - else: - return X - - def updateLoop(self, i, X, X_old, r_old, t, t_old): - print ("FISTA Reconstructor: update loop") - lambdaR_L1 = self.getParameter('ring_lambda_R_L1') - - t = (1 + numpy.sqrt(1 + 4 * t**2))/2 - X_t = X + (((t_old -1)/t) * (X - X_old)) - - if lambdaR_L1 > 0: - self.r = numpy.max( - numpy.abs(self.r) - lambdaR_L1 , 0) * \ - numpy.sign(self.r) - self.r_x = self.r + \ - (((t_old-1)/t) * (self.r - r_old)) - - if self.getParameter('region_of_interest') is None: - string = 'Iteration Number {0} | Objective {1} \n' - print (string.format( i, self.objective[i])) - else: - ROI , X_ideal = fistaRecon.getParameter('region_of_interest', - 'ideal_image') - - Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) - string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' - print (string.format(i,Resid_error[i], self.objective[i])) - return (X , X_t, t) - - def iterateOrderedSubsets(self, Xin=None): - print ("FISTA Reconstructor: Ordered Subsets iterate") - - if Xin is None: - if self.getParameter('initialize'): - X = self.initialize() - else: - N = vol_geom['GridColCount'] - X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) - else: - # copy by reference - X = Xin - # store the output volume in the parameters - self.setParameter(output_volume=X) - X_t = X.copy() - - # some useful constants - proj_geom , vol_geom, sino , \ - SlicesZ, weights , alpha_ring ,\ - lambdaR_L1 , L_const , iterFISTA = self.getParameter( - ['projector_geometry' , 'output_geometry', 'input_sinogram', - 'SlicesZ' , 'weights', 'ring_alpha' , - 'ring_lambda_R_L1', 'Lipschitz_constant', - 'number_of_iterations']) - - - # errors vector (if the ground truth is given) - Resid_error = numpy.zeros((iterFISTA)); - # objective function values vector - #objective = numpy.zeros((iterFISTA)); - objective = self.objective - - - t = 1 - - ## additional for - proj_geomSUB = proj_geom.copy() - self.residual2 = numpy.zeros(numpy.shape(sino)) - residual2 = self.residual2 - sino_updt_FULL = self.residual.copy() - r_x = self.r.copy() - - print ("starting iterations") - ## % Outer FISTA iterations loop - for i in range(self.getParameter('number_of_iterations')): - # With OS approach it becomes trickier to correlate independent - # subsets, hence additional work is required one solution is to - # work with a full sinogram at times - - r_old = self.r.copy() - t_old = t - SlicesZ, anglesNumb, Detectors = \ - numpy.shape(self.getParameter('input_sinogram')) ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4 - if (i > 1 and lambdaR_L1 > 0) : - for kkk in range(anglesNumb): - - residual2[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ - ((sino_updt_FULL[:,kkk,:]).squeeze() - \ - (sino[:,kkk,:]).squeeze() -\ - (alpha_ring * r_x) - ) - - vec = self.residual.sum(axis = 1) - #if SlicesZ > 1: - # vec = vec[:,1,:] # 1 or 0? - r_x = self.r_x - # update ring variable - self.r = (r_x - (1./L_const) * vec).copy() - - # subset loop - counterInd = 1 - geometry_type = self.getParameter('projector_geometry')['type'] - angles = self.getParameter('projector_geometry')['ProjectionAngles'] - - for ss in range(self.getParameter('subsets')): - #print ("Subset {0}".format(ss)) - X_old = X.copy() - t_old = t - - # the number of projections per subset - numProjSub = self.getParameter('os_bins')[ss] - CurrSubIndices = self.getParameter('os_indices')\ - [counterInd:counterInd+numProjSub] - #print ("Len CurrSubIndices {0}".format(numProjSub)) - mask = numpy.zeros(numpy.shape(angles), dtype=bool) - #cc = 0 - for j in range(len(CurrSubIndices)): - mask[int(CurrSubIndices[j])] = True - proj_geomSUB['ProjectionAngles'] = angles[mask] - - if self.use_device: - device = self.getParameter('device_model')\ - .createReducedDevice( - proj_par={'angles':angles[mask]}, - vol_par={}) - - shape = list(numpy.shape(self.getParameter('input_sinogram'))) - shape[1] = numProjSub - sino_updt_Sub = numpy.zeros(shape) - if geometry_type == 'parallel' or \ - geometry_type == 'fanflat' or \ - geometry_type == 'fanflat_vec' : - - for kkk in range(SlicesZ): - if self.use_device: - sinoT = device.doForwardProject(X_t[kkk:kkk+1]) - else: - sino_id, sinoT = astra.creators.create_sino3d_gpu ( - X_t[kkk:kkk+1] , proj_geomSUB, vol_geom) - astra.matlab.data3d('delete', sino_id) - sino_updt_Sub[kkk] = sinoT.T.copy() - - else: - # for 3D geometry (watch the GPU memory overflow in - # ASTRA < 1.8) - if self.use_device: - sino_updt_Sub = device.doForwardProject(X_t) - - else: - sino_id, sino_updt_Sub = \ - astra.creators.create_sino3d_gpu (X_t, proj_geomSUB, vol_geom) - - astra.matlab.data3d('delete', sino_id) - - #print ("shape(sino_updt_Sub)",numpy.shape(sino_updt_Sub)) - if lambdaR_L1 > 0 : - ## RING REMOVAL - #print ("ring removal") - residualSub , sino_updt_Sub, sino_updt_FULL = \ - self.ringRemovalOrderedSubsets(ss, - counterInd, - sino_updt_Sub, - sino_updt_FULL) - else: - #PWLS model - #print ("PWLS model") - residualSub = weights[:,CurrSubIndices,:] * \ - ( sino_updt_Sub - \ - sino[:,CurrSubIndices,:].squeeze() ) - objective[i] = 0.5 * numpy.linalg.norm(residualSub) - - # projection/backprojection routine - if geometry_type == 'parallel' or \ - geometry_type == 'fanflat' or \ - geometry_type == 'fanflat_vec' : - # if geometry is 2D use slice-by-slice projection-backprojection - # routine - x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32) - for kkk in range(SlicesZ): - if self.use_device: - x_temp[kkk] = device.doBackwardProject( - residualSub[kkk:kkk+1]) - else: - x_id, x_temp[kkk] = \ - astra.creators.create_backprojection3d_gpu( - residualSub[kkk:kkk+1], - proj_geomSUB, vol_geom) - astra.matlab.data3d('delete', x_id) - - else: - if self.use_device: - x_temp = device.doBackwardProject( - residualSub) - else: - x_id, x_temp = \ - astra.creators.create_backprojection3d_gpu( - residualSub, proj_geomSUB, vol_geom) - - astra.matlab.data3d('delete', x_id) - - X = X_t - (1/L_const) * x_temp - - ## REGULARIZATION - X = self.regularize(X) - - ## Update subset Loop - t = (1 + numpy.sqrt(1 + 4 * t**2))/2 - X_t = X + (((t_old -1)/t) * (X - X_old)) - # FINAL - ## update iteration loop - if lambdaR_L1 > 0: - self.r = numpy.max( - numpy.abs(self.r) - lambdaR_L1 , 0) * \ - numpy.sign(self.r) - self.r_x = self.r + \ - (((t_old-1)/t) * (self.r - r_old)) - - if self.getParameter('region_of_interest') is None: - string = 'Iteration Number {0} | Objective {1} \n' - print (string.format( i, self.objective[i])) - else: - ROI , X_ideal = fistaRecon.getParameter('region_of_interest', - 'ideal_image') - - Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) - string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' - print (string.format(i,Resid_error[i], self.objective[i])) - print("X min {0} max {1}".format(X.min(),X.max())) - self.setParameter(output_volume=X) - counterInd = counterInd + numProjSub - - return X - - def ringRemovalOrderedSubsets(self, ss,counterInd, - sino_updt_Sub, sino_updt_FULL): - residual = self.residual - r_x = self.r_x - weights , alpha_ring , sino = \ - self.getParameter( ['weights', 'ring_alpha', 'input_sinogram']) - numProjSub = self.getParameter('os_bins')[ss] - CurrSubIndices = self.getParameter('os_indices')\ - [counterInd:counterInd+numProjSub] - - shape = list(numpy.shape(self.getParameter('input_sinogram'))) - shape[1] = numProjSub - - residualSub = numpy.zeros(shape) - - for kkk in range(numProjSub): - #print ("ring removal indC ... {0}".format(kkk)) - indC = int(CurrSubIndices[kkk]) - residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \ - (sino_updt_Sub[:,kkk,:].squeeze() - \ - sino[:,indC,:].squeeze() - alpha_ring * r_x) - # filling the full sinogram - sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze() - - return (residualSub , sino_updt_Sub, sino_updt_FULL) - - diff --git a/Wrappers/Python/ccpi/reconstruction/Reconstructor.py b/Wrappers/Python/ccpi/reconstruction/Reconstructor.py deleted file mode 100644 index 2ad8a44..0000000 --- a/Wrappers/Python/ccpi/reconstruction/Reconstructor.py +++ /dev/null @@ -1,598 +0,0 @@ -# -*- coding: utf-8 -*- -############################################################################### -#This work is part of the Core Imaging Library developed by -#Visual Analytics and Imaging System Group of the Science Technology -#Facilities Council, STFC -# -#Copyright 2017 Edoardo Pasca, Srikanth Nagella -#Copyright 2017 Daniil Kazantsev -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -#http://www.apache.org/licenses/LICENSE-2.0 -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. -############################################################################### - - - -import numpy -import h5py -from ccpi.reconstruction.parallelbeam import alg - -from Regularizer import Regularizer -from enum import Enum - -import astra - - -class Reconstructor: - - class Algorithm(Enum): - CGLS = alg.cgls - CGLS_CONV = alg.cgls_conv - SIRT = alg.sirt - MLEM = alg.mlem - CGLS_TICHONOV = alg.cgls_tikhonov - CGLS_TVREG = alg.cgls_TVreg - FISTA = 'fista' - - def __init__(self, algorithm = None, projection_data = None, - angles = None, center_of_rotation = None , - flat_field = None, dark_field = None, - iterations = None, resolution = None, isLogScale = False, threads = None, - normalized_projection = None): - - self.pars = dict() - self.pars['algorithm'] = algorithm - self.pars['projection_data'] = projection_data - self.pars['normalized_projection'] = normalized_projection - self.pars['angles'] = angles - self.pars['center_of_rotation'] = numpy.double(center_of_rotation) - self.pars['flat_field'] = flat_field - self.pars['iterations'] = iterations - self.pars['dark_field'] = dark_field - self.pars['resolution'] = resolution - self.pars['isLogScale'] = isLogScale - self.pars['threads'] = threads - if (iterations != None): - self.pars['iterationValues'] = numpy.zeros((iterations)) - - if projection_data != None and dark_field != None and flat_field != None: - norm = self.normalize(projection_data, dark_field, flat_field, 0.1) - self.pars['normalized_projection'] = norm - - - def setPars(self, parameters): - keys = ['algorithm','projection_data' ,'normalized_projection', \ - 'angles' , 'center_of_rotation' , 'flat_field', \ - 'iterations','dark_field' , 'resolution', 'isLogScale' , \ - 'threads' , 'iterationValues', 'regularize'] - - for k in keys: - if k not in parameters.keys(): - self.pars[k] = None - else: - self.pars[k] = parameters[k] - - - def sanityCheck(self): - projection_data = self.pars['projection_data'] - dark_field = self.pars['dark_field'] - flat_field = self.pars['flat_field'] - angles = self.pars['angles'] - - if projection_data != None and dark_field != None and \ - angles != None and flat_field != None: - data_shape = numpy.shape(projection_data) - angle_shape = numpy.shape(angles) - - if angle_shape[0] != data_shape[0]: - #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ - # (angle_shape[0] , data_shape[0]) ) - return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ - (angle_shape[0] , data_shape[0]) ) - - if data_shape[1:] != numpy.shape(flat_field): - #raise Exception('Projection and flat field dimensions do not match') - return (False , 'Projection and flat field dimensions do not match') - if data_shape[1:] != numpy.shape(dark_field): - #raise Exception('Projection and dark field dimensions do not match') - return (False , 'Projection and dark field dimensions do not match') - - return (True , '' ) - elif self.pars['normalized_projection'] != None: - data_shape = numpy.shape(self.pars['normalized_projection']) - angle_shape = numpy.shape(angles) - - if angle_shape[0] != data_shape[0]: - #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ - # (angle_shape[0] , data_shape[0]) ) - return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ - (angle_shape[0] , data_shape[0]) ) - else: - return (True , '' ) - else: - return (False , 'Not enough data') - - def reconstruct(self, parameters = None): - if parameters != None: - self.setPars(parameters) - - go , reason = self.sanityCheck() - if go: - return self._reconstruct() - else: - raise Exception(reason) - - - def _reconstruct(self, parameters=None): - if parameters!=None: - self.setPars(parameters) - parameters = self.pars - - if parameters['algorithm'] != None and \ - parameters['normalized_projection'] != None and \ - parameters['angles'] != None and \ - parameters['center_of_rotation'] != None and \ - parameters['iterations'] != None and \ - parameters['resolution'] != None and\ - parameters['threads'] != None and\ - parameters['isLogScale'] != None: - - - if parameters['algorithm'] in (Reconstructor.Algorithm.CGLS, - Reconstructor.Algorithm.MLEM, Reconstructor.Algorithm.SIRT): - #store parameters - self.pars = parameters - result = parameters['algorithm']( - parameters['normalized_projection'] , - parameters['angles'], - parameters['center_of_rotation'], - parameters['resolution'], - parameters['iterations'], - parameters['threads'] , - parameters['isLogScale'] - ) - return result - elif parameters['algorithm'] in (Reconstructor.Algorithm.CGLS_CONV, - Reconstructor.Algorithm.CGLS_TICHONOV, - Reconstructor.Algorithm.CGLS_TVREG) : - self.pars = parameters - result = parameters['algorithm']( - parameters['normalized_projection'] , - parameters['angles'], - parameters['center_of_rotation'], - parameters['resolution'], - parameters['iterations'], - parameters['threads'] , - parameters['regularize'], - numpy.zeros((parameters['iterations'])), - parameters['isLogScale'] - ) - - elif parameters['algorithm'] == Reconstructor.Algorithm.FISTA: - pass - - else: - if parameters['projection_data'] != None and \ - parameters['dark_field'] != None and \ - parameters['flat_field'] != None: - norm = self.normalize(parameters['projection_data'], - parameters['dark_field'], - parameters['flat_field'], 0.1) - self.pars['normalized_projection'] = norm - return self._reconstruct(parameters) - - - - def _normalize(self, projection, dark, flat, def_val=0): - a = (projection - dark) - b = (flat-dark) - with numpy.errstate(divide='ignore', invalid='ignore'): - c = numpy.true_divide( a, b ) - c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 - return c - - def normalize(self, projections, dark, flat, def_val=0): - norm = [self._normalize(projection, dark, flat, def_val) for projection in projections] - return numpy.asarray (norm, dtype=numpy.float32) - - - -class FISTA(): - '''FISTA-based reconstruction algorithm using ASTRA-toolbox - - ''' - # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> - # ___Input___: - # params.[] file: - # - .proj_geom (geometry of the projector) [required] - # - .vol_geom (geometry of the reconstructed object) [required] - # - .sino (vectorized in 2D or 3D sinogram) [required] - # - .iterFISTA (iterations for the main loop, default 40) - # - .L_const (Lipschitz constant, default Power method) ) - # - .X_ideal (ideal image, if given) - # - .weights (statisitcal weights, size of the sinogram) - # - .ROI (Region-of-interest, only if X_ideal is given) - # - .initialize (a 'warm start' using SIRT method from ASTRA) - #----------------Regularization choices------------------------ - # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) - # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) - # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) - # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) - # - .Regul_Iterations (iterations for the selected penalty, default 25) - # - .Regul_tauLLT (time step parameter for LLT term) - # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) - # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) - #----------------Visualization parameters------------------------ - # - .show (visualize reconstruction 1/0, (0 default)) - # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) - # - .slice (for 3D volumes - slice number to imshow) - # ___Output___: - # 1. X - reconstructed image/volume - # 2. output - a structure with - # - .Resid_error - residual error (if X_ideal is given) - # - .objective: value of the objective function - # - .L_const: Lipshitz constant to avoid recalculations - - # References: - # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse - # Problems" by A. Beck and M Teboulle - # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo - # 3. "A novel tomographic reconstruction method based on the robust - # Student's t function for suppressing data outliers" D. Kazantsev et.al. - # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): - self.params = dict() - self.params['projector_geometry'] = projector_geometry - self.params['output_geometry'] = output_geometry - self.params['input_sinogram'] = input_sinogram - detectors, nangles, sliceZ = numpy.shape(input_sinogram) - self.params['detectors'] = detectors - self.params['number_og_angles'] = nangles - self.params['SlicesZ'] = sliceZ - - # Accepted input keywords - kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , - 'weights' , 'region_of_interest' , 'initialize' , - 'regularizer' , - 'ring_lambda_R_L1', - 'ring_alpha') - - # handle keyworded parameters - if kwargs is not None: - for key, value in kwargs.items(): - if key in kw: - #print("{0} = {1}".format(key, value)) - self.pars[key] = value - - # set the default values for the parameters if not set - if 'number_of_iterations' in kwargs.keys(): - self.pars['number_of_iterations'] = kwargs['number_of_iterations'] - else: - self.pars['number_of_iterations'] = 40 - if 'weights' in kwargs.keys(): - self.pars['weights'] = kwargs['weights'] - else: - self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) - if 'Lipschitz_constant' in kwargs.keys(): - self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] - else: - self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() - - if not self.pars['ideal_image'] in kwargs.keys(): - self.pars['ideal_image'] = None - - if not self.pars['region_of_interest'] : - if self.pars['ideal_image'] == None: - pass - else: - self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) - - if not self.pars['regularizer'] : - self.pars['regularizer'] = None - else: - # the regularizer must be a correctly instantiated object - if not self.pars['ring_lambda_R_L1']: - self.pars['ring_lambda_R_L1'] = 0 - if not self.pars['ring_alpha']: - self.pars['ring_alpha'] = 1 - - - - - def calculateLipschitzConstantWithPowerMethod(self): - ''' using Power method (PM) to establish L constant''' - - #N = params.vol_geom.GridColCount - N = self.pars['output_geometry'].GridColCount - proj_geom = self.params['projector_geometry'] - vol_geom = self.params['output_geometry'] - weights = self.pars['weights'] - SlicesZ = self.pars['SlicesZ'] - - if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): - #% for parallel geometry we can do just one slice - #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); - niter = 15;# % number of iteration for the PM - #N = params.vol_geom.GridColCount; - #x1 = rand(N,N,1); - x1 = numpy.random.rand(1,N,N) - #sqweight = sqrt(weights(:,:,1)); - sqweight = numpy.sqrt(weights.T[0]) - proj_geomT = proj_geom.copy(); - proj_geomT.DetectorRowCount = 1; - vol_geomT = vol_geom.copy(); - vol_geomT['GridSliceCount'] = 1; - - - for i in range(niter): - if i == 0: - #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); - y = sqweight * y # element wise multiplication - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id) - - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - y = sqweight*y; - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - del proj_geomT - del vol_geomT - else: - #% divergen beam geometry - #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); - niter = 8; #% number of iteration for PM - x1 = numpy.random.rand(SlicesZ , N , N); - #sqweight = sqrt(weights); - sqweight = numpy.sqrt(weights.T[0]) - - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id); - - for i in range(niter): - #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, - proj_geom, - vol_geom) - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); - sino_id, y = astra.creators.create_sino3d_gpu(x1, - proj_geom, - vol_geom); - - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - #astra_mex_data3d('delete', id); - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - #clear x1 - del x1 - - return s - - - def setRegularizer(self, regularizer): - #if regularizer - self.pars['regularizer'] = regularizer - - - - - -def getEntry(location): - for item in nx[location].keys(): - print (item) - - -print ("Loading Data") - -##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" -####ind = [i * 1049 for i in range(360)] -#### use only 360 images -##images = 200 -##ind = [int(i * 1049 / images) for i in range(images)] -##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) - -#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" -fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" -nx = h5py.File(fname, "r") - -# the data are stored in a particular location in the hdf5 -for item in nx['entry1/tomo_entry/data'].keys(): - print (item) - -data = nx.get('entry1/tomo_entry/data/rotation_angle') -angles = numpy.zeros(data.shape) -data.read_direct(angles) -print (angles) -# angles should be in degrees - -data = nx.get('entry1/tomo_entry/data/data') -stack = numpy.zeros(data.shape) -data.read_direct(stack) -print (data.shape) - -print ("Data Loaded") - - -# Normalize -data = nx.get('entry1/tomo_entry/instrument/detector/image_key') -itype = numpy.zeros(data.shape) -data.read_direct(itype) -# 2 is dark field -darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] -dark = darks[0] -for i in range(1, len(darks)): - dark += darks[i] -dark = dark / len(darks) -#dark[0][0] = dark[0][1] - -# 1 is flat field -flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] -flat = flats[0] -for i in range(1, len(flats)): - flat += flats[i] -flat = flat / len(flats) -#flat[0][0] = dark[0][1] - - -# 0 is projection data -proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = numpy.asarray (angle_proj) -angle_proj = angle_proj.astype(numpy.float32) - -# normalized data are -# norm = (projection - dark)/(flat-dark) - -def normalize(projection, dark, flat, def_val=0.1): - a = (projection - dark) - b = (flat-dark) - with numpy.errstate(divide='ignore', invalid='ignore'): - c = numpy.true_divide( a, b ) - c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 - return c - - -norm = [normalize(projection, dark, flat) for projection in proj] -norm = numpy.asarray (norm) -norm = norm.astype(numpy.float32) - -#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, -# angles = angle_proj, center_of_rotation = 86.2 , -# flat_field = flat, dark_field = dark, -# iterations = 15, resolution = 1, isLogScale = False, threads = 3) - -#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, -# angles = angle_proj, center_of_rotation = 86.2 , -# flat_field = flat, dark_field = dark, -# iterations = 15, resolution = 1, isLogScale = False, threads = 3) -#img_cgls = recon.reconstruct() -# -#pars = dict() -#pars['algorithm'] = Reconstructor.Algorithm.SIRT -#pars['projection_data'] = proj -#pars['angles'] = angle_proj -#pars['center_of_rotation'] = numpy.double(86.2) -#pars['flat_field'] = flat -#pars['iterations'] = 15 -#pars['dark_field'] = dark -#pars['resolution'] = 1 -#pars['isLogScale'] = False -#pars['threads'] = 3 -# -#img_sirt = recon.reconstruct(pars) -# -#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM -#img_mlem = recon.reconstruct() - -############################################################ -############################################################ -#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV -#recon.pars['regularize'] = numpy.double(0.1) -#img_cgls_conv = recon.reconstruct() - -niterations = 15 -threads = 3 - -img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - iteration_values, False) -print ("iteration values %s" % str(iteration_values)) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) -iteration_values = numpy.zeros((niterations,)) -img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) - - -##numpy.save("cgls_recon.npy", img_data) -import matplotlib.pyplot as plt -fig, ax = plt.subplots(1,6,sharey=True) -ax[0].imshow(img_cgls[80]) -ax[0].axis('off') # clear x- and y-axes -ax[1].imshow(img_sirt[80]) -ax[1].axis('off') # clear x- and y-axes -ax[2].imshow(img_mlem[80]) -ax[2].axis('off') # clear x- and y-axesplt.show() -ax[3].imshow(img_cgls_conv[80]) -ax[3].axis('off') # clear x- and y-axesplt.show() -ax[4].imshow(img_cgls_tikhonov[80]) -ax[4].axis('off') # clear x- and y-axesplt.show() -ax[5].imshow(img_cgls_TVreg[80]) -ax[5].axis('off') # clear x- and y-axesplt.show() - - -plt.show() - -#viewer = edo.CILViewer() -#viewer.setInputAsNumpy(img_cgls2) -#viewer.displaySliceActor(0) -#viewer.startRenderLoop() - -import vtk - -def NumpyToVTKImageData(numpyarray): - if (len(numpy.shape(numpyarray)) == 3): - doubleImg = vtk.vtkImageData() - shape = numpy.shape(numpyarray) - doubleImg.SetDimensions(shape[0], shape[1], shape[2]) - doubleImg.SetOrigin(0,0,0) - doubleImg.SetSpacing(1,1,1) - doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) - #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) - doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) - - for i in range(shape[0]): - for j in range(shape[1]): - for k in range(shape[2]): - doubleImg.SetScalarComponentFromDouble( - i,j,k,0, numpyarray[i][j][k]) - #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) - # rescale to appropriate VTK_UNSIGNED_SHORT - stats = vtk.vtkImageAccumulate() - stats.SetInputData(doubleImg) - stats.Update() - iMin = stats.GetMin()[0] - iMax = stats.GetMax()[0] - scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) - - shiftScaler = vtk.vtkImageShiftScale () - shiftScaler.SetInputData(doubleImg) - shiftScaler.SetScale(scale) - shiftScaler.SetShift(iMin) - shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) - shiftScaler.Update() - return shiftScaler.GetOutput() - -#writer = vtk.vtkMetaImageWriter() -#writer.SetFileName(alg + "_recon.mha") -#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) -#writer.Write() |