summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python/ccpi
diff options
context:
space:
mode:
Diffstat (limited to 'Wrappers/Python/ccpi')
-rw-r--r--Wrappers/Python/ccpi/reconstruction/AstraDevice.py95
-rw-r--r--Wrappers/Python/ccpi/reconstruction/DeviceModel.py63
-rw-r--r--Wrappers/Python/ccpi/reconstruction/FISTAReconstructor.py882
-rw-r--r--Wrappers/Python/ccpi/reconstruction/Reconstructor.py598
4 files changed, 1638 insertions, 0 deletions
diff --git a/Wrappers/Python/ccpi/reconstruction/AstraDevice.py b/Wrappers/Python/ccpi/reconstruction/AstraDevice.py
new file mode 100644
index 0000000..57435f8
--- /dev/null
+++ b/Wrappers/Python/ccpi/reconstruction/AstraDevice.py
@@ -0,0 +1,95 @@
+import astra
+from ccpi.reconstruction.DeviceModel import DeviceModel
+import numpy
+
+class AstraDevice(DeviceModel):
+ '''Concrete class for Astra Device'''
+
+ def __init__(self,
+ device_type,
+ data_aquisition_geometry,
+ reconstructed_volume_geometry):
+
+ super(AstraDevice, self).__init__(device_type,
+ data_aquisition_geometry,
+ reconstructed_volume_geometry)
+
+ self.type = device_type
+ self.proj_geom = astra.creators.create_proj_geom(
+ device_type,
+ self.acquisition_data_geometry['detectorSpacingX'],
+ self.acquisition_data_geometry['detectorSpacingY'],
+ self.acquisition_data_geometry['cameraX'],
+ self.acquisition_data_geometry['cameraY'],
+ self.acquisition_data_geometry['angles'],
+ )
+
+ self.vol_geom = astra.creators.create_vol_geom(
+ self.reconstructed_volume_geometry['X'],
+ self.reconstructed_volume_geometry['Y'],
+ self.reconstructed_volume_geometry['Z']
+ )
+
+ def doForwardProject(self, volume):
+ '''Forward projects the volume according to the device geometry
+
+Uses Astra-toolbox
+'''
+
+ try:
+ sino_id, y = astra.creators.create_sino3d_gpu(
+ volume, self.proj_geom, self.vol_geom)
+ astra.matlab.data3d('delete', sino_id)
+ return y
+ except Exception as e:
+ print(e)
+ print("Value Error: ", self.proj_geom, self.vol_geom)
+
+ def doBackwardProject(self, projections):
+ '''Backward projects the projections according to the device geometry
+
+Uses Astra-toolbox
+'''
+ idx, volume = \
+ astra.creators.create_backprojection3d_gpu(
+ projections,
+ self.proj_geom,
+ self.vol_geom)
+
+ astra.matlab.data3d('delete', idx)
+ return volume
+
+ def createReducedDevice(self, proj_par={'cameraY' : 1} , vol_par={'Z':1}):
+ '''Create a new device based on the current device by changing some parameter
+
+VERY RISKY'''
+ acquisition_data_geometry = self.acquisition_data_geometry.copy()
+ for k,v in proj_par.items():
+ if k in acquisition_data_geometry.keys():
+ acquisition_data_geometry[k] = v
+ proj_geom = [
+ acquisition_data_geometry['cameraX'],
+ acquisition_data_geometry['cameraY'],
+ acquisition_data_geometry['detectorSpacingX'],
+ acquisition_data_geometry['detectorSpacingY'],
+ acquisition_data_geometry['angles']
+ ]
+
+ reconstructed_volume_geometry = self.reconstructed_volume_geometry.copy()
+ for k,v in vol_par.items():
+ if k in reconstructed_volume_geometry.keys():
+ reconstructed_volume_geometry[k] = v
+
+ vol_geom = [
+ reconstructed_volume_geometry['X'],
+ reconstructed_volume_geometry['Y'],
+ reconstructed_volume_geometry['Z']
+ ]
+ return AstraDevice(self.type, proj_geom, vol_geom)
+
+
+
+if __name__=="main":
+ a = AstraDevice()
+
+
diff --git a/Wrappers/Python/ccpi/reconstruction/DeviceModel.py b/Wrappers/Python/ccpi/reconstruction/DeviceModel.py
new file mode 100644
index 0000000..eeb9a34
--- /dev/null
+++ b/Wrappers/Python/ccpi/reconstruction/DeviceModel.py
@@ -0,0 +1,63 @@
+from abc import ABCMeta, abstractmethod
+from enum import Enum
+
+class DeviceModel(metaclass=ABCMeta):
+ '''Abstract class that defines the device for projection and backprojection
+
+This class defines the methods that must be implemented by concrete classes.
+
+ '''
+
+ class DeviceType(Enum):
+ '''Type of device
+PARALLEL BEAM
+PARALLEL BEAM 3D
+CONE BEAM
+HELICAL'''
+
+ PARALLEL = 'parallel'
+ PARALLEL3D = 'parallel3d'
+ CONE_BEAM = 'cone-beam'
+ HELICAL = 'helical'
+
+ def __init__(self,
+ device_type,
+ data_aquisition_geometry,
+ reconstructed_volume_geometry):
+ '''Initializes the class
+
+Mandatory parameters are:
+device_type from DeviceType Enum
+data_acquisition_geometry: tuple (camera_X, camera_Y, detectorSpacingX,
+ detectorSpacingY, angles)
+reconstructed_volume_geometry: tuple (dimX,dimY,dimZ)
+'''
+ self.device_geometry = device_type
+ self.acquisition_data_geometry = {
+ 'cameraX': data_aquisition_geometry[0],
+ 'cameraY': data_aquisition_geometry[1],
+ 'detectorSpacingX' : data_aquisition_geometry[2],
+ 'detectorSpacingY' : data_aquisition_geometry[3],
+ 'angles' : data_aquisition_geometry[4],}
+ self.reconstructed_volume_geometry = {
+ 'X': reconstructed_volume_geometry[0] ,
+ 'Y': reconstructed_volume_geometry[1] ,
+ 'Z': reconstructed_volume_geometry[2] }
+
+ @abstractmethod
+ def doForwardProject(self, volume):
+ '''Forward projects the volume according to the device geometry'''
+ return NotImplemented
+
+
+ @abstractmethod
+ def doBackwardProject(self, projections):
+ '''Backward projects the projections according to the device geometry'''
+ return NotImplemented
+
+ @abstractmethod
+ def createReducedDevice(self):
+ '''Create a Device to do forward/backward projections on 2D slices'''
+ return NotImplemented
+
+
diff --git a/Wrappers/Python/ccpi/reconstruction/FISTAReconstructor.py b/Wrappers/Python/ccpi/reconstruction/FISTAReconstructor.py
new file mode 100644
index 0000000..e40ad24
--- /dev/null
+++ b/Wrappers/Python/ccpi/reconstruction/FISTAReconstructor.py
@@ -0,0 +1,882 @@
+# -*- coding: utf-8 -*-
+###############################################################################
+#This work is part of the Core Imaging Library developed by
+#Visual Analytics and Imaging System Group of the Science Technology
+#Facilities Council, STFC
+#
+#Copyright 2017 Edoardo Pasca, Srikanth Nagella
+#Copyright 2017 Daniil Kazantsev
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#http://www.apache.org/licenses/LICENSE-2.0
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+###############################################################################
+
+
+
+import numpy
+#from ccpi.reconstruction.parallelbeam import alg
+
+#from ccpi.imaging.Regularizer import Regularizer
+from enum import Enum
+
+import astra
+from ccpi.reconstruction.AstraDevice import AstraDevice
+
+
+
+class FISTAReconstructor():
+ '''FISTA-based reconstruction algorithm using ASTRA-toolbox
+
+ '''
+ # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>>
+ # ___Input___:
+ # params.[] file:
+ # - .proj_geom (geometry of the projector) [required]
+ # - .vol_geom (geometry of the reconstructed object) [required]
+ # - .sino (vectorized in 2D or 3D sinogram) [required]
+ # - .iterFISTA (iterations for the main loop, default 40)
+ # - .L_const (Lipschitz constant, default Power method) )
+ # - .X_ideal (ideal image, if given)
+ # - .weights (statisitcal weights, size of the sinogram)
+ # - .ROI (Region-of-interest, only if X_ideal is given)
+ # - .initialize (a 'warm start' using SIRT method from ASTRA)
+ #----------------Regularization choices------------------------
+ # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter)
+ # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter)
+ # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter)
+ # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04)
+ # - .Regul_Iterations (iterations for the selected penalty, default 25)
+ # - .Regul_tauLLT (time step parameter for LLT term)
+ # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal)
+ # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1)
+ #----------------Visualization parameters------------------------
+ # - .show (visualize reconstruction 1/0, (0 default))
+ # - .maxvalplot (maximum value to use for imshow[0 maxvalplot])
+ # - .slice (for 3D volumes - slice number to imshow)
+ # ___Output___:
+ # 1. X - reconstructed image/volume
+ # 2. output - a structure with
+ # - .Resid_error - residual error (if X_ideal is given)
+ # - .objective: value of the objective function
+ # - .L_const: Lipshitz constant to avoid recalculations
+
+ # References:
+ # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse
+ # Problems" by A. Beck and M Teboulle
+ # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo
+ # 3. "A novel tomographic reconstruction method based on the robust
+ # Student's t function for suppressing data outliers" D. Kazantsev et.al.
+ # D. Kazantsev, 2016-17
+ def __init__(self, projector_geometry,
+ output_geometry,
+ input_sinogram,
+ device,
+ **kwargs):
+ # handle parmeters:
+ # obligatory parameters
+ self.pars = dict()
+ self.pars['projector_geometry'] = projector_geometry # proj_geom
+ self.pars['output_geometry'] = output_geometry # vol_geom
+ self.pars['input_sinogram'] = input_sinogram # sino
+ sliceZ, nangles, detectors = numpy.shape(input_sinogram)
+ self.pars['detectors'] = detectors
+ self.pars['number_of_angles'] = nangles
+ self.pars['SlicesZ'] = sliceZ
+ self.pars['output_volume'] = None
+ self.pars['device_model'] = device
+
+ self.use_device = True
+
+ print (self.pars)
+ # handle optional input parameters (at instantiation)
+
+ # Accepted input keywords
+ kw = (
+ # mandatory fields
+ 'projector_geometry',
+ 'output_geometry',
+ 'input_sinogram',
+ 'detectors',
+ 'number_of_angles',
+ 'SlicesZ',
+ # optional fields
+ 'number_of_iterations',
+ 'Lipschitz_constant' ,
+ 'ideal_image' ,
+ 'weights' ,
+ 'region_of_interest' ,
+ 'initialize' ,
+ 'regularizer' ,
+ 'ring_lambda_R_L1',
+ 'ring_alpha',
+ 'subsets',
+ 'output_volume',
+ 'os_subsets',
+ 'os_indices',
+ 'os_bins',
+ 'device_model',
+ 'reduced_device_model')
+ self.acceptedInputKeywords = list(kw)
+
+ # handle keyworded parameters
+ if kwargs is not None:
+ for key, value in kwargs.items():
+ if key in kw:
+ #print("{0} = {1}".format(key, value))
+ self.pars[key] = value
+
+ # set the default values for the parameters if not set
+ if 'number_of_iterations' in kwargs.keys():
+ self.pars['number_of_iterations'] = kwargs['number_of_iterations']
+ else:
+ self.pars['number_of_iterations'] = 40
+ if 'weights' in kwargs.keys():
+ self.pars['weights'] = kwargs['weights']
+ else:
+ self.pars['weights'] = \
+ numpy.ones(numpy.shape(
+ self.pars['input_sinogram']))
+ if 'Lipschitz_constant' in kwargs.keys():
+ self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant']
+ else:
+ self.pars['Lipschitz_constant'] = None
+
+ if not 'ideal_image' in kwargs.keys():
+ self.pars['ideal_image'] = None
+
+ if not 'region_of_interest'in kwargs.keys() :
+ if self.pars['ideal_image'] == None:
+ self.pars['region_of_interest'] = None
+ else:
+ ## nonzero if the image is larger than m
+ fsm = numpy.frompyfunc(lambda x,m: 1 if x>m else 0, 2,1)
+
+ self.pars['region_of_interest'] = fsm(self.pars['ideal_image'], 0)
+
+ # the regularizer must be a correctly instantiated object
+ if not 'regularizer' in kwargs.keys() :
+ self.pars['regularizer'] = None
+
+ #RING REMOVAL
+ if not 'ring_lambda_R_L1' in kwargs.keys():
+ self.pars['ring_lambda_R_L1'] = 0
+ if not 'ring_alpha' in kwargs.keys():
+ self.pars['ring_alpha'] = 1
+
+ # ORDERED SUBSET
+ if not 'subsets' in kwargs.keys():
+ self.pars['subsets'] = 0
+ else:
+ self.createOrderedSubsets()
+
+ if not 'initialize' in kwargs.keys():
+ self.pars['initialize'] = False
+
+ reduced_device = device.createReducedDevice()
+ self.setParameter(reduced_device_model=reduced_device)
+
+
+
+ def setParameter(self, **kwargs):
+ '''set named parameter for the reconstructor engine
+
+ raises Exception if the named parameter is not recognized
+
+ '''
+ for key , value in kwargs.items():
+ if key in self.acceptedInputKeywords:
+ self.pars[key] = value
+ else:
+ raise Exception('Wrong parameter {0} for '.format(key) +
+ 'reconstructor')
+ # setParameter
+
+ def getParameter(self, key):
+ if type(key) is str:
+ if key in self.acceptedInputKeywords:
+ return self.pars[key]
+ else:
+ raise Exception('Unrecongnised parameter: {0} '.format(key) )
+ elif type(key) is list:
+ outpars = []
+ for k in key:
+ outpars.append(self.getParameter(k))
+ return outpars
+ else:
+ raise Exception('Unhandled input {0}' .format(str(type(key))))
+
+
+ def calculateLipschitzConstantWithPowerMethod(self):
+ ''' using Power method (PM) to establish L constant'''
+
+ N = self.pars['output_geometry']['GridColCount']
+ proj_geom = self.pars['projector_geometry']
+ vol_geom = self.pars['output_geometry']
+ weights = self.pars['weights']
+ SlicesZ = self.pars['SlicesZ']
+
+
+
+ if (proj_geom['type'] == 'parallel') or \
+ (proj_geom['type'] == 'parallel3d'):
+ #% for parallel geometry we can do just one slice
+ #print('Calculating Lipshitz constant for parallel beam geometry...')
+ niter = 5;# % number of iteration for the PM
+ #N = params.vol_geom.GridColCount;
+ #x1 = rand(N,N,1);
+ x1 = numpy.random.rand(1,N,N)
+ #sqweight = sqrt(weights(:,:,1));
+ sqweight = numpy.sqrt(weights[0:1,:,:])
+ proj_geomT = proj_geom.copy();
+ proj_geomT['DetectorRowCount'] = 1;
+ vol_geomT = vol_geom.copy();
+ vol_geomT['GridSliceCount'] = 1;
+
+ #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+
+
+ for i in range(niter):
+ # [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT);
+ # s = norm(x1(:));
+ # x1 = x1/s;
+ # [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+ # y = sqweight.*y;
+ # astra_mex_data3d('delete', sino_id);
+ # astra_mex_data3d('delete', id);
+ #print ("iteration {0}".format(i))
+
+ sino_id, y = astra.creators.create_sino3d_gpu(x1,
+ proj_geomT,
+ vol_geomT)
+
+ y = (sqweight * y).copy() # element wise multiplication
+
+ #b=fig.add_subplot(2,1,2)
+ #imgplot = plt.imshow(x1[0])
+ #plt.show()
+
+ #astra_mex_data3d('delete', sino_id);
+ astra.matlab.data3d('delete', sino_id)
+ del x1
+
+ idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(),
+ proj_geomT,
+ vol_geomT)
+ del y
+
+
+ s = numpy.linalg.norm(x1)
+ ### this line?
+ x1 = (x1/s).copy();
+
+ # ### this line?
+ # sino_id, y = astra.creators.create_sino3d_gpu(x1,
+ # proj_geomT,
+ # vol_geomT);
+ # y = sqweight * y;
+ astra.matlab.data3d('delete', sino_id);
+ astra.matlab.data3d('delete', idx)
+ print ("iteration {0} s= {1}".format(i,s))
+
+ #end
+ del proj_geomT
+ del vol_geomT
+ #plt.show()
+ else:
+ #% divergen beam geometry
+ print('Calculating Lipshitz constant for divergen beam geometry...')
+ niter = 8; #% number of iteration for PM
+ x1 = numpy.random.rand(SlicesZ , N , N);
+ #sqweight = sqrt(weights);
+ sqweight = numpy.sqrt(weights[0])
+
+ sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom);
+ y = sqweight*y;
+ #astra_mex_data3d('delete', sino_id);
+ astra.matlab.data3d('delete', sino_id);
+
+ for i in range(niter):
+ #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom);
+ idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y,
+ proj_geom,
+ vol_geom)
+ s = numpy.linalg.norm(x1)
+ ### this line?
+ x1 = x1/s;
+ ### this line?
+ #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom);
+ sino_id, y = astra.creators.create_sino3d_gpu(x1,
+ proj_geom,
+ vol_geom);
+
+ y = sqweight*y;
+ #astra_mex_data3d('delete', sino_id);
+ #astra_mex_data3d('delete', id);
+ astra.matlab.data3d('delete', sino_id);
+ astra.matlab.data3d('delete', idx);
+ #end
+ #clear x1
+ del x1
+
+
+ return s
+
+
+ def setRegularizer(self, regularizer):
+ if regularizer is not None:
+ self.pars['regularizer'] = regularizer
+
+
+ def initialize(self):
+ # convenience variable storage
+ proj_geom = self.pars['projector_geometry']
+ vol_geom = self.pars['output_geometry']
+ sino = self.pars['input_sinogram']
+
+ # a 'warm start' with SIRT method
+ # Create a data object for the reconstruction
+ rec_id = astra.matlab.data3d('create', '-vol',
+ vol_geom);
+
+ #sinogram_id = astra_mex_data3d('create', '-proj3d', proj_geom, sino);
+ sinogram_id = astra.matlab.data3d('create', '-proj3d',
+ proj_geom,
+ sino)
+
+ sirt_config = astra.astra_dict('SIRT3D_CUDA')
+ sirt_config['ReconstructionDataId' ] = rec_id
+ sirt_config['ProjectionDataId'] = sinogram_id
+
+ sirt = astra.algorithm.create(sirt_config)
+ astra.algorithm.run(sirt, iterations=35)
+ X = astra.matlab.data3d('get', rec_id)
+
+ # clean up memory
+ astra.matlab.data3d('delete', rec_id)
+ astra.matlab.data3d('delete', sinogram_id)
+ astra.algorithm.delete(sirt)
+
+
+
+ return X
+
+ def createOrderedSubsets(self, subsets=None):
+ if subsets is None:
+ try:
+ subsets = self.getParameter('subsets')
+ except Exception():
+ subsets = 0
+ #return subsets
+ else:
+ self.setParameter(subsets=subsets)
+
+
+ angles = self.getParameter('projector_geometry')['ProjectionAngles']
+
+ #binEdges = numpy.linspace(angles.min(),
+ # angles.max(),
+ # subsets + 1)
+ binsDiscr, binEdges = numpy.histogram(angles, bins=subsets)
+ # get rearranged subset indices
+ IndicesReorg = numpy.zeros((numpy.shape(angles)), dtype=numpy.int32)
+ counterM = 0
+ for ii in range(binsDiscr.max()):
+ counter = 0
+ for jj in range(subsets):
+ curr_index = ii + jj + counter
+ #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM))
+ if binsDiscr[jj] > ii:
+ if (counterM < numpy.size(IndicesReorg)):
+ IndicesReorg[counterM] = curr_index
+ counterM = counterM + 1
+
+ counter = counter + binsDiscr[jj] - 1
+
+ # store the OS in parameters
+ self.setParameter(os_subsets=subsets,
+ os_bins=binsDiscr,
+ os_indices=IndicesReorg)
+
+
+ def prepareForIteration(self):
+ print ("FISTA Reconstructor: prepare for iteration")
+
+ self.residual_error = numpy.zeros((self.pars['number_of_iterations']))
+ self.objective = numpy.zeros((self.pars['number_of_iterations']))
+
+ #2D array (for 3D data) of sparse "ring"
+ detectors, nangles, sliceZ = numpy.shape(self.pars['input_sinogram'])
+ self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float)
+ # another ring variable
+ self.r_x = self.r.copy()
+
+ self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram']))
+
+ if self.getParameter('Lipschitz_constant') is None:
+ self.pars['Lipschitz_constant'] = \
+ self.calculateLipschitzConstantWithPowerMethod()
+ # errors vector (if the ground truth is given)
+ self.Resid_error = numpy.zeros((self.getParameter('number_of_iterations')));
+ # objective function values vector
+ self.objective = numpy.zeros((self.getParameter('number_of_iterations')));
+
+
+ # prepareForIteration
+
+ def iterate (self, Xin=None):
+ if self.getParameter('subsets') == 0:
+ return self.iterateStandard(Xin)
+ else:
+ return self.iterateOrderedSubsets(Xin)
+
+ def iterateStandard(self, Xin=None):
+ print ("FISTA Reconstructor: iterate")
+
+ if Xin is None:
+ if self.getParameter('initialize'):
+ X = self.initialize()
+ else:
+ N = vol_geom['GridColCount']
+ X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float)
+ else:
+ # copy by reference
+ X = Xin
+ # store the output volume in the parameters
+ self.setParameter(output_volume=X)
+ X_t = X.copy()
+ # convenience variable storage
+ proj_geom , vol_geom, sino , \
+ SlicesZ , ring_lambda_R_L1 , weights = \
+ self.getParameter([ 'projector_geometry' ,
+ 'output_geometry',
+ 'input_sinogram',
+ 'SlicesZ' ,
+ 'ring_lambda_R_L1',
+ 'weights'])
+
+ t = 1
+
+ device = self.getParameter('device_model')
+ reduced_device = self.getParameter('reduced_device_model')
+
+ for i in range(self.getParameter('number_of_iterations')):
+ print("iteration", i)
+ X_old = X.copy()
+ t_old = t
+ r_old = self.r.copy()
+ pg = self.getParameter('projector_geometry')['type']
+ if pg == 'parallel' or \
+ pg == 'fanflat' or \
+ pg == 'fanflat_vec':
+ # if the geometry is parallel use slice-by-slice
+ # projection-backprojection routine
+ #sino_updt = zeros(size(sino),'single');
+
+ if self.use_device :
+ self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float)
+
+ for kkk in range(SlicesZ):
+ self.sino_updt[kkk] = \
+ reduced_device.doForwardProject( X_t[kkk:kkk+1] )
+ else:
+ proj_geomT = proj_geom.copy()
+ proj_geomT['DetectorRowCount'] = 1
+ vol_geomT = vol_geom.copy()
+ vol_geomT['GridSliceCount'] = 1;
+ self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float)
+ for kkk in range(SlicesZ):
+ sino_id, self.sino_updt[kkk] = \
+ astra.creators.create_sino3d_gpu(
+ X_t[kkk:kkk+1], proj_geomT, vol_geomT)
+ astra.matlab.data3d('delete', sino_id)
+ else:
+ # for divergent 3D geometry (watch the GPU memory overflow in
+ # ASTRA versions < 1.8)
+ #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom);
+
+ if self.use_device:
+ self.sino_updt = device.doForwardProject(X_t)
+ else:
+ sino_id, self.sino_updt = astra.creators.create_sino3d_gpu(
+ X_t, proj_geom, vol_geom)
+ astra.matlab.data3d('delete', sino_id)
+
+
+ ## RING REMOVAL
+ if ring_lambda_R_L1 != 0:
+ self.ringRemoval(i)
+ else:
+ self.residual = weights * (self.sino_updt - sino)
+ self.objective[i] = 0.5 * numpy.linalg.norm(self.residual)
+ #objective(i) = 0.5*norm(residual(:)); % for the objective function output
+ ## Projection/Backprojection Routine
+ X, X_t = self.projectionBackprojection(X, X_t)
+
+ ## REGULARIZATION
+ Y = self.regularize(X)
+ X = Y.copy()
+ ## Update Loop
+ X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old)
+
+ print ("t" , t)
+ print ("X min {0} max {1}".format(X_t.min(),X_t.max()))
+ self.setParameter(output_volume=X)
+ return X
+ ## iterate
+
+ def ringRemoval(self, i):
+ print ("FISTA Reconstructor: ring removal")
+ residual = self.residual
+ lambdaR_L1 , alpha_ring , weights , L_const , sino= \
+ self.getParameter(['ring_lambda_R_L1',
+ 'ring_alpha' , 'weights',
+ 'Lipschitz_constant',
+ 'input_sinogram'])
+ r_x = self.r_x
+ sino_updt = self.sino_updt
+
+ SlicesZ, anglesNumb, Detectors = \
+ numpy.shape(self.getParameter('input_sinogram'))
+ if lambdaR_L1 > 0 :
+ for kkk in range(anglesNumb):
+
+ residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \
+ ((sino_updt[:,kkk,:]).squeeze() - \
+ (sino[:,kkk,:]).squeeze() -\
+ (alpha_ring * r_x)
+ )
+ vec = residual.sum(axis = 1)
+ #if SlicesZ > 1:
+ # vec = vec[:,1,:].squeeze()
+ self.r = (r_x - (1./L_const) * vec).copy()
+ self.objective[i] = (0.5 * (residual ** 2).sum())
+
+ def projectionBackprojection(self, X, X_t):
+ print ("FISTA Reconstructor: projection-backprojection routine")
+
+ # a few useful variables
+ SlicesZ, anglesNumb, Detectors = \
+ numpy.shape(self.getParameter('input_sinogram'))
+ residual = self.residual
+ proj_geom , vol_geom , L_const = \
+ self.getParameter(['projector_geometry' ,
+ 'output_geometry',
+ 'Lipschitz_constant'])
+
+ device, reduced_device = self.getParameter(['device_model',
+ 'reduced_device_model'])
+
+ if self.getParameter('projector_geometry')['type'] == 'parallel' or \
+ self.getParameter('projector_geometry')['type'] == 'fanflat' or \
+ self.getParameter('projector_geometry')['type'] == 'fanflat_vec':
+ # if the geometry is parallel use slice-by-slice
+ # projection-backprojection routine
+ #sino_updt = zeros(size(sino),'single');
+ x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32)
+
+ if self.use_device:
+ proj_geomT = proj_geom.copy()
+ proj_geomT['DetectorRowCount'] = 1
+ vol_geomT = vol_geom.copy()
+ vol_geomT['GridSliceCount'] = 1;
+
+ for kkk in range(SlicesZ):
+
+ x_id, x_temp[kkk] = \
+ astra.creators.create_backprojection3d_gpu(
+ residual[kkk:kkk+1],
+ proj_geomT, vol_geomT)
+ astra.matlab.data3d('delete', x_id)
+ else:
+ for kkk in range(SliceZ):
+ x_temp[kkk] = \
+ reduced_device.doBackwardProject(residual[kkk:kkk+1])
+ else:
+ if self.use_device:
+ x_id, x_temp = \
+ astra.creators.create_backprojection3d_gpu(
+ residual, proj_geom, vol_geom)
+ astra.matlab.data3d('delete', x_id)
+ else:
+ x_temp = \
+ device.doBackwardProject(residual)
+
+
+ X = X_t - (1/L_const) * x_temp
+ #astra.matlab.data3d('delete', sino_id)
+ return (X , X_t)
+
+
+ def regularize(self, X , output_all=False):
+ #print ("FISTA Reconstructor: regularize")
+
+ regularizer = self.getParameter('regularizer')
+ if regularizer is not None:
+ return regularizer(input=X,
+ output_all=output_all)
+ else:
+ return X
+
+ def updateLoop(self, i, X, X_old, r_old, t, t_old):
+ print ("FISTA Reconstructor: update loop")
+ lambdaR_L1 = self.getParameter('ring_lambda_R_L1')
+
+ t = (1 + numpy.sqrt(1 + 4 * t**2))/2
+ X_t = X + (((t_old -1)/t) * (X - X_old))
+
+ if lambdaR_L1 > 0:
+ self.r = numpy.max(
+ numpy.abs(self.r) - lambdaR_L1 , 0) * \
+ numpy.sign(self.r)
+ self.r_x = self.r + \
+ (((t_old-1)/t) * (self.r - r_old))
+
+ if self.getParameter('region_of_interest') is None:
+ string = 'Iteration Number {0} | Objective {1} \n'
+ print (string.format( i, self.objective[i]))
+ else:
+ ROI , X_ideal = fistaRecon.getParameter('region_of_interest',
+ 'ideal_image')
+
+ Resid_error[i] = RMSE(X*ROI, X_ideal*ROI)
+ string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n'
+ print (string.format(i,Resid_error[i], self.objective[i]))
+ return (X , X_t, t)
+
+ def iterateOrderedSubsets(self, Xin=None):
+ print ("FISTA Reconstructor: Ordered Subsets iterate")
+
+ if Xin is None:
+ if self.getParameter('initialize'):
+ X = self.initialize()
+ else:
+ N = vol_geom['GridColCount']
+ X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float)
+ else:
+ # copy by reference
+ X = Xin
+ # store the output volume in the parameters
+ self.setParameter(output_volume=X)
+ X_t = X.copy()
+
+ # some useful constants
+ proj_geom , vol_geom, sino , \
+ SlicesZ, weights , alpha_ring ,\
+ lambdaR_L1 , L_const , iterFISTA = self.getParameter(
+ ['projector_geometry' , 'output_geometry', 'input_sinogram',
+ 'SlicesZ' , 'weights', 'ring_alpha' ,
+ 'ring_lambda_R_L1', 'Lipschitz_constant',
+ 'number_of_iterations'])
+
+
+ # errors vector (if the ground truth is given)
+ Resid_error = numpy.zeros((iterFISTA));
+ # objective function values vector
+ #objective = numpy.zeros((iterFISTA));
+ objective = self.objective
+
+
+ t = 1
+
+ ## additional for
+ proj_geomSUB = proj_geom.copy()
+ self.residual2 = numpy.zeros(numpy.shape(sino))
+ residual2 = self.residual2
+ sino_updt_FULL = self.residual.copy()
+ r_x = self.r.copy()
+
+ print ("starting iterations")
+ ## % Outer FISTA iterations loop
+ for i in range(self.getParameter('number_of_iterations')):
+ # With OS approach it becomes trickier to correlate independent
+ # subsets, hence additional work is required one solution is to
+ # work with a full sinogram at times
+
+ r_old = self.r.copy()
+ t_old = t
+ SlicesZ, anglesNumb, Detectors = \
+ numpy.shape(self.getParameter('input_sinogram')) ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4
+ if (i > 1 and lambdaR_L1 > 0) :
+ for kkk in range(anglesNumb):
+
+ residual2[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \
+ ((sino_updt_FULL[:,kkk,:]).squeeze() - \
+ (sino[:,kkk,:]).squeeze() -\
+ (alpha_ring * r_x)
+ )
+
+ vec = self.residual.sum(axis = 1)
+ #if SlicesZ > 1:
+ # vec = vec[:,1,:] # 1 or 0?
+ r_x = self.r_x
+ # update ring variable
+ self.r = (r_x - (1./L_const) * vec).copy()
+
+ # subset loop
+ counterInd = 1
+ geometry_type = self.getParameter('projector_geometry')['type']
+ angles = self.getParameter('projector_geometry')['ProjectionAngles']
+
+ for ss in range(self.getParameter('subsets')):
+ #print ("Subset {0}".format(ss))
+ X_old = X.copy()
+ t_old = t
+
+ # the number of projections per subset
+ numProjSub = self.getParameter('os_bins')[ss]
+ CurrSubIndices = self.getParameter('os_indices')\
+ [counterInd:counterInd+numProjSub]
+ #print ("Len CurrSubIndices {0}".format(numProjSub))
+ mask = numpy.zeros(numpy.shape(angles), dtype=bool)
+ #cc = 0
+ for j in range(len(CurrSubIndices)):
+ mask[int(CurrSubIndices[j])] = True
+ proj_geomSUB['ProjectionAngles'] = angles[mask]
+
+ if self.use_device:
+ device = self.getParameter('device_model')\
+ .createReducedDevice(
+ proj_par={'angles':angles[mask]},
+ vol_par={})
+
+ shape = list(numpy.shape(self.getParameter('input_sinogram')))
+ shape[1] = numProjSub
+ sino_updt_Sub = numpy.zeros(shape)
+ if geometry_type == 'parallel' or \
+ geometry_type == 'fanflat' or \
+ geometry_type == 'fanflat_vec' :
+
+ for kkk in range(SlicesZ):
+ if self.use_device:
+ sinoT = device.doForwardProject(X_t[kkk:kkk+1])
+ else:
+ sino_id, sinoT = astra.creators.create_sino3d_gpu (
+ X_t[kkk:kkk+1] , proj_geomSUB, vol_geom)
+ astra.matlab.data3d('delete', sino_id)
+ sino_updt_Sub[kkk] = sinoT.T.copy()
+
+ else:
+ # for 3D geometry (watch the GPU memory overflow in
+ # ASTRA < 1.8)
+ if self.use_device:
+ sino_updt_Sub = device.doForwardProject(X_t)
+
+ else:
+ sino_id, sino_updt_Sub = \
+ astra.creators.create_sino3d_gpu (X_t, proj_geomSUB, vol_geom)
+
+ astra.matlab.data3d('delete', sino_id)
+
+ #print ("shape(sino_updt_Sub)",numpy.shape(sino_updt_Sub))
+ if lambdaR_L1 > 0 :
+ ## RING REMOVAL
+ #print ("ring removal")
+ residualSub , sino_updt_Sub, sino_updt_FULL = \
+ self.ringRemovalOrderedSubsets(ss,
+ counterInd,
+ sino_updt_Sub,
+ sino_updt_FULL)
+ else:
+ #PWLS model
+ #print ("PWLS model")
+ residualSub = weights[:,CurrSubIndices,:] * \
+ ( sino_updt_Sub - \
+ sino[:,CurrSubIndices,:].squeeze() )
+ objective[i] = 0.5 * numpy.linalg.norm(residualSub)
+
+ # projection/backprojection routine
+ if geometry_type == 'parallel' or \
+ geometry_type == 'fanflat' or \
+ geometry_type == 'fanflat_vec' :
+ # if geometry is 2D use slice-by-slice projection-backprojection
+ # routine
+ x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32)
+ for kkk in range(SlicesZ):
+ if self.use_device:
+ x_temp[kkk] = device.doBackwardProject(
+ residualSub[kkk:kkk+1])
+ else:
+ x_id, x_temp[kkk] = \
+ astra.creators.create_backprojection3d_gpu(
+ residualSub[kkk:kkk+1],
+ proj_geomSUB, vol_geom)
+ astra.matlab.data3d('delete', x_id)
+
+ else:
+ if self.use_device:
+ x_temp = device.doBackwardProject(
+ residualSub)
+ else:
+ x_id, x_temp = \
+ astra.creators.create_backprojection3d_gpu(
+ residualSub, proj_geomSUB, vol_geom)
+
+ astra.matlab.data3d('delete', x_id)
+
+ X = X_t - (1/L_const) * x_temp
+
+ ## REGULARIZATION
+ X = self.regularize(X)
+
+ ## Update subset Loop
+ t = (1 + numpy.sqrt(1 + 4 * t**2))/2
+ X_t = X + (((t_old -1)/t) * (X - X_old))
+ # FINAL
+ ## update iteration loop
+ if lambdaR_L1 > 0:
+ self.r = numpy.max(
+ numpy.abs(self.r) - lambdaR_L1 , 0) * \
+ numpy.sign(self.r)
+ self.r_x = self.r + \
+ (((t_old-1)/t) * (self.r - r_old))
+
+ if self.getParameter('region_of_interest') is None:
+ string = 'Iteration Number {0} | Objective {1} \n'
+ print (string.format( i, self.objective[i]))
+ else:
+ ROI , X_ideal = fistaRecon.getParameter('region_of_interest',
+ 'ideal_image')
+
+ Resid_error[i] = RMSE(X*ROI, X_ideal*ROI)
+ string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n'
+ print (string.format(i,Resid_error[i], self.objective[i]))
+ print("X min {0} max {1}".format(X.min(),X.max()))
+ self.setParameter(output_volume=X)
+ counterInd = counterInd + numProjSub
+
+ return X
+
+ def ringRemovalOrderedSubsets(self, ss,counterInd,
+ sino_updt_Sub, sino_updt_FULL):
+ residual = self.residual
+ r_x = self.r_x
+ weights , alpha_ring , sino = \
+ self.getParameter( ['weights', 'ring_alpha', 'input_sinogram'])
+ numProjSub = self.getParameter('os_bins')[ss]
+ CurrSubIndices = self.getParameter('os_indices')\
+ [counterInd:counterInd+numProjSub]
+
+ shape = list(numpy.shape(self.getParameter('input_sinogram')))
+ shape[1] = numProjSub
+
+ residualSub = numpy.zeros(shape)
+
+ for kkk in range(numProjSub):
+ #print ("ring removal indC ... {0}".format(kkk))
+ indC = int(CurrSubIndices[kkk])
+ residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \
+ (sino_updt_Sub[:,kkk,:].squeeze() - \
+ sino[:,indC,:].squeeze() - alpha_ring * r_x)
+ # filling the full sinogram
+ sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze()
+
+ return (residualSub , sino_updt_Sub, sino_updt_FULL)
+
+
diff --git a/Wrappers/Python/ccpi/reconstruction/Reconstructor.py b/Wrappers/Python/ccpi/reconstruction/Reconstructor.py
new file mode 100644
index 0000000..ba67327
--- /dev/null
+++ b/Wrappers/Python/ccpi/reconstruction/Reconstructor.py
@@ -0,0 +1,598 @@
+# -*- coding: utf-8 -*-
+###############################################################################
+#This work is part of the Core Imaging Library developed by
+#Visual Analytics and Imaging System Group of the Science Technology
+#Facilities Council, STFC
+#
+#Copyright 2017 Edoardo Pasca, Srikanth Nagella
+#Copyright 2017 Daniil Kazantsev
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#http://www.apache.org/licenses/LICENSE-2.0
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+###############################################################################
+
+
+
+import numpy
+import h5py
+from ccpi.reconstruction.parallelbeam import alg
+
+from Regularizer import Regularizer
+from enum import Enum
+
+import astra
+
+
+class Reconstructor:
+
+ class Algorithm(Enum):
+ CGLS = alg.cgls
+ CGLS_CONV = alg.cgls_conv
+ SIRT = alg.sirt
+ MLEM = alg.mlem
+ CGLS_TICHONOV = alg.cgls_tikhonov
+ CGLS_TVREG = alg.cgls_TVreg
+ FISTA = 'fista'
+
+ def __init__(self, algorithm = None, projection_data = None,
+ angles = None, center_of_rotation = None ,
+ flat_field = None, dark_field = None,
+ iterations = None, resolution = None, isLogScale = False, threads = None,
+ normalized_projection = None):
+
+ self.pars = dict()
+ self.pars['algorithm'] = algorithm
+ self.pars['projection_data'] = projection_data
+ self.pars['normalized_projection'] = normalized_projection
+ self.pars['angles'] = angles
+ self.pars['center_of_rotation'] = numpy.double(center_of_rotation)
+ self.pars['flat_field'] = flat_field
+ self.pars['iterations'] = iterations
+ self.pars['dark_field'] = dark_field
+ self.pars['resolution'] = resolution
+ self.pars['isLogScale'] = isLogScale
+ self.pars['threads'] = threads
+ if (iterations != None):
+ self.pars['iterationValues'] = numpy.zeros((iterations))
+
+ if projection_data != None and dark_field != None and flat_field != None:
+ norm = self.normalize(projection_data, dark_field, flat_field, 0.1)
+ self.pars['normalized_projection'] = norm
+
+
+ def setPars(self, parameters):
+ keys = ['algorithm','projection_data' ,'normalized_projection', \
+ 'angles' , 'center_of_rotation' , 'flat_field', \
+ 'iterations','dark_field' , 'resolution', 'isLogScale' , \
+ 'threads' , 'iterationValues', 'regularize']
+
+ for k in keys:
+ if k not in parameters.keys():
+ self.pars[k] = None
+ else:
+ self.pars[k] = parameters[k]
+
+
+ def sanityCheck(self):
+ projection_data = self.pars['projection_data']
+ dark_field = self.pars['dark_field']
+ flat_field = self.pars['flat_field']
+ angles = self.pars['angles']
+
+ if projection_data != None and dark_field != None and \
+ angles != None and flat_field != None:
+ data_shape = numpy.shape(projection_data)
+ angle_shape = numpy.shape(angles)
+
+ if angle_shape[0] != data_shape[0]:
+ #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \
+ # (angle_shape[0] , data_shape[0]) )
+ return (False , 'Projections and angles dimensions do not match: %d vs %d' % \
+ (angle_shape[0] , data_shape[0]) )
+
+ if data_shape[1:] != numpy.shape(flat_field):
+ #raise Exception('Projection and flat field dimensions do not match')
+ return (False , 'Projection and flat field dimensions do not match')
+ if data_shape[1:] != numpy.shape(dark_field):
+ #raise Exception('Projection and dark field dimensions do not match')
+ return (False , 'Projection and dark field dimensions do not match')
+
+ return (True , '' )
+ elif self.pars['normalized_projection'] != None:
+ data_shape = numpy.shape(self.pars['normalized_projection'])
+ angle_shape = numpy.shape(angles)
+
+ if angle_shape[0] != data_shape[0]:
+ #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \
+ # (angle_shape[0] , data_shape[0]) )
+ return (False , 'Projections and angles dimensions do not match: %d vs %d' % \
+ (angle_shape[0] , data_shape[0]) )
+ else:
+ return (True , '' )
+ else:
+ return (False , 'Not enough data')
+
+ def reconstruct(self, parameters = None):
+ if parameters != None:
+ self.setPars(parameters)
+
+ go , reason = self.sanityCheck()
+ if go:
+ return self._reconstruct()
+ else:
+ raise Exception(reason)
+
+
+ def _reconstruct(self, parameters=None):
+ if parameters!=None:
+ self.setPars(parameters)
+ parameters = self.pars
+
+ if parameters['algorithm'] != None and \
+ parameters['normalized_projection'] != None and \
+ parameters['angles'] != None and \
+ parameters['center_of_rotation'] != None and \
+ parameters['iterations'] != None and \
+ parameters['resolution'] != None and\
+ parameters['threads'] != None and\
+ parameters['isLogScale'] != None:
+
+
+ if parameters['algorithm'] in (Reconstructor.Algorithm.CGLS,
+ Reconstructor.Algorithm.MLEM, Reconstructor.Algorithm.SIRT):
+ #store parameters
+ self.pars = parameters
+ result = parameters['algorithm'](
+ parameters['normalized_projection'] ,
+ parameters['angles'],
+ parameters['center_of_rotation'],
+ parameters['resolution'],
+ parameters['iterations'],
+ parameters['threads'] ,
+ parameters['isLogScale']
+ )
+ return result
+ elif parameters['algorithm'] in (Reconstructor.Algorithm.CGLS_CONV,
+ Reconstructor.Algorithm.CGLS_TICHONOV,
+ Reconstructor.Algorithm.CGLS_TVREG) :
+ self.pars = parameters
+ result = parameters['algorithm'](
+ parameters['normalized_projection'] ,
+ parameters['angles'],
+ parameters['center_of_rotation'],
+ parameters['resolution'],
+ parameters['iterations'],
+ parameters['threads'] ,
+ parameters['regularize'],
+ numpy.zeros((parameters['iterations'])),
+ parameters['isLogScale']
+ )
+
+ elif parameters['algorithm'] == Reconstructor.Algorithm.FISTA:
+ pass
+
+ else:
+ if parameters['projection_data'] != None and \
+ parameters['dark_field'] != None and \
+ parameters['flat_field'] != None:
+ norm = self.normalize(parameters['projection_data'],
+ parameters['dark_field'],
+ parameters['flat_field'], 0.1)
+ self.pars['normalized_projection'] = norm
+ return self._reconstruct(parameters)
+
+
+
+ def _normalize(self, projection, dark, flat, def_val=0):
+ a = (projection - dark)
+ b = (flat-dark)
+ with numpy.errstate(divide='ignore', invalid='ignore'):
+ c = numpy.true_divide( a, b )
+ c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0
+ return c
+
+ def normalize(self, projections, dark, flat, def_val=0):
+ norm = [self._normalize(projection, dark, flat, def_val) for projection in projections]
+ return numpy.asarray (norm, dtype=numpy.float32)
+
+
+
+class FISTA():
+ '''FISTA-based reconstruction algorithm using ASTRA-toolbox
+
+ '''
+ # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>>
+ # ___Input___:
+ # params.[] file:
+ # - .proj_geom (geometry of the projector) [required]
+ # - .vol_geom (geometry of the reconstructed object) [required]
+ # - .sino (vectorized in 2D or 3D sinogram) [required]
+ # - .iterFISTA (iterations for the main loop, default 40)
+ # - .L_const (Lipschitz constant, default Power method) )
+ # - .X_ideal (ideal image, if given)
+ # - .weights (statisitcal weights, size of the sinogram)
+ # - .ROI (Region-of-interest, only if X_ideal is given)
+ # - .initialize (a 'warm start' using SIRT method from ASTRA)
+ #----------------Regularization choices------------------------
+ # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter)
+ # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter)
+ # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter)
+ # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04)
+ # - .Regul_Iterations (iterations for the selected penalty, default 25)
+ # - .Regul_tauLLT (time step parameter for LLT term)
+ # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal)
+ # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1)
+ #----------------Visualization parameters------------------------
+ # - .show (visualize reconstruction 1/0, (0 default))
+ # - .maxvalplot (maximum value to use for imshow[0 maxvalplot])
+ # - .slice (for 3D volumes - slice number to imshow)
+ # ___Output___:
+ # 1. X - reconstructed image/volume
+ # 2. output - a structure with
+ # - .Resid_error - residual error (if X_ideal is given)
+ # - .objective: value of the objective function
+ # - .L_const: Lipshitz constant to avoid recalculations
+
+ # References:
+ # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse
+ # Problems" by A. Beck and M Teboulle
+ # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo
+ # 3. "A novel tomographic reconstruction method based on the robust
+ # Student's t function for suppressing data outliers" D. Kazantsev et.al.
+ # D. Kazantsev, 2016-17
+ def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs):
+ self.params = dict()
+ self.params['projector_geometry'] = projector_geometry
+ self.params['output_geometry'] = output_geometry
+ self.params['input_sinogram'] = input_sinogram
+ detectors, nangles, sliceZ = numpy.shape(input_sinogram)
+ self.params['detectors'] = detectors
+ self.params['number_og_angles'] = nangles
+ self.params['SlicesZ'] = sliceZ
+
+ # Accepted input keywords
+ kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' ,
+ 'weights' , 'region_of_interest' , 'initialize' ,
+ 'regularizer' ,
+ 'ring_lambda_R_L1',
+ 'ring_alpha')
+
+ # handle keyworded parameters
+ if kwargs is not None:
+ for key, value in kwargs.items():
+ if key in kw:
+ #print("{0} = {1}".format(key, value))
+ self.pars[key] = value
+
+ # set the default values for the parameters if not set
+ if 'number_of_iterations' in kwargs.keys():
+ self.pars['number_of_iterations'] = kwargs['number_of_iterations']
+ else:
+ self.pars['number_of_iterations'] = 40
+ if 'weights' in kwargs.keys():
+ self.pars['weights'] = kwargs['weights']
+ else:
+ self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram']))
+ if 'Lipschitz_constant' in kwargs.keys():
+ self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant']
+ else:
+ self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod()
+
+ if not self.pars['ideal_image'] in kwargs.keys():
+ self.pars['ideal_image'] = None
+
+ if not self.pars['region_of_interest'] :
+ if self.pars['ideal_image'] == None:
+ pass
+ else:
+ self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0)
+
+ if not self.pars['regularizer'] :
+ self.pars['regularizer'] = None
+ else:
+ # the regularizer must be a correctly instantiated object
+ if not self.pars['ring_lambda_R_L1']:
+ self.pars['ring_lambda_R_L1'] = 0
+ if not self.pars['ring_alpha']:
+ self.pars['ring_alpha'] = 1
+
+
+
+
+ def calculateLipschitzConstantWithPowerMethod(self):
+ ''' using Power method (PM) to establish L constant'''
+
+ #N = params.vol_geom.GridColCount
+ N = self.pars['output_geometry'].GridColCount
+ proj_geom = self.params['projector_geometry']
+ vol_geom = self.params['output_geometry']
+ weights = self.pars['weights']
+ SlicesZ = self.pars['SlicesZ']
+
+ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'):
+ #% for parallel geometry we can do just one slice
+ #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...');
+ niter = 15;# % number of iteration for the PM
+ #N = params.vol_geom.GridColCount;
+ #x1 = rand(N,N,1);
+ x1 = numpy.random.rand(1,N,N)
+ #sqweight = sqrt(weights(:,:,1));
+ sqweight = numpy.sqrt(weights.T[0])
+ proj_geomT = proj_geom.copy();
+ proj_geomT.DetectorRowCount = 1;
+ vol_geomT = vol_geom.copy();
+ vol_geomT['GridSliceCount'] = 1;
+
+
+ for i in range(niter):
+ if i == 0:
+ #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+ sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT);
+ y = sqweight * y # element wise multiplication
+ #astra_mex_data3d('delete', sino_id);
+ astra.matlab.data3d('delete', sino_id)
+
+ idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT);
+ s = numpy.linalg.norm(x1)
+ ### this line?
+ x1 = x1/s;
+ ### this line?
+ sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+ y = sqweight*y;
+ astra.matlab.data3d('delete', sino_id);
+ astra.matlab.data3d('delete', idx);
+ #end
+ del proj_geomT
+ del vol_geomT
+ else
+ #% divergen beam geometry
+ #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...');
+ niter = 8; #% number of iteration for PM
+ x1 = numpy.random.rand(SlicesZ , N , N);
+ #sqweight = sqrt(weights);
+ sqweight = numpy.sqrt(weights.T[0])
+
+ sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom);
+ y = sqweight*y;
+ #astra_mex_data3d('delete', sino_id);
+ astra.matlab.data3d('delete', sino_id);
+
+ for i in range(niter):
+ #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom);
+ idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y,
+ proj_geom,
+ vol_geom)
+ s = numpy.linalg.norm(x1)
+ ### this line?
+ x1 = x1/s;
+ ### this line?
+ #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom);
+ sino_id, y = astra.creators.create_sino3d_gpu(x1,
+ proj_geom,
+ vol_geom);
+
+ y = sqweight*y;
+ #astra_mex_data3d('delete', sino_id);
+ #astra_mex_data3d('delete', id);
+ astra.matlab.data3d('delete', sino_id);
+ astra.matlab.data3d('delete', idx);
+ #end
+ #clear x1
+ del x1
+
+ return s
+
+
+ def setRegularizer(self, regularizer):
+ if regularizer
+ self.pars['regularizer'] = regularizer
+
+
+
+
+
+def getEntry(location):
+ for item in nx[location].keys():
+ print (item)
+
+
+print ("Loading Data")
+
+##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif"
+####ind = [i * 1049 for i in range(360)]
+#### use only 360 images
+##images = 200
+##ind = [int(i * 1049 / images) for i in range(images)]
+##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None)
+
+#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs"
+fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs"
+nx = h5py.File(fname, "r")
+
+# the data are stored in a particular location in the hdf5
+for item in nx['entry1/tomo_entry/data'].keys():
+ print (item)
+
+data = nx.get('entry1/tomo_entry/data/rotation_angle')
+angles = numpy.zeros(data.shape)
+data.read_direct(angles)
+print (angles)
+# angles should be in degrees
+
+data = nx.get('entry1/tomo_entry/data/data')
+stack = numpy.zeros(data.shape)
+data.read_direct(stack)
+print (data.shape)
+
+print ("Data Loaded")
+
+
+# Normalize
+data = nx.get('entry1/tomo_entry/instrument/detector/image_key')
+itype = numpy.zeros(data.shape)
+data.read_direct(itype)
+# 2 is dark field
+darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ]
+dark = darks[0]
+for i in range(1, len(darks)):
+ dark += darks[i]
+dark = dark / len(darks)
+#dark[0][0] = dark[0][1]
+
+# 1 is flat field
+flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ]
+flat = flats[0]
+for i in range(1, len(flats)):
+ flat += flats[i]
+flat = flat / len(flats)
+#flat[0][0] = dark[0][1]
+
+
+# 0 is projection data
+proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ]
+angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ]
+angle_proj = numpy.asarray (angle_proj)
+angle_proj = angle_proj.astype(numpy.float32)
+
+# normalized data are
+# norm = (projection - dark)/(flat-dark)
+
+def normalize(projection, dark, flat, def_val=0.1):
+ a = (projection - dark)
+ b = (flat-dark)
+ with numpy.errstate(divide='ignore', invalid='ignore'):
+ c = numpy.true_divide( a, b )
+ c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0
+ return c
+
+
+norm = [normalize(projection, dark, flat) for projection in proj]
+norm = numpy.asarray (norm)
+norm = norm.astype(numpy.float32)
+
+#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm,
+# angles = angle_proj, center_of_rotation = 86.2 ,
+# flat_field = flat, dark_field = dark,
+# iterations = 15, resolution = 1, isLogScale = False, threads = 3)
+
+#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj,
+# angles = angle_proj, center_of_rotation = 86.2 ,
+# flat_field = flat, dark_field = dark,
+# iterations = 15, resolution = 1, isLogScale = False, threads = 3)
+#img_cgls = recon.reconstruct()
+#
+#pars = dict()
+#pars['algorithm'] = Reconstructor.Algorithm.SIRT
+#pars['projection_data'] = proj
+#pars['angles'] = angle_proj
+#pars['center_of_rotation'] = numpy.double(86.2)
+#pars['flat_field'] = flat
+#pars['iterations'] = 15
+#pars['dark_field'] = dark
+#pars['resolution'] = 1
+#pars['isLogScale'] = False
+#pars['threads'] = 3
+#
+#img_sirt = recon.reconstruct(pars)
+#
+#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM
+#img_mlem = recon.reconstruct()
+
+############################################################
+############################################################
+#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV
+#recon.pars['regularize'] = numpy.double(0.1)
+#img_cgls_conv = recon.reconstruct()
+
+niterations = 15
+threads = 3
+
+img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
+img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
+img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
+
+iteration_values = numpy.zeros((niterations,))
+img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
+ iteration_values, False)
+print ("iteration values %s" % str(iteration_values))
+
+iteration_values = numpy.zeros((niterations,))
+img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
+ numpy.double(1e-5), iteration_values , False)
+print ("iteration values %s" % str(iteration_values))
+iteration_values = numpy.zeros((niterations,))
+img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
+ numpy.double(1e-5), iteration_values , False)
+print ("iteration values %s" % str(iteration_values))
+
+
+##numpy.save("cgls_recon.npy", img_data)
+import matplotlib.pyplot as plt
+fig, ax = plt.subplots(1,6,sharey=True)
+ax[0].imshow(img_cgls[80])
+ax[0].axis('off') # clear x- and y-axes
+ax[1].imshow(img_sirt[80])
+ax[1].axis('off') # clear x- and y-axes
+ax[2].imshow(img_mlem[80])
+ax[2].axis('off') # clear x- and y-axesplt.show()
+ax[3].imshow(img_cgls_conv[80])
+ax[3].axis('off') # clear x- and y-axesplt.show()
+ax[4].imshow(img_cgls_tikhonov[80])
+ax[4].axis('off') # clear x- and y-axesplt.show()
+ax[5].imshow(img_cgls_TVreg[80])
+ax[5].axis('off') # clear x- and y-axesplt.show()
+
+
+plt.show()
+
+#viewer = edo.CILViewer()
+#viewer.setInputAsNumpy(img_cgls2)
+#viewer.displaySliceActor(0)
+#viewer.startRenderLoop()
+
+import vtk
+
+def NumpyToVTKImageData(numpyarray):
+ if (len(numpy.shape(numpyarray)) == 3):
+ doubleImg = vtk.vtkImageData()
+ shape = numpy.shape(numpyarray)
+ doubleImg.SetDimensions(shape[0], shape[1], shape[2])
+ doubleImg.SetOrigin(0,0,0)
+ doubleImg.SetSpacing(1,1,1)
+ doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1)
+ #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation())
+ doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1)
+
+ for i in range(shape[0]):
+ for j in range(shape[1]):
+ for k in range(shape[2]):
+ doubleImg.SetScalarComponentFromDouble(
+ i,j,k,0, numpyarray[i][j][k])
+ #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) )
+ # rescale to appropriate VTK_UNSIGNED_SHORT
+ stats = vtk.vtkImageAccumulate()
+ stats.SetInputData(doubleImg)
+ stats.Update()
+ iMin = stats.GetMin()[0]
+ iMax = stats.GetMax()[0]
+ scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin)
+
+ shiftScaler = vtk.vtkImageShiftScale ()
+ shiftScaler.SetInputData(doubleImg)
+ shiftScaler.SetScale(scale)
+ shiftScaler.SetShift(iMin)
+ shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT)
+ shiftScaler.Update()
+ return shiftScaler.GetOutput()
+
+#writer = vtk.vtkMetaImageWriter()
+#writer.SetFileName(alg + "_recon.mha")
+#writer.SetInputData(NumpyToVTKImageData(img_cgls2))
+#writer.Write()