diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Python/CMakeLists.txt | 111 | ||||
-rw-r--r-- | src/Python/ccpi/fista/FISTAReconstructor.py | 609 | ||||
-rw-r--r-- | src/Python/ccpi/fista/Reconstructor.py | 425 | ||||
-rw-r--r-- | src/Python/ccpi/fista/__init__.py | 0 | ||||
-rw-r--r-- | src/Python/ccpi/reconstruction/FISTAReconstructor.py | 602 | ||||
-rw-r--r-- | src/Python/compile-fista.bat.in | 7 | ||||
-rw-r--r-- | src/Python/compile-fista.sh.in | 9 | ||||
-rw-r--r-- | src/Python/conda-recipe/meta.yaml | 2 | ||||
-rw-r--r-- | src/Python/fista-recipe/build.sh | 10 | ||||
-rw-r--r-- | src/Python/fista-recipe/meta.yaml | 29 | ||||
-rw-r--r-- | src/Python/setup-fista.py.in | 27 | ||||
-rw-r--r-- | src/Python/setup.py.in | 4 | ||||
-rw-r--r-- | src/Python/test/test_reconstructor-os.py (renamed from src/Python/test_reconstructor-os.py) | 31 | ||||
-rw-r--r-- | src/Python/test/test_reconstructor.py | 309 |
14 files changed, 940 insertions, 1235 deletions
diff --git a/src/Python/CMakeLists.txt b/src/Python/CMakeLists.txt index b464059..66630cb 100644 --- a/src/Python/CMakeLists.txt +++ b/src/Python/CMakeLists.txt @@ -12,21 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -message("CIL VERSION " ${CIL_VERSION}) - - # variables that must be set for conda compilation #PREFIX=C:\Apps\Miniconda2\envs\cil\Library #LIBRARY_INC=C:\\Apps\\Miniconda2\\envs\\cil\\Library\\include set (NUMPY_VERSION 1.12) -#set (PYTHON_VERSION 3.5) - -#https://groups.google.com/a/continuum.io/forum/#!topic/anaconda/R9gWjl09UFs -#set (CONDA_ENVIRONMENT "cil") ## Tries to parse the output of conda env list to determine the current ## active conda environment +message ("Trying to determine your active conda environment...") execute_process(COMMAND "conda" "env" "list" OUTPUT_VARIABLE _CONDA_ENVS RESULT_VARIABLE _CONDA_RESULT @@ -44,7 +38,7 @@ execute_process(COMMAND "conda" "env" "list" endif() endforeach() else() - message("conda result false" ${_CONDA_ERR}) + message(FATAL_ERROR "ERROR with conda command " ${_CONDA_ERR}) endif() if (${CONDA_ENVIRONMENT} AND ${CONDA_ENVIRONMENT_PATH}) @@ -55,24 +49,28 @@ else() message("Using current conda environmnet path " ${CONDA_ENVIRONMENT_PATH}) endif() - - message("CIL VERSION " ${CIL_VERSION}) # set the Python variables for the Conda environment include(FindAnacondaEnvironment.cmake) findPythonForAnacondaEnvironment(${CONDA_ENVIRONMENT_PATH}) + message("Python found " ${PYTHON_VERSION_STRING}) message("Python found Major " ${PYTHON_VERSION_MAJOR}) message("Python found Minor " ${PYTHON_VERSION_MINOR}) + findPythonPackagesPath() message("PYTHON_PACKAGES_FOUND " ${PYTHON_PACKAGES_PATH}) -# copy the Pyhon files of the package -file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging/) -file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi) -file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/imaging/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging) -file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/imaging/Regularizer.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging) +## CACHE SOME VARIABLES ## +set (CONDA_ENVIRONMENT ${CONDA_ENVIRONMENT} CACHE INTERNAL "active conda environment" FORCE) +set (CONDA_ENVIRONMENT_PATH ${CONDA_ENVIRONMENT_PATH} CACHE INTERNAL "active conda environment" FORCE) + +set (PYTHON_VERSION_STRING ${PYTHON_VERSION_STRING} CACHE INTERNAL "conda environment Python version string" FORCE) +set (PYTHON_VERSION_MAJOR ${PYTHON_VERSION_MAJOR} CACHE INTERNAL "conda environment Python version major" FORCE) +set (PYTHON_VERSION_MINOR ${PYTHON_VERSION_MINOR} CACHE INTERNAL "conda environment Python version minor" FORCE) +set (PYTHON_VERSION_PATCH ${PYTHON_VERSION_PATCH} CACHE INTERNAL "conda environment Python version patch" FORCE) +set (PYTHON_PACKAGES_PATH ${PYTHON_PACKAGES_PATH} CACHE INTERNAL "conda environment Python packages path" FORCE) if (WIN32) #set (CONDA_ENVIRONMENT_PATH "C:\\Apps\\Miniconda2\\envs\\${CONDA_ENVIRONMENT}" CACHE PATH "Main environment directory") @@ -84,21 +82,92 @@ elseif (UNIX) set (CONDA_ENVIRONMENT_LIBRARY_INC "${CONDA_ENVIRONMENT_PREFIX}/include" CACHE PATH "env dir") endif() +######### CONFIGURE REGULARIZER PACKAGE ############# + +# copy the Pyhon files of the package regularizer +file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging/) +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi) +# regularizers +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/imaging/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging) +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/imaging/Regularizer.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging) + # Copy and configure the relative conda build and recipes configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in ${CMAKE_CURRENT_BINARY_DIR}/setup.py) file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe) file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/meta.yaml DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe) if (WIN32) - file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/bld.bat DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe/) - configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.bat.in ${CMAKE_CURRENT_BINARY_DIR}/compile.bat) + + file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/bld.bat DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe/) + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.bat.in ${CMAKE_CURRENT_BINARY_DIR}/compile.bat) + +elseif(UNIX) + + message ("We are on UNIX") + file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/build.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe/) + # assumes we will use bash + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.sh.in ${CMAKE_CURRENT_BINARY_DIR}/compile.sh) + +endif() + +########## CONFIGURE FISTA RECONSTRUCTOR PACKAGE +# fista reconstructor +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/reconstruction/FISTAReconstructor.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/reconstruction) +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/reconstruction/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/reconstruction) + +configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup-fista.py.in ${CMAKE_CURRENT_BINARY_DIR}/setup-fista.py) +file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe) +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/fista-recipe/meta.yaml DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe) + +if (WIN32) + + file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/fista-recipe/bld.bat DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe/) + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile-fista.bat.in ${CMAKE_CURRENT_BINARY_DIR}/compile-fista.bat) + elseif(UNIX) - message ("We are on UNIX") - file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/build.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe/) - # assumes we will use bash - configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.sh.in ${CMAKE_CURRENT_BINARY_DIR}/compile.sh) + message ("We are on UNIX") + file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/fista-recipe/build.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe/) + # assumes we will use bash + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile-fista.sh.in ${CMAKE_CURRENT_BINARY_DIR}/compile-fista.sh) endif() +############################# TARGETS + +########################## REGULARIZER PACKAGE ############################### + +# runs cmake on the build tree to update the code from source +add_custom_target(update_code + COMMAND ${CMAKE_COMMAND} + ARGS ${CMAKE_SOURCE_DIR} + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + ) + + +add_custom_target(fista + COMMAND bash + compile-fista.sh + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${update_code} + ) + +add_custom_target(regularizers + COMMAND bash + compile.sh + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${update_code} + ) + +add_custom_target(install-fista + COMMAND conda + install --force --use-local ccpi-fista=${CIL_VERSION} -c ccpi -c conda-forge + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${fista}) + +add_custom_target(install-regularizers + COMMAND conda + install --force --use-local ccpi-regularizers=${CIL_VERSION} -c ccpi -c conda-forge + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${fista}) ### add tests #add_executable(RegularizersTest ) diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py deleted file mode 100644 index 85bfac5..0000000 --- a/src/Python/ccpi/fista/FISTAReconstructor.py +++ /dev/null @@ -1,609 +0,0 @@ -# -*- coding: utf-8 -*- -############################################################################### -#This work is part of the Core Imaging Library developed by -#Visual Analytics and Imaging System Group of the Science Technology -#Facilities Council, STFC -# -#Copyright 2017 Edoardo Pasca, Srikanth Nagella -#Copyright 2017 Daniil Kazantsev -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -#http://www.apache.org/licenses/LICENSE-2.0 -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. -############################################################################### - - - -import numpy -#from ccpi.reconstruction.parallelbeam import alg - -#from ccpi.imaging.Regularizer import Regularizer -from enum import Enum - -import astra - - - -class FISTAReconstructor(): - '''FISTA-based reconstruction algorithm using ASTRA-toolbox - - ''' - # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> - # ___Input___: - # params.[] file: - # - .proj_geom (geometry of the projector) [required] - # - .vol_geom (geometry of the reconstructed object) [required] - # - .sino (vectorized in 2D or 3D sinogram) [required] - # - .iterFISTA (iterations for the main loop, default 40) - # - .L_const (Lipschitz constant, default Power method) ) - # - .X_ideal (ideal image, if given) - # - .weights (statisitcal weights, size of the sinogram) - # - .ROI (Region-of-interest, only if X_ideal is given) - # - .initialize (a 'warm start' using SIRT method from ASTRA) - #----------------Regularization choices------------------------ - # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) - # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) - # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) - # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) - # - .Regul_Iterations (iterations for the selected penalty, default 25) - # - .Regul_tauLLT (time step parameter for LLT term) - # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) - # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) - #----------------Visualization parameters------------------------ - # - .show (visualize reconstruction 1/0, (0 default)) - # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) - # - .slice (for 3D volumes - slice number to imshow) - # ___Output___: - # 1. X - reconstructed image/volume - # 2. output - a structure with - # - .Resid_error - residual error (if X_ideal is given) - # - .objective: value of the objective function - # - .L_const: Lipshitz constant to avoid recalculations - - # References: - # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse - # Problems" by A. Beck and M Teboulle - # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo - # 3. "A novel tomographic reconstruction method based on the robust - # Student's t function for suppressing data outliers" D. Kazantsev et.al. - # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, output_geometry, input_sinogram, - **kwargs): - # handle parmeters: - # obligatory parameters - self.pars = dict() - self.pars['projector_geometry'] = projector_geometry # proj_geom - self.pars['output_geometry'] = output_geometry # vol_geom - self.pars['input_sinogram'] = input_sinogram # sino - sliceZ, nangles, detectors = numpy.shape(input_sinogram) - self.pars['detectors'] = detectors - self.pars['number_of_angles'] = nangles - self.pars['SlicesZ'] = sliceZ - self.pars['output_volume'] = None - - print (self.pars) - # handle optional input parameters (at instantiation) - - # Accepted input keywords - kw = ( - # mandatory fields - 'projector_geometry', - 'output_geometry', - 'input_sinogram', - 'detectors', - 'number_of_angles', - 'SlicesZ', - # optional fields - 'number_of_iterations', - 'Lipschitz_constant' , - 'ideal_image' , - 'weights' , - 'region_of_interest' , - 'initialize' , - 'regularizer' , - 'ring_lambda_R_L1', - 'ring_alpha', - 'subsets', - 'output_volume', - 'os_subsets', - 'os_indices', - 'os_bins') - self.acceptedInputKeywords = list(kw) - - # handle keyworded parameters - if kwargs is not None: - for key, value in kwargs.items(): - if key in kw: - #print("{0} = {1}".format(key, value)) - self.pars[key] = value - - # set the default values for the parameters if not set - if 'number_of_iterations' in kwargs.keys(): - self.pars['number_of_iterations'] = kwargs['number_of_iterations'] - else: - self.pars['number_of_iterations'] = 40 - if 'weights' in kwargs.keys(): - self.pars['weights'] = kwargs['weights'] - else: - self.pars['weights'] = \ - numpy.ones(numpy.shape( - self.pars['input_sinogram'])) - if 'Lipschitz_constant' in kwargs.keys(): - self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] - else: - self.pars['Lipschitz_constant'] = None - - if not 'ideal_image' in kwargs.keys(): - self.pars['ideal_image'] = None - - if not 'region_of_interest'in kwargs.keys() : - if self.pars['ideal_image'] == None: - self.pars['region_of_interest'] = None - else: - ## nonzero if the image is larger than m - fsm = numpy.frompyfunc(lambda x,m: 1 if x>m else 0, 2,1) - - self.pars['region_of_interest'] = fsm(self.pars['ideal_image'], 0) - - # the regularizer must be a correctly instantiated object - if not 'regularizer' in kwargs.keys() : - self.pars['regularizer'] = None - - #RING REMOVAL - if not 'ring_lambda_R_L1' in kwargs.keys(): - self.pars['ring_lambda_R_L1'] = 0 - if not 'ring_alpha' in kwargs.keys(): - self.pars['ring_alpha'] = 1 - - # ORDERED SUBSET - if not 'subsets' in kwargs.keys(): - self.pars['subsets'] = 0 - else: - self.createOrderedSubsets() - - if not 'initialize' in kwargs.keys(): - self.pars['initialize'] = False - - - - - def setParameter(self, **kwargs): - '''set named parameter for the reconstructor engine - - raises Exception if the named parameter is not recognized - - ''' - for key , value in kwargs.items(): - if key in self.acceptedInputKeywords: - self.pars[key] = value - else: - raise Exception('Wrong parameter {0} for '.format(key) + - 'reconstructor') - # setParameter - - def getParameter(self, key): - if type(key) is str: - if key in self.acceptedInputKeywords: - return self.pars[key] - else: - raise Exception('Unrecongnised parameter: {0} '.format(key) ) - elif type(key) is list: - outpars = [] - for k in key: - outpars.append(self.getParameter(k)) - return outpars - else: - raise Exception('Unhandled input {0}' .format(str(type(key)))) - - - def calculateLipschitzConstantWithPowerMethod(self): - ''' using Power method (PM) to establish L constant''' - - N = self.pars['output_geometry']['GridColCount'] - proj_geom = self.pars['projector_geometry'] - vol_geom = self.pars['output_geometry'] - weights = self.pars['weights'] - SlicesZ = self.pars['SlicesZ'] - - - - if (proj_geom['type'] == 'parallel') or \ - (proj_geom['type'] == 'parallel3d'): - #% for parallel geometry we can do just one slice - #print('Calculating Lipshitz constant for parallel beam geometry...') - niter = 5;# % number of iteration for the PM - #N = params.vol_geom.GridColCount; - #x1 = rand(N,N,1); - x1 = numpy.random.rand(1,N,N) - #sqweight = sqrt(weights(:,:,1)); - sqweight = numpy.sqrt(weights[0]) - proj_geomT = proj_geom.copy(); - proj_geomT['DetectorRowCount'] = 1; - vol_geomT = vol_geom.copy(); - vol_geomT['GridSliceCount'] = 1; - - #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - - - for i in range(niter): - # [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT); - # s = norm(x1(:)); - # x1 = x1/s; - # [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - # y = sqweight.*y; - # astra_mex_data3d('delete', sino_id); - # astra_mex_data3d('delete', id); - #print ("iteration {0}".format(i)) - - sino_id, y = astra.creators.create_sino3d_gpu(x1, - proj_geomT, - vol_geomT) - - y = (sqweight * y).copy() # element wise multiplication - - #b=fig.add_subplot(2,1,2) - #imgplot = plt.imshow(x1[0]) - #plt.show() - - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id) - del x1 - - idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(), - proj_geomT, - vol_geomT) - del y - - - s = numpy.linalg.norm(x1) - ### this line? - x1 = (x1/s).copy(); - - # ### this line? - # sino_id, y = astra.creators.create_sino3d_gpu(x1, - # proj_geomT, - # vol_geomT); - # y = sqweight * y; - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx) - print ("iteration {0} s= {1}".format(i,s)) - - #end - del proj_geomT - del vol_geomT - #plt.show() - else: - #% divergen beam geometry - print('Calculating Lipshitz constant for divergen beam geometry...') - niter = 8; #% number of iteration for PM - x1 = numpy.random.rand(SlicesZ , N , N); - #sqweight = sqrt(weights); - sqweight = numpy.sqrt(weights[0]) - - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id); - - for i in range(niter): - #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, - proj_geom, - vol_geom) - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); - sino_id, y = astra.creators.create_sino3d_gpu(x1, - proj_geom, - vol_geom); - - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - #astra_mex_data3d('delete', id); - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - #clear x1 - del x1 - - - return s - - - def setRegularizer(self, regularizer): - if regularizer is not None: - self.pars['regularizer'] = regularizer - - - def initialize(self): - # convenience variable storage - proj_geom = self.pars['projector_geometry'] - vol_geom = self.pars['output_geometry'] - sino = self.pars['input_sinogram'] - - # a 'warm start' with SIRT method - # Create a data object for the reconstruction - rec_id = astra.matlab.data3d('create', '-vol', - vol_geom); - - #sinogram_id = astra_mex_data3d('create', '-proj3d', proj_geom, sino); - sinogram_id = astra.matlab.data3d('create', '-proj3d', - proj_geom, - sino) - - sirt_config = astra.astra_dict('SIRT3D_CUDA') - sirt_config['ReconstructionDataId' ] = rec_id - sirt_config['ProjectionDataId'] = sinogram_id - - sirt = astra.algorithm.create(sirt_config) - astra.algorithm.run(sirt, iterations=35) - X = astra.matlab.data3d('get', rec_id) - - # clean up memory - astra.matlab.data3d('delete', rec_id) - astra.matlab.data3d('delete', sinogram_id) - astra.algorithm.delete(sirt) - - - - return X - - def createOrderedSubsets(self, subsets=None): - if subsets is None: - try: - subsets = self.getParameter('subsets') - except Exception(): - subsets = 0 - #return subsets - - angles = self.getParameter('projector_geometry')['ProjectionAngles'] - - #binEdges = numpy.linspace(angles.min(), - # angles.max(), - # subsets + 1) - binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) - # get rearranged subset indices - IndicesReorg = numpy.zeros((numpy.shape(angles))) - counterM = 0 - for ii in range(binsDiscr.max()): - counter = 0 - for jj in range(subsets): - curr_index = ii + jj + counter - #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM)) - if binsDiscr[jj] > ii: - if (counterM < numpy.size(IndicesReorg)): - IndicesReorg[counterM] = curr_index - counterM = counterM + 1 - - counter = counter + binsDiscr[jj] - 1 - - # store the OS in parameters - self.setParameter(os_subsets=subsets, - os_bins=binsDiscr, - os_indices=IndicesReorg) - - - def prepareForIteration(self): - print ("FISTA Reconstructor: prepare for iteration") - - self.residual_error = numpy.zeros((self.pars['number_of_iterations'])) - self.objective = numpy.zeros((self.pars['number_of_iterations'])) - - #2D array (for 3D data) of sparse "ring" - detectors, nangles, sliceZ = numpy.shape(self.pars['input_sinogram']) - self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float) - # another ring variable - self.r_x = self.r.copy() - - self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram'])) - - if self.getParameter('Lipschitz_constant') is None: - self.pars['Lipschitz_constant'] = \ - self.calculateLipschitzConstantWithPowerMethod() - # errors vector (if the ground truth is given) - self.Resid_error = numpy.zeros((self.getParameter('number_of_iterations'))); - # objective function values vector - self.objective = numpy.zeros((self.getParameter('number_of_iterations'))); - - - # prepareForIteration - - def iterate(self, Xin=None): - print ("FISTA Reconstructor: iterate") - - if Xin is None: - if self.getParameter('initialize'): - X = self.initialize() - else: - N = vol_geom['GridColCount'] - X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) - else: - # copy by reference - X = Xin - # store the output volume in the parameters - self.setParameter(output_volume=X) - X_t = X.copy() - # convenience variable storage - proj_geom , vol_geom, sino , \ - SlicesZ = self.getParameter([ 'projector_geometry' , - 'output_geometry', - 'input_sinogram', - 'SlicesZ' ]) - - t = 1 - - for i in range(self.getParameter('number_of_iterations')): - X_old = X.copy() - t_old = t - r_old = self.r.copy() - if self.getParameter('projector_geometry')['type'] == 'parallel' or \ - self.getParameter('projector_geometry')['type'] == 'fanflat' or \ - self.getParameter('projector_geometry')['type'] == 'fanflat_vec': - # if the geometry is parallel use slice-by-slice - # projection-backprojection routine - #sino_updt = zeros(size(sino),'single'); - proj_geomT = proj_geom.copy() - proj_geomT['DetectorRowCount'] = 1 - vol_geomT = vol_geom.copy() - vol_geomT['GridSliceCount'] = 1; - self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) - for kkk in range(SlicesZ): - sino_id, self.sino_updt[kkk] = \ - astra.creators.create_sino3d_gpu( - X_t[kkk:kkk+1], proj_geomT, vol_geomT) - astra.matlab.data3d('delete', sino_id) - else: - # for divergent 3D geometry (watch the GPU memory overflow in - # ASTRA versions < 1.8) - #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); - sino_id, self.sino_updt = astra.creators.create_sino3d_gpu( - X_t, proj_geom, vol_geom) - - - ## RING REMOVAL - self.ringRemoval(i) - ## Projection/Backprojection Routine - self.projectionBackprojection(X, X_t) - astra.matlab.data3d('delete', sino_id) - ## REGULARIZATION - X = self.regularize(X) - ## Update Loop - X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old) - self.setParameter(output_volume=X) - return X - ## iterate - - def ringRemoval(self, i): - print ("FISTA Reconstructor: ring removal") - residual = self.residual - lambdaR_L1 , alpha_ring , weights , L_const , sino= \ - self.getParameter(['ring_lambda_R_L1', - 'ring_alpha' , 'weights', - 'Lipschitz_constant', - 'input_sinogram']) - r_x = self.r_x - sino_updt = self.sino_updt - - SlicesZ, anglesNumb, Detectors = \ - numpy.shape(self.getParameter('input_sinogram')) - if lambdaR_L1 > 0 : - for kkk in range(anglesNumb): - - residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ - ((sino_updt[:,kkk,:]).squeeze() - \ - (sino[:,kkk,:]).squeeze() -\ - (alpha_ring * r_x) - ) - vec = residual.sum(axis = 1) - #if SlicesZ > 1: - # vec = vec[:,1,:].squeeze() - self.r = (r_x - (1./L_const) * vec).copy() - self.objective[i] = (0.5 * (residual ** 2).sum()) - - def projectionBackprojection(self, X, X_t): - print ("FISTA Reconstructor: projection-backprojection routine") - - # a few useful variables - SlicesZ, anglesNumb, Detectors = \ - numpy.shape(self.getParameter('input_sinogram')) - residual = self.residual - proj_geom , vol_geom , L_const = \ - self.getParameter(['projector_geometry' , - 'output_geometry', - 'Lipschitz_constant']) - - - if self.getParameter('projector_geometry')['type'] == 'parallel' or \ - self.getParameter('projector_geometry')['type'] == 'fanflat' or \ - self.getParameter('projector_geometry')['type'] == 'fanflat_vec': - # if the geometry is parallel use slice-by-slice - # projection-backprojection routine - #sino_updt = zeros(size(sino),'single'); - proj_geomT = proj_geom.copy() - proj_geomT['DetectorRowCount'] = 1 - vol_geomT = vol_geom.copy() - vol_geomT['GridSliceCount'] = 1; - x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) - - for kkk in range(SlicesZ): - - x_id, x_temp[kkk] = \ - astra.creators.create_backprojection3d_gpu( - residual[kkk:kkk+1], - proj_geomT, vol_geomT) - astra.matlab.data3d('delete', x_id) - else: - x_id, x_temp = \ - astra.creators.create_backprojection3d_gpu( - residual, proj_geom, vol_geom) - - X = X_t - (1/L_const) * x_temp - #astra.matlab.data3d('delete', sino_id) - astra.matlab.data3d('delete', x_id) - - def regularize(self, X): - print ("FISTA Reconstructor: regularize") - - regularizer = self.getParameter('regularizer') - if regularizer is not None: - return regularizer(input=X) - else: - return X - - def updateLoop(self, i, X, X_old, r_old, t, t_old): - print ("FISTA Reconstructor: update loop") - lambdaR_L1 = self.getParameter('ring_lambda_R_L1') - if lambdaR_L1 > 0: - self.r = numpy.max( - numpy.abs(self.r) - lambdaR_L1 , 0) * \ - numpy.sign(self.r) - t = (1 + numpy.sqrt(1 + 4 * t**2))/2 - X_t = X + (((t_old -1)/t) * (X - X_old)) - - if lambdaR_L1 > 0: - self.r_x = self.r + \ - (((t_old-1)/t) * (self.r - r_old)) - - if self.getParameter('region_of_interest') is None: - string = 'Iteration Number {0} | Objective {1} \n' - print (string.format( i, self.objective[i])) - else: - ROI , X_ideal = fistaRecon.getParameter('region_of_interest', - 'ideal_image') - - Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) - string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' - print (string.format(i,Resid_error[i], self.objective[i])) - return (X , X_t, t) - - def os_iterate(self, Xin=None): - print ("FISTA Reconstructor: iterate") - - if Xin is None: - if self.getParameter('initialize'): - X = self.initialize() - else: - N = vol_geom['GridColCount'] - X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) - else: - # copy by reference - X = Xin - # store the output volume in the parameters - self.setParameter(output_volume=X) - X_t = X.copy() - - # some useful constants - proj_geom , vol_geom, sino , \ - SlicesZ, weights , alpha_ring , - lambdaR_L1 , L_const = self.getParameter( - ['projector_geometry' , 'output_geometry', - 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha' , - 'ring_lambda_R_L1', 'Lipschitz_constant']) diff --git a/src/Python/ccpi/fista/Reconstructor.py b/src/Python/ccpi/fista/Reconstructor.py deleted file mode 100644 index d29ac0d..0000000 --- a/src/Python/ccpi/fista/Reconstructor.py +++ /dev/null @@ -1,425 +0,0 @@ -# -*- coding: utf-8 -*- -############################################################################### -#This work is part of the Core Imaging Library developed by -#Visual Analytics and Imaging System Group of the Science Technology -#Facilities Council, STFC -# -#Copyright 2017 Edoardo Pasca, Srikanth Nagella -#Copyright 2017 Daniil Kazantsev -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -#http://www.apache.org/licenses/LICENSE-2.0 -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. -############################################################################### - - - -import numpy -import h5py -from ccpi.reconstruction.parallelbeam import alg - -from Regularizer import Regularizer -from enum import Enum - -import astra - - - -class FISTAReconstructor(): - '''FISTA-based reconstruction algorithm using ASTRA-toolbox - - ''' - # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> - # ___Input___: - # params.[] file: - # - .proj_geom (geometry of the projector) [required] - # - .vol_geom (geometry of the reconstructed object) [required] - # - .sino (vectorized in 2D or 3D sinogram) [required] - # - .iterFISTA (iterations for the main loop, default 40) - # - .L_const (Lipschitz constant, default Power method) ) - # - .X_ideal (ideal image, if given) - # - .weights (statisitcal weights, size of the sinogram) - # - .ROI (Region-of-interest, only if X_ideal is given) - # - .initialize (a 'warm start' using SIRT method from ASTRA) - #----------------Regularization choices------------------------ - # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) - # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) - # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) - # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) - # - .Regul_Iterations (iterations for the selected penalty, default 25) - # - .Regul_tauLLT (time step parameter for LLT term) - # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) - # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) - #----------------Visualization parameters------------------------ - # - .show (visualize reconstruction 1/0, (0 default)) - # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) - # - .slice (for 3D volumes - slice number to imshow) - # ___Output___: - # 1. X - reconstructed image/volume - # 2. output - a structure with - # - .Resid_error - residual error (if X_ideal is given) - # - .objective: value of the objective function - # - .L_const: Lipshitz constant to avoid recalculations - - # References: - # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse - # Problems" by A. Beck and M Teboulle - # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo - # 3. "A novel tomographic reconstruction method based on the robust - # Student's t function for suppressing data outliers" D. Kazantsev et.al. - # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): - self.params = dict() - self.params['projector_geometry'] = projector_geometry - self.params['output_geometry'] = output_geometry - self.params['input_sinogram'] = input_sinogram - detectors, nangles, sliceZ = numpy.shape(input_sinogram) - self.params['detectors'] = detectors - self.params['number_og_angles'] = nangles - self.params['SlicesZ'] = sliceZ - - # Accepted input keywords - kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , - 'weights' , 'region_of_interest' , 'initialize' , - 'regularizer' , - 'ring_lambda_R_L1', - 'ring_alpha') - - # handle keyworded parameters - if kwargs is not None: - for key, value in kwargs.items(): - if key in kw: - #print("{0} = {1}".format(key, value)) - self.pars[key] = value - - # set the default values for the parameters if not set - if 'number_of_iterations' in kwargs.keys(): - self.pars['number_of_iterations'] = kwargs['number_of_iterations'] - else: - self.pars['number_of_iterations'] = 40 - if 'weights' in kwargs.keys(): - self.pars['weights'] = kwargs['weights'] - else: - self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) - if 'Lipschitz_constant' in kwargs.keys(): - self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] - else: - self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() - - if not self.pars['ideal_image'] in kwargs.keys(): - self.pars['ideal_image'] = None - - if not self.pars['region_of_interest'] : - if self.pars['ideal_image'] == None: - pass - else: - self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) - - if not self.pars['regularizer'] : - self.pars['regularizer'] = None - else: - # the regularizer must be a correctly instantiated object - if not self.pars['ring_lambda_R_L1']: - self.pars['ring_lambda_R_L1'] = 0 - if not self.pars['ring_alpha']: - self.pars['ring_alpha'] = 1 - - - - - def calculateLipschitzConstantWithPowerMethod(self): - ''' using Power method (PM) to establish L constant''' - - #N = params.vol_geom.GridColCount - N = self.pars['output_geometry'].GridColCount - proj_geom = self.params['projector_geometry'] - vol_geom = self.params['output_geometry'] - weights = self.pars['weights'] - SlicesZ = self.pars['SlicesZ'] - - if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): - #% for parallel geometry we can do just one slice - #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); - niter = 15;# % number of iteration for the PM - #N = params.vol_geom.GridColCount; - #x1 = rand(N,N,1); - x1 = numpy.random.rand(1,N,N) - #sqweight = sqrt(weights(:,:,1)); - sqweight = numpy.sqrt(weights.T[0]) - proj_geomT = proj_geom.copy(); - proj_geomT.DetectorRowCount = 1; - vol_geomT = vol_geom.copy(); - vol_geomT['GridSliceCount'] = 1; - - - for i in range(niter): - if i == 0: - #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); - y = sqweight * y # element wise multiplication - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id) - - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - y = sqweight*y; - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - del proj_geomT - del vol_geomT - else - #% divergen beam geometry - #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); - niter = 8; #% number of iteration for PM - x1 = numpy.random.rand(SlicesZ , N , N); - #sqweight = sqrt(weights); - sqweight = numpy.sqrt(weights.T[0]) - - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id); - - for i in range(niter): - #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, - proj_geom, - vol_geom) - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); - sino_id, y = astra.creators.create_sino3d_gpu(x1, - proj_geom, - vol_geom); - - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - #astra_mex_data3d('delete', id); - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - #clear x1 - del x1 - - return s - - - def setRegularizer(self, regularizer): - if regularizer - self.pars['regularizer'] = regularizer - - - - - -def getEntry(location): - for item in nx[location].keys(): - print (item) - - -print ("Loading Data") - -##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" -####ind = [i * 1049 for i in range(360)] -#### use only 360 images -##images = 200 -##ind = [int(i * 1049 / images) for i in range(images)] -##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) - -#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" -fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" -nx = h5py.File(fname, "r") - -# the data are stored in a particular location in the hdf5 -for item in nx['entry1/tomo_entry/data'].keys(): - print (item) - -data = nx.get('entry1/tomo_entry/data/rotation_angle') -angles = numpy.zeros(data.shape) -data.read_direct(angles) -print (angles) -# angles should be in degrees - -data = nx.get('entry1/tomo_entry/data/data') -stack = numpy.zeros(data.shape) -data.read_direct(stack) -print (data.shape) - -print ("Data Loaded") - - -# Normalize -data = nx.get('entry1/tomo_entry/instrument/detector/image_key') -itype = numpy.zeros(data.shape) -data.read_direct(itype) -# 2 is dark field -darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] -dark = darks[0] -for i in range(1, len(darks)): - dark += darks[i] -dark = dark / len(darks) -#dark[0][0] = dark[0][1] - -# 1 is flat field -flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] -flat = flats[0] -for i in range(1, len(flats)): - flat += flats[i] -flat = flat / len(flats) -#flat[0][0] = dark[0][1] - - -# 0 is projection data -proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = numpy.asarray (angle_proj) -angle_proj = angle_proj.astype(numpy.float32) - -# normalized data are -# norm = (projection - dark)/(flat-dark) - -def normalize(projection, dark, flat, def_val=0.1): - a = (projection - dark) - b = (flat-dark) - with numpy.errstate(divide='ignore', invalid='ignore'): - c = numpy.true_divide( a, b ) - c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 - return c - - -norm = [normalize(projection, dark, flat) for projection in proj] -norm = numpy.asarray (norm) -norm = norm.astype(numpy.float32) - -#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, -# angles = angle_proj, center_of_rotation = 86.2 , -# flat_field = flat, dark_field = dark, -# iterations = 15, resolution = 1, isLogScale = False, threads = 3) - -#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, -# angles = angle_proj, center_of_rotation = 86.2 , -# flat_field = flat, dark_field = dark, -# iterations = 15, resolution = 1, isLogScale = False, threads = 3) -#img_cgls = recon.reconstruct() -# -#pars = dict() -#pars['algorithm'] = Reconstructor.Algorithm.SIRT -#pars['projection_data'] = proj -#pars['angles'] = angle_proj -#pars['center_of_rotation'] = numpy.double(86.2) -#pars['flat_field'] = flat -#pars['iterations'] = 15 -#pars['dark_field'] = dark -#pars['resolution'] = 1 -#pars['isLogScale'] = False -#pars['threads'] = 3 -# -#img_sirt = recon.reconstruct(pars) -# -#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM -#img_mlem = recon.reconstruct() - -############################################################ -############################################################ -#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV -#recon.pars['regularize'] = numpy.double(0.1) -#img_cgls_conv = recon.reconstruct() - -niterations = 15 -threads = 3 - -img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - iteration_values, False) -print ("iteration values %s" % str(iteration_values)) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) -iteration_values = numpy.zeros((niterations,)) -img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) - - -##numpy.save("cgls_recon.npy", img_data) -import matplotlib.pyplot as plt -fig, ax = plt.subplots(1,6,sharey=True) -ax[0].imshow(img_cgls[80]) -ax[0].axis('off') # clear x- and y-axes -ax[1].imshow(img_sirt[80]) -ax[1].axis('off') # clear x- and y-axes -ax[2].imshow(img_mlem[80]) -ax[2].axis('off') # clear x- and y-axesplt.show() -ax[3].imshow(img_cgls_conv[80]) -ax[3].axis('off') # clear x- and y-axesplt.show() -ax[4].imshow(img_cgls_tikhonov[80]) -ax[4].axis('off') # clear x- and y-axesplt.show() -ax[5].imshow(img_cgls_TVreg[80]) -ax[5].axis('off') # clear x- and y-axesplt.show() - - -plt.show() - -#viewer = edo.CILViewer() -#viewer.setInputAsNumpy(img_cgls2) -#viewer.displaySliceActor(0) -#viewer.startRenderLoop() - -import vtk - -def NumpyToVTKImageData(numpyarray): - if (len(numpy.shape(numpyarray)) == 3): - doubleImg = vtk.vtkImageData() - shape = numpy.shape(numpyarray) - doubleImg.SetDimensions(shape[0], shape[1], shape[2]) - doubleImg.SetOrigin(0,0,0) - doubleImg.SetSpacing(1,1,1) - doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) - #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) - doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) - - for i in range(shape[0]): - for j in range(shape[1]): - for k in range(shape[2]): - doubleImg.SetScalarComponentFromDouble( - i,j,k,0, numpyarray[i][j][k]) - #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) - # rescale to appropriate VTK_UNSIGNED_SHORT - stats = vtk.vtkImageAccumulate() - stats.SetInputData(doubleImg) - stats.Update() - iMin = stats.GetMin()[0] - iMax = stats.GetMax()[0] - scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) - - shiftScaler = vtk.vtkImageShiftScale () - shiftScaler.SetInputData(doubleImg) - shiftScaler.SetScale(scale) - shiftScaler.SetShift(iMin) - shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) - shiftScaler.Update() - return shiftScaler.GetOutput() - -#writer = vtk.vtkMetaImageWriter() -#writer.SetFileName(alg + "_recon.mha") -#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) -#writer.Write() diff --git a/src/Python/ccpi/fista/__init__.py b/src/Python/ccpi/fista/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/src/Python/ccpi/fista/__init__.py +++ /dev/null diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py index ea96b53..c903712 100644 --- a/src/Python/ccpi/reconstruction/FISTAReconstructor.py +++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py @@ -21,10 +21,9 @@ import numpy -import h5py #from ccpi.reconstruction.parallelbeam import alg -from ccpi.imaging.Regularizer import Regularizer +#from ccpi.imaging.Regularizer import Regularizer from enum import Enum import astra @@ -74,18 +73,34 @@ class FISTAReconstructor(): # 3. "A novel tomographic reconstruction method based on the robust # Student's t function for suppressing data outliers" D. Kazantsev et.al. # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): - self.params = dict() - self.params['projector_geometry'] = projector_geometry - self.params['output_geometry'] = output_geometry - self.params['input_sinogram'] = input_sinogram - detectors, nangles, sliceZ = numpy.shape(input_sinogram) - self.params['detectors'] = detectors - self.params['number_og_angles'] = nangles - self.params['SlicesZ'] = sliceZ + def __init__(self, projector_geometry, output_geometry, input_sinogram, + **kwargs): + # handle parmeters: + # obligatory parameters + self.pars = dict() + self.pars['projector_geometry'] = projector_geometry # proj_geom + self.pars['output_geometry'] = output_geometry # vol_geom + self.pars['input_sinogram'] = input_sinogram # sino + sliceZ, nangles, detectors = numpy.shape(input_sinogram) + self.pars['detectors'] = detectors + self.pars['number_of_angles'] = nangles + self.pars['SlicesZ'] = sliceZ + self.pars['output_volume'] = None + + print (self.pars) + # handle optional input parameters (at instantiation) # Accepted input keywords - kw = ('number_of_iterations', + kw = ( + # mandatory fields + 'projector_geometry', + 'output_geometry', + 'input_sinogram', + 'detectors', + 'number_of_angles', + 'SlicesZ', + # optional fields + 'number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , 'weights' , @@ -93,7 +108,13 @@ class FISTAReconstructor(): 'initialize' , 'regularizer' , 'ring_lambda_R_L1', - 'ring_alpha') + 'ring_alpha', + 'subsets', + 'output_volume', + 'os_subsets', + 'os_indices', + 'os_bins') + self.acceptedInputKeywords = list(kw) # handle keyworded parameters if kwargs is not None: @@ -110,85 +131,160 @@ class FISTAReconstructor(): if 'weights' in kwargs.keys(): self.pars['weights'] = kwargs['weights'] else: - self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + self.pars['weights'] = \ + numpy.ones(numpy.shape( + self.pars['input_sinogram'])) if 'Lipschitz_constant' in kwargs.keys(): self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] else: - self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + self.pars['Lipschitz_constant'] = None - if not self.pars['ideal_image'] in kwargs.keys(): + if not 'ideal_image' in kwargs.keys(): self.pars['ideal_image'] = None - if not self.pars['region_of_interest'] : + if not 'region_of_interest'in kwargs.keys() : if self.pars['ideal_image'] == None: - pass + self.pars['region_of_interest'] = None else: - self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) - - if not self.pars['regularizer'] : + ## nonzero if the image is larger than m + fsm = numpy.frompyfunc(lambda x,m: 1 if x>m else 0, 2,1) + + self.pars['region_of_interest'] = fsm(self.pars['ideal_image'], 0) + + # the regularizer must be a correctly instantiated object + if not 'regularizer' in kwargs.keys() : self.pars['regularizer'] = None + + #RING REMOVAL + if not 'ring_lambda_R_L1' in kwargs.keys(): + self.pars['ring_lambda_R_L1'] = 0 + if not 'ring_alpha' in kwargs.keys(): + self.pars['ring_alpha'] = 1 + + # ORDERED SUBSET + if not 'subsets' in kwargs.keys(): + self.pars['subsets'] = 0 else: - # the regularizer must be a correctly instantiated object - if not self.pars['ring_lambda_R_L1']: - self.pars['ring_lambda_R_L1'] = 0 - if not self.pars['ring_alpha']: - self.pars['ring_alpha'] = 1 + self.createOrderedSubsets() + + if not 'initialize' in kwargs.keys(): + self.pars['initialize'] = False + + def setParameter(self, **kwargs): + '''set named parameter for the reconstructor engine + + raises Exception if the named parameter is not recognized + ''' + for key , value in kwargs.items(): + if key in self.acceptedInputKeywords: + self.pars[key] = value + else: + raise Exception('Wrong parameter {0} for '.format(key) + + 'reconstructor') + # setParameter + + def getParameter(self, key): + if type(key) is str: + if key in self.acceptedInputKeywords: + return self.pars[key] + else: + raise Exception('Unrecongnised parameter: {0} '.format(key) ) + elif type(key) is list: + outpars = [] + for k in key: + outpars.append(self.getParameter(k)) + return outpars + else: + raise Exception('Unhandled input {0}' .format(str(type(key)))) + + def calculateLipschitzConstantWithPowerMethod(self): ''' using Power method (PM) to establish L constant''' - #N = params.vol_geom.GridColCount - N = self.pars['output_geometry'].GridColCount - proj_geom = self.params['projector_geometry'] - vol_geom = self.params['output_geometry'] + N = self.pars['output_geometry']['GridColCount'] + proj_geom = self.pars['projector_geometry'] + vol_geom = self.pars['output_geometry'] weights = self.pars['weights'] SlicesZ = self.pars['SlicesZ'] - if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + + + if (proj_geom['type'] == 'parallel') or \ + (proj_geom['type'] == 'parallel3d'): #% for parallel geometry we can do just one slice - #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); - niter = 15;# % number of iteration for the PM + #print('Calculating Lipshitz constant for parallel beam geometry...') + niter = 5;# % number of iteration for the PM #N = params.vol_geom.GridColCount; #x1 = rand(N,N,1); x1 = numpy.random.rand(1,N,N) #sqweight = sqrt(weights(:,:,1)); - sqweight = numpy.sqrt(weights.T[0]) + sqweight = numpy.sqrt(weights[0]) proj_geomT = proj_geom.copy(); - proj_geomT.DetectorRowCount = 1; + proj_geomT['DetectorRowCount'] = 1; vol_geomT = vol_geom.copy(); vol_geomT['GridSliceCount'] = 1; + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + for i in range(niter): - if i == 0: - #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); - y = sqweight * y # element wise multiplication - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id) + # [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT); + # s = norm(x1(:)); + # x1 = x1/s; + # [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + # y = sqweight.*y; + # astra_mex_data3d('delete', sino_id); + # astra_mex_data3d('delete', id); + #print ("iteration {0}".format(i)) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geomT, + vol_geomT) + + y = (sqweight * y).copy() # element wise multiplication + + #b=fig.add_subplot(2,1,2) + #imgplot = plt.imshow(x1[0]) + #plt.show() + + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + del x1 - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(), + proj_geomT, + vol_geomT) + del y + + s = numpy.linalg.norm(x1) ### this line? - x1 = x1/s; - ### this line? - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); - y = sqweight*y; + x1 = (x1/s).copy(); + + # ### this line? + # sino_id, y = astra.creators.create_sino3d_gpu(x1, + # proj_geomT, + # vol_geomT); + # y = sqweight * y; astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); + astra.matlab.data3d('delete', idx) + print ("iteration {0} s= {1}".format(i,s)) + #end del proj_geomT del vol_geomT + #plt.show() else: #% divergen beam geometry - #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + print('Calculating Lipshitz constant for divergen beam geometry...') niter = 8; #% number of iteration for PM x1 = numpy.random.rand(SlicesZ , N , N); #sqweight = sqrt(weights); - sqweight = numpy.sqrt(weights.T[0]) + sqweight = numpy.sqrt(weights[0]) sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); y = sqweight*y; @@ -217,6 +313,7 @@ class FISTAReconstructor(): #end #clear x1 del x1 + return s @@ -225,130 +322,291 @@ class FISTAReconstructor(): if regularizer is not None: self.pars['regularizer'] = regularizer + + def initialize(self): + # convenience variable storage + proj_geom = self.pars['projector_geometry'] + vol_geom = self.pars['output_geometry'] + sino = self.pars['input_sinogram'] + + # a 'warm start' with SIRT method + # Create a data object for the reconstruction + rec_id = astra.matlab.data3d('create', '-vol', + vol_geom); + + #sinogram_id = astra_mex_data3d('create', '-proj3d', proj_geom, sino); + sinogram_id = astra.matlab.data3d('create', '-proj3d', + proj_geom, + sino) + + sirt_config = astra.astra_dict('SIRT3D_CUDA') + sirt_config['ReconstructionDataId' ] = rec_id + sirt_config['ProjectionDataId'] = sinogram_id + + sirt = astra.algorithm.create(sirt_config) + astra.algorithm.run(sirt, iterations=35) + X = astra.matlab.data3d('get', rec_id) + + # clean up memory + astra.matlab.data3d('delete', rec_id) + astra.matlab.data3d('delete', sinogram_id) + astra.algorithm.delete(sirt) + + + + return X + + def createOrderedSubsets(self, subsets=None): + if subsets is None: + try: + subsets = self.getParameter('subsets') + except Exception(): + subsets = 0 + #return subsets + else: + self.setParameter(subsets=subsets) + + + angles = self.getParameter('projector_geometry')['ProjectionAngles'] + + #binEdges = numpy.linspace(angles.min(), + # angles.max(), + # subsets + 1) + binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) + # get rearranged subset indices + IndicesReorg = numpy.zeros((numpy.shape(angles)), dtype=numpy.int32) + counterM = 0 + for ii in range(binsDiscr.max()): + counter = 0 + for jj in range(subsets): + curr_index = ii + jj + counter + #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM)) + if binsDiscr[jj] > ii: + if (counterM < numpy.size(IndicesReorg)): + IndicesReorg[counterM] = curr_index + counterM = counterM + 1 + + counter = counter + binsDiscr[jj] - 1 + + # store the OS in parameters + self.setParameter(os_subsets=subsets, + os_bins=binsDiscr, + os_indices=IndicesReorg) + + + def prepareForIteration(self): + print ("FISTA Reconstructor: prepare for iteration") + + self.residual_error = numpy.zeros((self.pars['number_of_iterations'])) + self.objective = numpy.zeros((self.pars['number_of_iterations'])) + + #2D array (for 3D data) of sparse "ring" + detectors, nangles, sliceZ = numpy.shape(self.pars['input_sinogram']) + self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float) + # another ring variable + self.r_x = self.r.copy() + + self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram'])) + + if self.getParameter('Lipschitz_constant') is None: + self.pars['Lipschitz_constant'] = \ + self.calculateLipschitzConstantWithPowerMethod() + # errors vector (if the ground truth is given) + self.Resid_error = numpy.zeros((self.getParameter('number_of_iterations'))); + # objective function values vector + self.objective = numpy.zeros((self.getParameter('number_of_iterations'))); + + + # prepareForIteration + + def iterate(self, Xin=None): + print ("FISTA Reconstructor: iterate") + + if Xin is None: + if self.getParameter('initialize'): + X = self.initialize() + else: + N = vol_geom['GridColCount'] + X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) + else: + # copy by reference + X = Xin + # store the output volume in the parameters + self.setParameter(output_volume=X) + X_t = X.copy() + # convenience variable storage + proj_geom , vol_geom, sino , \ + SlicesZ = self.getParameter([ 'projector_geometry' , + 'output_geometry', + 'input_sinogram', + 'SlicesZ' ]) + + t = 1 + + for i in range(self.getParameter('number_of_iterations')): + X_old = X.copy() + t_old = t + r_old = self.r.copy() + if self.getParameter('projector_geometry')['type'] == 'parallel' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat_vec': + # if the geometry is parallel use slice-by-slice + # projection-backprojection routine + #sino_updt = zeros(size(sino),'single'); + proj_geomT = proj_geom.copy() + proj_geomT['DetectorRowCount'] = 1 + vol_geomT = vol_geom.copy() + vol_geomT['GridSliceCount'] = 1; + self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) + for kkk in range(SlicesZ): + sino_id, self.sino_updt[kkk] = \ + astra.creators.create_sino3d_gpu( + X_t[kkk:kkk+1], proj_geomT, vol_geomT) + astra.matlab.data3d('delete', sino_id) + else: + # for divergent 3D geometry (watch the GPU memory overflow in + # ASTRA versions < 1.8) + #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); + sino_id, self.sino_updt = astra.creators.create_sino3d_gpu( + X_t, proj_geom, vol_geom) + + + ## RING REMOVAL + self.ringRemoval(i) + ## Projection/Backprojection Routine + self.projectionBackprojection(X, X_t) + astra.matlab.data3d('delete', sino_id) + ## REGULARIZATION + X = self.regularize(X) + ## Update Loop + X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old) + self.setParameter(output_volume=X) + return X + ## iterate - + def ringRemoval(self, i): + print ("FISTA Reconstructor: ring removal") + residual = self.residual + lambdaR_L1 , alpha_ring , weights , L_const , sino= \ + self.getParameter(['ring_lambda_R_L1', + 'ring_alpha' , 'weights', + 'Lipschitz_constant', + 'input_sinogram']) + r_x = self.r_x + sino_updt = self.sino_updt + + SlicesZ, anglesNumb, Detectors = \ + numpy.shape(self.getParameter('input_sinogram')) + if lambdaR_L1 > 0 : + for kkk in range(anglesNumb): + + residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ + ((sino_updt[:,kkk,:]).squeeze() - \ + (sino[:,kkk,:]).squeeze() -\ + (alpha_ring * r_x) + ) + vec = residual.sum(axis = 1) + #if SlicesZ > 1: + # vec = vec[:,1,:].squeeze() + self.r = (r_x - (1./L_const) * vec).copy() + self.objective[i] = (0.5 * (residual ** 2).sum()) + def projectionBackprojection(self, X, X_t): + print ("FISTA Reconstructor: projection-backprojection routine") + + # a few useful variables + SlicesZ, anglesNumb, Detectors = \ + numpy.shape(self.getParameter('input_sinogram')) + residual = self.residual + proj_geom , vol_geom , L_const = \ + self.getParameter(['projector_geometry' , + 'output_geometry', + 'Lipschitz_constant']) + + + if self.getParameter('projector_geometry')['type'] == 'parallel' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat_vec': + # if the geometry is parallel use slice-by-slice + # projection-backprojection routine + #sino_updt = zeros(size(sino),'single'); + proj_geomT = proj_geom.copy() + proj_geomT['DetectorRowCount'] = 1 + vol_geomT = vol_geom.copy() + vol_geomT['GridSliceCount'] = 1; + x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) + + for kkk in range(SlicesZ): + + x_id, x_temp[kkk] = \ + astra.creators.create_backprojection3d_gpu( + residual[kkk:kkk+1], + proj_geomT, vol_geomT) + astra.matlab.data3d('delete', x_id) + else: + x_id, x_temp = \ + astra.creators.create_backprojection3d_gpu( + residual, proj_geom, vol_geom) + + X = X_t - (1/L_const) * x_temp + #astra.matlab.data3d('delete', sino_id) + astra.matlab.data3d('delete', x_id) -def getEntry(location): - for item in nx[location].keys(): - print (item) - - -print ("Loading Data") - -##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" -####ind = [i * 1049 for i in range(360)] -#### use only 360 images -##images = 200 -##ind = [int(i * 1049 / images) for i in range(images)] -##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) - -#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" -#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" -##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5" -##nx = h5py.File(fname, "r") -## -### the data are stored in a particular location in the hdf5 -##for item in nx['entry1/tomo_entry/data'].keys(): -## print (item) -## -##data = nx.get('entry1/tomo_entry/data/rotation_angle') -##angles = numpy.zeros(data.shape) -##data.read_direct(angles) -##print (angles) -### angles should be in degrees -## -##data = nx.get('entry1/tomo_entry/data/data') -##stack = numpy.zeros(data.shape) -##data.read_direct(stack) -##print (data.shape) -## -##print ("Data Loaded") -## -## -### Normalize -##data = nx.get('entry1/tomo_entry/instrument/detector/image_key') -##itype = numpy.zeros(data.shape) -##data.read_direct(itype) -### 2 is dark field -##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] -##dark = darks[0] -##for i in range(1, len(darks)): -## dark += darks[i] -##dark = dark / len(darks) -###dark[0][0] = dark[0][1] -## -### 1 is flat field -##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] -##flat = flats[0] -##for i in range(1, len(flats)): -## flat += flats[i] -##flat = flat / len(flats) -###flat[0][0] = dark[0][1] -## -## -### 0 is projection data -##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] -##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] -##angle_proj = numpy.asarray (angle_proj) -##angle_proj = angle_proj.astype(numpy.float32) -## -### normalized data are -### norm = (projection - dark)/(flat-dark) -## -##def normalize(projection, dark, flat, def_val=0.1): -## a = (projection - dark) -## b = (flat-dark) -## with numpy.errstate(divide='ignore', invalid='ignore'): -## c = numpy.true_divide( a, b ) -## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 -## return c -## -## -##norm = [normalize(projection, dark, flat) for projection in proj] -##norm = numpy.asarray (norm) -##norm = norm.astype(numpy.float32) - - -##niterations = 15 -##threads = 3 -## -##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -## -##iteration_values = numpy.zeros((niterations,)) -##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -## iteration_values, False) -##print ("iteration values %s" % str(iteration_values)) -## -##iteration_values = numpy.zeros((niterations,)) -##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -## numpy.double(1e-5), iteration_values , False) -##print ("iteration values %s" % str(iteration_values)) -##iteration_values = numpy.zeros((niterations,)) -##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -## numpy.double(1e-5), iteration_values , False) -##print ("iteration values %s" % str(iteration_values)) -## -## -####numpy.save("cgls_recon.npy", img_data) -##import matplotlib.pyplot as plt -##fig, ax = plt.subplots(1,6,sharey=True) -##ax[0].imshow(img_cgls[80]) -##ax[0].axis('off') # clear x- and y-axes -##ax[1].imshow(img_sirt[80]) -##ax[1].axis('off') # clear x- and y-axes -##ax[2].imshow(img_mlem[80]) -##ax[2].axis('off') # clear x- and y-axesplt.show() -##ax[3].imshow(img_cgls_conv[80]) -##ax[3].axis('off') # clear x- and y-axesplt.show() -##ax[4].imshow(img_cgls_tikhonov[80]) -##ax[4].axis('off') # clear x- and y-axesplt.show() -##ax[5].imshow(img_cgls_TVreg[80]) -##ax[5].axis('off') # clear x- and y-axesplt.show() -## -## -##plt.show() -## + def regularize(self, X): + print ("FISTA Reconstructor: regularize") + + regularizer = self.getParameter('regularizer') + if regularizer is not None: + return regularizer(input=X) + else: + return X + + def updateLoop(self, i, X, X_old, r_old, t, t_old): + print ("FISTA Reconstructor: update loop") + lambdaR_L1 = self.getParameter('ring_lambda_R_L1') + if lambdaR_L1 > 0: + self.r = numpy.max( + numpy.abs(self.r) - lambdaR_L1 , 0) * \ + numpy.sign(self.r) + t = (1 + numpy.sqrt(1 + 4 * t**2))/2 + X_t = X + (((t_old -1)/t) * (X - X_old)) + + if lambdaR_L1 > 0: + self.r_x = self.r + \ + (((t_old-1)/t) * (self.r - r_old)) + + if self.getParameter('region_of_interest') is None: + string = 'Iteration Number {0} | Objective {1} \n' + print (string.format( i, self.objective[i])) + else: + ROI , X_ideal = fistaRecon.getParameter('region_of_interest', + 'ideal_image') + + Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) + string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' + print (string.format(i,Resid_error[i], self.objective[i])) + return (X , X_t, t) + + def os_iterate(self, Xin=None): + print ("FISTA Reconstructor: iterate") + + if Xin is None: + if self.getParameter('initialize'): + X = self.initialize() + else: + N = vol_geom['GridColCount'] + X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) + else: + # copy by reference + X = Xin + # store the output volume in the parameters + self.setParameter(output_volume=X) + X_t = X.copy() + # some useful constants + proj_geom , vol_geom, sino , \ + SlicesZ, weights , alpha_ring ,\ + lambdaR_L1 , L_const = self.getParameter( + ['projector_geometry' , 'output_geometry', + 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha' , + 'ring_lambda_R_L1', 'Lipschitz_constant']) diff --git a/src/Python/compile-fista.bat.in b/src/Python/compile-fista.bat.in new file mode 100644 index 0000000..b1db686 --- /dev/null +++ b/src/Python/compile-fista.bat.in @@ -0,0 +1,7 @@ +set CIL_VERSION=@CIL_VERSION@ + +set PREFIX=@CONDA_ENVIRONMENT_PREFIX@ +set LIBRARY_INC=@CONDA_ENVIRONMENT_LIBRARY_INC@ + +REM activate @CONDA_ENVIRONMENT@ +conda build fista-recipe --python=@PYTHON_VERSION_MAJOR@.@PYTHON_VERSION_MINOR@ --numpy=@NUMPY_VERSION@ -c ccpi -c conda-forge diff --git a/src/Python/compile-fista.sh.in b/src/Python/compile-fista.sh.in new file mode 100644 index 0000000..267f014 --- /dev/null +++ b/src/Python/compile-fista.sh.in @@ -0,0 +1,9 @@ +#!/bin/sh +# compile within the right conda environment +#module load python/anaconda +#source activate @CONDA_ENVIRONMENT@ + +export CIL_VERSION=@CIL_VERSION@ +export LIBRARY_INC=@CONDA_ENVIRONMENT_LIBRARY_INC@ + +conda build fista-recipe --python=@PYTHON_VERSION_MAJOR@.@PYTHON_VERSION_MINOR@ --numpy=@NUMPY_VERSION@ -c ccpi diff --git a/src/Python/conda-recipe/meta.yaml b/src/Python/conda-recipe/meta.yaml index c5b7a89..7068e9d 100644 --- a/src/Python/conda-recipe/meta.yaml +++ b/src/Python/conda-recipe/meta.yaml @@ -1,5 +1,5 @@ package: - name: ccpi-fista + name: ccpi-regularizers version: {{ environ['CIL_VERSION'] }} diff --git a/src/Python/fista-recipe/build.sh b/src/Python/fista-recipe/build.sh new file mode 100644 index 0000000..e3f3552 --- /dev/null +++ b/src/Python/fista-recipe/build.sh @@ -0,0 +1,10 @@ +if [ -z "$CIL_VERSION" ]; then + echo "Need to set CIL_VERSION" + exit 1 +fi +mkdir "$SRC_DIR/ccpifista" +cp -r "$RECIPE_DIR/.." "$SRC_DIR/ccpifista" + +cd $SRC_DIR/ccpifista + +$PYTHON setup-fista.py install diff --git a/src/Python/fista-recipe/meta.yaml b/src/Python/fista-recipe/meta.yaml new file mode 100644 index 0000000..265541f --- /dev/null +++ b/src/Python/fista-recipe/meta.yaml @@ -0,0 +1,29 @@ +package: + name: ccpi-fista + version: {{ environ['CIL_VERSION'] }} + + +build: + preserve_egg_dir: False + script_env: + - CIL_VERSION +# number: 0 + +requirements: + build: + - python + - numpy + - setuptools + + run: + - python + - numpy + #- astra-toolbox + - ccpi-regularizers + + + +about: + home: http://www.ccpi.ac.uk + license: Apache v.2.0 license + summary: 'CCPi Core Imaging Library (Viewer)' diff --git a/src/Python/setup-fista.py.in b/src/Python/setup-fista.py.in new file mode 100644 index 0000000..c5c9f4d --- /dev/null +++ b/src/Python/setup-fista.py.in @@ -0,0 +1,27 @@ +from distutils.core import setup +#from setuptools import setup, find_packages +import os + +cil_version=os.environ['CIL_VERSION'] +if cil_version == '': + print("Please set the environmental variable CIL_VERSION") + sys.exit(1) + +setup( + name="ccpi-fista", + version=cil_version, + packages=['ccpi','ccpi.reconstruction'], + install_requires=['numpy'], + + zip_safe = False, + + # metadata for upload to PyPI + author="Edoardo Pasca", + author_email="edo.paskino@gmail.com", + description='CCPi Core Imaging Library - FISTA Reconstructor module', + license="Apache v2.0", + keywords="tomography interative reconstruction", + url="http://www.ccpi.ac.uk", # project home page, if any + + # could also include long_description, download_url, classifiers, etc. +) diff --git a/src/Python/setup.py.in b/src/Python/setup.py.in index 0a1f4ad..12e8af1 100644 --- a/src/Python/setup.py.in +++ b/src/Python/setup.py.in @@ -44,7 +44,7 @@ else: setup( name='ccpi', - description='CCPi Core Imaging Library - FISTA Reconstruction Module', + description='CCPi Core Imaging Library - Image Regularizers', version=cil_version, cmdclass = {'build_ext': build_ext}, ext_modules = [Extension("ccpi.imaging.cpu_regularizers", @@ -65,3 +65,5 @@ setup( zip_safe = False, packages = {'ccpi','ccpi.imaging'}, ) + + diff --git a/src/Python/test_reconstructor-os.py b/src/Python/test/test_reconstructor-os.py index aee70a4..6c82ae0 100644 --- a/src/Python/test_reconstructor-os.py +++ b/src/Python/test/test_reconstructor-os.py @@ -9,9 +9,10 @@ Based on DemoRD2.m import h5py import numpy -from ccpi.fista.FISTAReconstructor import FISTAReconstructor +from ccpi.reconstruction.FISTAReconstructor import FISTAReconstructor import astra import matplotlib.pyplot as plt +from ccpi.imaging.Regularizer import Regularizer def RMSE(signal1, signal2): '''RMSE Root Mean Squared Error''' @@ -76,9 +77,18 @@ fistaRecon.setParameter(Lipschitz_constant = 767893952.0) fistaRecon.setParameter(ring_alpha = 21) fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) + +reg = Regularizer(Regularizer.Algorithm.LLT_model) +reg.setParameter(regularization_parameter=25, + time_step=0.0003, + tolerance_constant=0.0001, + number_of_iterations=300) + ## Ordered subset if True: subsets = 16 + fistaRecon.setParameter(subsets=subsets) + fistaRecon.createOrderedSubsets() angles = fistaRecon.getParameter('projector_geometry')['ProjectionAngles'] #binEdges = numpy.linspace(angles.min(), # angles.max(), @@ -146,6 +156,7 @@ if True: fistaRecon.residual2 = numpy.zeros(numpy.shape(fistaRecon.pars['input_sinogram'])) residual2 = fistaRecon.residual2 sino_updt_FULL = fistaRecon.residual.copy() + r_x = fistaRecon.r.copy() print ("starting iterations") ## % Outer FISTA iterations loop @@ -206,8 +217,13 @@ if True: # the number of projections per subset numProjSub = fistaRecon.getParameter('os_bins')[ss] CurrSubIndices = fistaRecon.getParameter('os_indices')\ - [counterInd:counterInd+numProjSub-1] - proj_geomSUB['ProjectionAngles'] = angles[CurrSubIndeces] + [counterInd:counterInd+numProjSub] + #print ("Len CurrSubIndices {0}".format(numProjSub)) + mask = numpy.zeros(numpy.shape(angles), dtype=bool) + cc = 0 + for j in range(len(CurrSubIndices)): + mask[int(CurrSubIndices[j])] = True + proj_geomSUB['ProjectionAngles'] = angles[mask] shape = list(numpy.shape(fistaRecon.getParameter('input_sinogram'))) shape[1] = numProjSub @@ -246,7 +262,8 @@ if True: ## sino_updt_FULL(:,indC,:) = squeeze(sino_updt_Sub(:,kkk,:)); % filling the full sinogram ## end for kkk in range(numProjSub): - indC = CurrSubIndices[kkk] + #print ("ring removal indC ... {0}".format(kkk)) + indC = int(CurrSubIndices[kkk]) residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \ (sino_updt_Sub[:,kkk,:].squeeze() - \ sino[:,indC,:].squeeze() - alpha_ring * r_x) @@ -288,7 +305,8 @@ if True: # regularizer = fistaRecon.getParameter('regularizer') # for slices: # out = regularizer(input=X) - print ("skipping regularizer") + print ("regularizer") + #X = reg(input=X) ## FINAL @@ -312,7 +330,8 @@ if True: Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' print (string.format(i,Resid_error[i], objective[i])) - + + numpy.save("X_out_os.npy", X) else: fistaRecon = FISTAReconstructor(proj_geom, diff --git a/src/Python/test/test_reconstructor.py b/src/Python/test/test_reconstructor.py new file mode 100644 index 0000000..3342301 --- /dev/null +++ b/src/Python/test/test_reconstructor.py @@ -0,0 +1,309 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Aug 23 16:34:49 2017 + +@author: ofn77899 +Based on DemoRD2.m +""" + +import h5py +import numpy + +from ccpi.reconstruction.FISTAReconstructor import FISTAReconstructor +import astra +import matplotlib.pyplot as plt + +def RMSE(signal1, signal2): + '''RMSE Root Mean Squared Error''' + if numpy.shape(signal1) == numpy.shape(signal2): + err = (signal1 - signal2) + err = numpy.sum( err * err )/numpy.size(signal1); # MSE + err = sqrt(err); # RMSE + return err + else: + raise Exception('Input signals must have the same shape') + +filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5' +nx = h5py.File(filename, "r") +#getEntry(nx, '/') +# I have exported the entries as children of / +entries = [entry for entry in nx['/'].keys()] +print (entries) + +Sino3D = numpy.asarray(nx.get('/Sino3D'), dtype="float32") +Weights3D = numpy.asarray(nx.get('/Weights3D'), dtype="float32") +angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] +angles_rad = numpy.asarray(nx.get('/angles_rad'), dtype="float32") +recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] +size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] +slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] + +Z_slices = 20 +det_row_count = Z_slices +# next definition is just for consistency of naming +det_col_count = size_det + +detectorSpacingX = 1.0 +detectorSpacingY = detectorSpacingX + + +proj_geom = astra.creators.create_proj_geom('parallel3d', + detectorSpacingX, + detectorSpacingY, + det_row_count, + det_col_count, + angles_rad) + +#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices); +image_size_x = recon_size +image_size_y = recon_size +image_size_z = Z_slices +vol_geom = astra.creators.create_vol_geom( image_size_x, + image_size_y, + image_size_z) + +## First pass the arguments to the FISTAReconstructor and test the +## Lipschitz constant + +fistaRecon = FISTAReconstructor(proj_geom, + vol_geom, + Sino3D , + weights=Weights3D) + +print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) +fistaRecon.setParameter(number_of_iterations = 12) +fistaRecon.setParameter(Lipschitz_constant = 767893952.0) +fistaRecon.setParameter(ring_alpha = 21) +fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) + +reg = Regularizer(Regularizer.Algorithm.LLT_model) +reg.setParameter(regularization_parameter=25, + time_step=0.0003, + tolerance_constant=0.0001, + number_of_iterations=300) +fistaRecon.setParameter(regularizer = reg) + +## Ordered subset +if False: + subsets = 16 + angles = fistaRecon.getParameter('projector_geometry')['ProjectionAngles'] + #binEdges = numpy.linspace(angles.min(), + # angles.max(), + # subsets + 1) + binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) + # get rearranged subset indices + IndicesReorg = numpy.zeros((numpy.shape(angles))) + counterM = 0 + for ii in range(binsDiscr.max()): + counter = 0 + for jj in range(subsets): + curr_index = ii + jj + counter + #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM)) + if binsDiscr[jj] > ii: + if (counterM < numpy.size(IndicesReorg)): + IndicesReorg[counterM] = curr_index + counterM = counterM + 1 + + counter = counter + binsDiscr[jj] - 1 + + +if False: + print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) + print ("prepare for iteration") + fistaRecon.prepareForIteration() + + + + print("initializing ...") + if False: + # if X doesn't exist + #N = params.vol_geom.GridColCount + N = vol_geom['GridColCount'] + print ("N " + str(N)) + X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) + else: + #X = fistaRecon.initialize() + X = numpy.load("X.npy") + + print (numpy.shape(X)) + X_t = X.copy() + print ("initialized") + proj_geom , vol_geom, sino , \ + SlicesZ = fistaRecon.getParameter(['projector_geometry' , + 'output_geometry', + 'input_sinogram', + 'SlicesZ']) + + #fistaRecon.setParameter(number_of_iterations = 3) + iterFISTA = fistaRecon.getParameter('number_of_iterations') + # errors vector (if the ground truth is given) + Resid_error = numpy.zeros((iterFISTA)); + # objective function values vector + objective = numpy.zeros((iterFISTA)); + + + t = 1 + + + print ("starting iterations") +## % Outer FISTA iterations loop + for i in range(fistaRecon.getParameter('number_of_iterations')): + X_old = X.copy() + t_old = t + r_old = fistaRecon.r.copy() + if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ + fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or \ + fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec' : + # if the geometry is parallel use slice-by-slice + # projection-backprojection routine + #sino_updt = zeros(size(sino),'single'); + proj_geomT = proj_geom.copy() + proj_geomT['DetectorRowCount'] = 1 + vol_geomT = vol_geom.copy() + vol_geomT['GridSliceCount'] = 1; + sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) + for kkk in range(SlicesZ): + sino_id, sino_updt[kkk] = \ + astra.creators.create_sino3d_gpu( + X_t[kkk:kkk+1], proj_geom, vol_geom) + astra.matlab.data3d('delete', sino_id) + else: + # for divergent 3D geometry (watch the GPU memory overflow in + # ASTRA versions < 1.8) + #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); + sino_id, sino_updt = astra.creators.create_sino3d_gpu( + X_t, proj_geom, vol_geom) + + ## RING REMOVAL + residual = fistaRecon.residual + lambdaR_L1 , alpha_ring , weights , L_const= \ + fistaRecon.getParameter(['ring_lambda_R_L1', + 'ring_alpha' , 'weights', + 'Lipschitz_constant']) + r_x = fistaRecon.r_x + SlicesZ, anglesNumb, Detectors = \ + numpy.shape(fistaRecon.getParameter('input_sinogram')) + if lambdaR_L1 > 0 : + print ("ring removal") + for kkk in range(anglesNumb): + + residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ + ((sino_updt[:,kkk,:]).squeeze() - \ + (sino[:,kkk,:]).squeeze() -\ + (alpha_ring * r_x) + ) + vec = residual.sum(axis = 1) + #if SlicesZ > 1: + # vec = vec[:,1,:].squeeze() + fistaRecon.r = (r_x - (1./L_const) * vec).copy() + objective[i] = (0.5 * (residual ** 2).sum()) +## % the ring removal part (Group-Huber fidelity) +## for kkk = 1:anglesNumb +## residual(:,kkk,:) = squeeze(weights(:,kkk,:)).* +## (squeeze(sino_updt(:,kkk,:)) - +## (squeeze(sino(:,kkk,:)) - alpha_ring.*r_x)); +## end +## vec = sum(residual,2); +## if (SlicesZ > 1) +## vec = squeeze(vec(:,1,:)); +## end +## r = r_x - (1./L_const).*vec; +## objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output + + + + # Projection/Backprojection Routine + if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ + fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or\ + fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec': + x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) + print ("Projection/Backprojection Routine") + for kkk in range(SlicesZ): + + x_id, x_temp[kkk] = \ + astra.creators.create_backprojection3d_gpu( + residual[kkk:kkk+1], + proj_geomT, vol_geomT) + astra.matlab.data3d('delete', x_id) + else: + x_id, x_temp = \ + astra.creators.create_backprojection3d_gpu( + residual, proj_geom, vol_geom) + + X = X_t - (1/L_const) * x_temp + astra.matlab.data3d('delete', sino_id) + astra.matlab.data3d('delete', x_id) + + + ## REGULARIZATION + ## SKIPPING FOR NOW + ## Should be simpli + # regularizer = fistaRecon.getParameter('regularizer') + # for slices: + # out = regularizer(input=X) + print ("skipping regularizer") + + + ## FINAL + print ("final") + lambdaR_L1 = fistaRecon.getParameter('ring_lambda_R_L1') + if lambdaR_L1 > 0: + fistaRecon.r = numpy.max( + numpy.abs(fistaRecon.r) - lambdaR_L1 , 0) * \ + numpy.sign(fistaRecon.r) + t = (1 + numpy.sqrt(1 + 4 * t**2))/2 + X_t = X + (((t_old -1)/t) * (X - X_old)) + + if lambdaR_L1 > 0: + fistaRecon.r_x = fistaRecon.r + \ + (((t_old-1)/t) * (fistaRecon.r - r_old)) + + if fistaRecon.getParameter('region_of_interest') is None: + string = 'Iteration Number {0} | Objective {1} \n' + print (string.format( i, objective[i])) + else: + ROI , X_ideal = fistaRecon.getParameter('region_of_interest', + 'ideal_image') + + Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) + string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' + print (string.format(i,Resid_error[i], objective[i])) + +## if (lambdaR_L1 > 0) +## r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector +## end +## +## t = (1 + sqrt(1 + 4*t^2))/2; % updating t +## X_t = X + ((t_old-1)/t).*(X - X_old); % updating X +## +## if (lambdaR_L1 > 0) +## r_x = r + ((t_old-1)/t).*(r - r_old); % updating r +## end +## +## if (show == 1) +## figure(10); imshow(X(:,:,slice), [0 maxvalplot]); +## if (lambdaR_L1 > 0) +## figure(11); plot(r); title('Rings offset vector') +## end +## pause(0.01); +## end +## if (strcmp(X_ideal, 'none' ) == 0) +## Resid_error(i) = RMSE(X(ROI), X_ideal(ROI)); +## fprintf('%s %i %s %s %.4f %s %s %f \n', 'Iteration Number:', i, '|', 'Error RMSE:', Resid_error(i), '|', 'Objective:', objective(i)); +## else +## fprintf('%s %i %s %s %f \n', 'Iteration Number:', i, '|', 'Objective:', objective(i)); +## end +else: + fistaRecon = FISTAReconstructor(proj_geom, + vol_geom, + Sino3D , + weights=Weights3D) + + print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) + fistaRecon.setParameter(number_of_iterations = 12) + fistaRecon.setParameter(Lipschitz_constant = 767893952.0) + fistaRecon.setParameter(ring_alpha = 21) + fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) + fistaRecon.prepareForIteration() + X = fistaRecon.iterate(numpy.load("X.npy")) + numpy.save("X_out.npy", X) |