From 73fed4964d81f1f47a0b6ecbe66517f569327b27 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 14:31:16 +0100 Subject: initial commit of Reconstructor.py --- src/Python/ccpi/reconstruction/Reconstructor.py | 598 ++++++++++++++++++++++++ 1 file changed, 598 insertions(+) create mode 100644 src/Python/ccpi/reconstruction/Reconstructor.py (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/reconstruction/Reconstructor.py b/src/Python/ccpi/reconstruction/Reconstructor.py new file mode 100644 index 0000000..ba67327 --- /dev/null +++ b/src/Python/ccpi/reconstruction/Reconstructor.py @@ -0,0 +1,598 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +from ccpi.reconstruction.parallelbeam import alg + +from Regularizer import Regularizer +from enum import Enum + +import astra + + +class Reconstructor: + + class Algorithm(Enum): + CGLS = alg.cgls + CGLS_CONV = alg.cgls_conv + SIRT = alg.sirt + MLEM = alg.mlem + CGLS_TICHONOV = alg.cgls_tikhonov + CGLS_TVREG = alg.cgls_TVreg + FISTA = 'fista' + + def __init__(self, algorithm = None, projection_data = None, + angles = None, center_of_rotation = None , + flat_field = None, dark_field = None, + iterations = None, resolution = None, isLogScale = False, threads = None, + normalized_projection = None): + + self.pars = dict() + self.pars['algorithm'] = algorithm + self.pars['projection_data'] = projection_data + self.pars['normalized_projection'] = normalized_projection + self.pars['angles'] = angles + self.pars['center_of_rotation'] = numpy.double(center_of_rotation) + self.pars['flat_field'] = flat_field + self.pars['iterations'] = iterations + self.pars['dark_field'] = dark_field + self.pars['resolution'] = resolution + self.pars['isLogScale'] = isLogScale + self.pars['threads'] = threads + if (iterations != None): + self.pars['iterationValues'] = numpy.zeros((iterations)) + + if projection_data != None and dark_field != None and flat_field != None: + norm = self.normalize(projection_data, dark_field, flat_field, 0.1) + self.pars['normalized_projection'] = norm + + + def setPars(self, parameters): + keys = ['algorithm','projection_data' ,'normalized_projection', \ + 'angles' , 'center_of_rotation' , 'flat_field', \ + 'iterations','dark_field' , 'resolution', 'isLogScale' , \ + 'threads' , 'iterationValues', 'regularize'] + + for k in keys: + if k not in parameters.keys(): + self.pars[k] = None + else: + self.pars[k] = parameters[k] + + + def sanityCheck(self): + projection_data = self.pars['projection_data'] + dark_field = self.pars['dark_field'] + flat_field = self.pars['flat_field'] + angles = self.pars['angles'] + + if projection_data != None and dark_field != None and \ + angles != None and flat_field != None: + data_shape = numpy.shape(projection_data) + angle_shape = numpy.shape(angles) + + if angle_shape[0] != data_shape[0]: + #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ + # (angle_shape[0] , data_shape[0]) ) + return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ + (angle_shape[0] , data_shape[0]) ) + + if data_shape[1:] != numpy.shape(flat_field): + #raise Exception('Projection and flat field dimensions do not match') + return (False , 'Projection and flat field dimensions do not match') + if data_shape[1:] != numpy.shape(dark_field): + #raise Exception('Projection and dark field dimensions do not match') + return (False , 'Projection and dark field dimensions do not match') + + return (True , '' ) + elif self.pars['normalized_projection'] != None: + data_shape = numpy.shape(self.pars['normalized_projection']) + angle_shape = numpy.shape(angles) + + if angle_shape[0] != data_shape[0]: + #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ + # (angle_shape[0] , data_shape[0]) ) + return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ + (angle_shape[0] , data_shape[0]) ) + else: + return (True , '' ) + else: + return (False , 'Not enough data') + + def reconstruct(self, parameters = None): + if parameters != None: + self.setPars(parameters) + + go , reason = self.sanityCheck() + if go: + return self._reconstruct() + else: + raise Exception(reason) + + + def _reconstruct(self, parameters=None): + if parameters!=None: + self.setPars(parameters) + parameters = self.pars + + if parameters['algorithm'] != None and \ + parameters['normalized_projection'] != None and \ + parameters['angles'] != None and \ + parameters['center_of_rotation'] != None and \ + parameters['iterations'] != None and \ + parameters['resolution'] != None and\ + parameters['threads'] != None and\ + parameters['isLogScale'] != None: + + + if parameters['algorithm'] in (Reconstructor.Algorithm.CGLS, + Reconstructor.Algorithm.MLEM, Reconstructor.Algorithm.SIRT): + #store parameters + self.pars = parameters + result = parameters['algorithm']( + parameters['normalized_projection'] , + parameters['angles'], + parameters['center_of_rotation'], + parameters['resolution'], + parameters['iterations'], + parameters['threads'] , + parameters['isLogScale'] + ) + return result + elif parameters['algorithm'] in (Reconstructor.Algorithm.CGLS_CONV, + Reconstructor.Algorithm.CGLS_TICHONOV, + Reconstructor.Algorithm.CGLS_TVREG) : + self.pars = parameters + result = parameters['algorithm']( + parameters['normalized_projection'] , + parameters['angles'], + parameters['center_of_rotation'], + parameters['resolution'], + parameters['iterations'], + parameters['threads'] , + parameters['regularize'], + numpy.zeros((parameters['iterations'])), + parameters['isLogScale'] + ) + + elif parameters['algorithm'] == Reconstructor.Algorithm.FISTA: + pass + + else: + if parameters['projection_data'] != None and \ + parameters['dark_field'] != None and \ + parameters['flat_field'] != None: + norm = self.normalize(parameters['projection_data'], + parameters['dark_field'], + parameters['flat_field'], 0.1) + self.pars['normalized_projection'] = norm + return self._reconstruct(parameters) + + + + def _normalize(self, projection, dark, flat, def_val=0): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + def normalize(self, projections, dark, flat, def_val=0): + norm = [self._normalize(projection, dark, flat, def_val) for projection in projections] + return numpy.asarray (norm, dtype=numpy.float32) + + + +class FISTA(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , + 'weights' , 'region_of_interest' , 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +nx = h5py.File(fname, "r") + +# the data are stored in a particular location in the hdf5 +for item in nx['entry1/tomo_entry/data'].keys(): + print (item) + +data = nx.get('entry1/tomo_entry/data/rotation_angle') +angles = numpy.zeros(data.shape) +data.read_direct(angles) +print (angles) +# angles should be in degrees + +data = nx.get('entry1/tomo_entry/data/data') +stack = numpy.zeros(data.shape) +data.read_direct(stack) +print (data.shape) + +print ("Data Loaded") + + +# Normalize +data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +itype = numpy.zeros(data.shape) +data.read_direct(itype) +# 2 is dark field +darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +dark = darks[0] +for i in range(1, len(darks)): + dark += darks[i] +dark = dark / len(darks) +#dark[0][0] = dark[0][1] + +# 1 is flat field +flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +flat = flats[0] +for i in range(1, len(flats)): + flat += flats[i] +flat = flat / len(flats) +#flat[0][0] = dark[0][1] + + +# 0 is projection data +proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = numpy.asarray (angle_proj) +angle_proj = angle_proj.astype(numpy.float32) + +# normalized data are +# norm = (projection - dark)/(flat-dark) + +def normalize(projection, dark, flat, def_val=0.1): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + +norm = [normalize(projection, dark, flat) for projection in proj] +norm = numpy.asarray (norm) +norm = norm.astype(numpy.float32) + +#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) + +#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) +#img_cgls = recon.reconstruct() +# +#pars = dict() +#pars['algorithm'] = Reconstructor.Algorithm.SIRT +#pars['projection_data'] = proj +#pars['angles'] = angle_proj +#pars['center_of_rotation'] = numpy.double(86.2) +#pars['flat_field'] = flat +#pars['iterations'] = 15 +#pars['dark_field'] = dark +#pars['resolution'] = 1 +#pars['isLogScale'] = False +#pars['threads'] = 3 +# +#img_sirt = recon.reconstruct(pars) +# +#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM +#img_mlem = recon.reconstruct() + +############################################################ +############################################################ +#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV +#recon.pars['regularize'] = numpy.double(0.1) +#img_cgls_conv = recon.reconstruct() + +niterations = 15 +threads = 3 + +img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + iteration_values, False) +print ("iteration values %s" % str(iteration_values)) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) +iteration_values = numpy.zeros((niterations,)) +img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) + + +##numpy.save("cgls_recon.npy", img_data) +import matplotlib.pyplot as plt +fig, ax = plt.subplots(1,6,sharey=True) +ax[0].imshow(img_cgls[80]) +ax[0].axis('off') # clear x- and y-axes +ax[1].imshow(img_sirt[80]) +ax[1].axis('off') # clear x- and y-axes +ax[2].imshow(img_mlem[80]) +ax[2].axis('off') # clear x- and y-axesplt.show() +ax[3].imshow(img_cgls_conv[80]) +ax[3].axis('off') # clear x- and y-axesplt.show() +ax[4].imshow(img_cgls_tikhonov[80]) +ax[4].axis('off') # clear x- and y-axesplt.show() +ax[5].imshow(img_cgls_TVreg[80]) +ax[5].axis('off') # clear x- and y-axesplt.show() + + +plt.show() + +#viewer = edo.CILViewer() +#viewer.setInputAsNumpy(img_cgls2) +#viewer.displaySliceActor(0) +#viewer.startRenderLoop() + +import vtk + +def NumpyToVTKImageData(numpyarray): + if (len(numpy.shape(numpyarray)) == 3): + doubleImg = vtk.vtkImageData() + shape = numpy.shape(numpyarray) + doubleImg.SetDimensions(shape[0], shape[1], shape[2]) + doubleImg.SetOrigin(0,0,0) + doubleImg.SetSpacing(1,1,1) + doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) + #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) + doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) + + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(shape[2]): + doubleImg.SetScalarComponentFromDouble( + i,j,k,0, numpyarray[i][j][k]) + #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(doubleImg) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(doubleImg) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(iMin) + shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) + shiftScaler.Update() + return shiftScaler.GetOutput() + +#writer = vtk.vtkMetaImageWriter() +#writer.SetFileName(alg + "_recon.mha") +#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) +#writer.Write() -- cgit v1.2.3 From 105b57a1e98c2bb7b3bf94c43b6c669925ebb1b9 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 15:01:28 +0100 Subject: added viewer for testing --- src/Python/ccpi/viewer/CILViewer.py | 361 +++++++ src/Python/ccpi/viewer/CILViewer2D.py | 1126 ++++++++++++++++++++ src/Python/ccpi/viewer/QVTKWidget.py | 340 ++++++ src/Python/ccpi/viewer/QVTKWidget2.py | 84 ++ src/Python/ccpi/viewer/__init__.py | 1 + .../viewer/__pycache__/CILViewer.cpython-35.pyc | Bin 0 -> 10542 bytes .../viewer/__pycache__/CILViewer2D.cpython-35.pyc | Bin 0 -> 35633 bytes .../viewer/__pycache__/QVTKWidget.cpython-35.pyc | Bin 0 -> 10099 bytes .../viewer/__pycache__/QVTKWidget2.cpython-35.pyc | Bin 0 -> 1316 bytes .../viewer/__pycache__/__init__.cpython-35.pyc | Bin 0 -> 210 bytes src/Python/ccpi/viewer/embedvtk.py | 75 ++ 11 files changed, 1987 insertions(+) create mode 100644 src/Python/ccpi/viewer/CILViewer.py create mode 100644 src/Python/ccpi/viewer/CILViewer2D.py create mode 100644 src/Python/ccpi/viewer/QVTKWidget.py create mode 100644 src/Python/ccpi/viewer/QVTKWidget2.py create mode 100644 src/Python/ccpi/viewer/__init__.py create mode 100644 src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/embedvtk.py (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/viewer/CILViewer.py b/src/Python/ccpi/viewer/CILViewer.py new file mode 100644 index 0000000..efcf8be --- /dev/null +++ b/src/Python/ccpi/viewer/CILViewer.py @@ -0,0 +1,361 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Edoardo Pasca +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import vtk +import numpy +import math +from vtk.util import numpy_support + +SLICE_ORIENTATION_XY = 2 # Z +SLICE_ORIENTATION_XZ = 1 # Y +SLICE_ORIENTATION_YZ = 0 # X + + + +class CILViewer(): + '''Simple 3D Viewer based on VTK classes''' + + def __init__(self, dimx=600,dimy=600): + '''creates the rendering pipeline''' + + # create a rendering window and renderer + self.ren = vtk.vtkRenderer() + self.renWin = vtk.vtkRenderWindow() + self.renWin.SetSize(dimx,dimy) + self.renWin.AddRenderer(self.ren) + + # img 3D as slice + self.img3D = None + self.sliceno = 0 + self.sliceOrientation = SLICE_ORIENTATION_XY + self.sliceActor = None + self.voi = None + self.wl = None + self.ia = None + self.sliceActorNo = 0 + # create a renderwindowinteractor + self.iren = vtk.vtkRenderWindowInteractor() + self.iren.SetRenderWindow(self.renWin) + + self.style = vtk.vtkInteractorStyleTrackballCamera() + self.iren.SetInteractorStyle(self.style) + + self.ren.SetBackground(.1, .2, .4) + + self.actors = {} + self.iren.RemoveObservers('MouseWheelForwardEvent') + self.iren.RemoveObservers('MouseWheelBackwardEvent') + + self.iren.AddObserver('MouseWheelForwardEvent', self.mouseInteraction, 1.0) + self.iren.AddObserver('MouseWheelBackwardEvent', self.mouseInteraction, 1.0) + + self.iren.RemoveObservers('KeyPressEvent') + self.iren.AddObserver('KeyPressEvent', self.keyPress, 1.0) + + + self.iren.Initialize() + + + + def getRenderer(self): + '''returns the renderer''' + return self.ren + + def getRenderWindow(self): + '''returns the render window''' + return self.renWin + + def getInteractor(self): + '''returns the render window interactor''' + return self.iren + + def getCamera(self): + '''returns the active camera''' + return self.ren.GetActiveCamera() + + def createPolyDataActor(self, polydata): + '''returns an actor for a given polydata''' + mapper = vtk.vtkPolyDataMapper() + if vtk.VTK_MAJOR_VERSION <= 5: + mapper.SetInput(polydata) + else: + mapper.SetInputData(polydata) + + # actor + actor = vtk.vtkActor() + actor.SetMapper(mapper) + #actor.GetProperty().SetOpacity(0.8) + return actor + + def setPolyDataActor(self, actor): + '''displays the given polydata''' + + self.ren.AddActor(actor) + + self.actors[len(self.actors)+1] = [actor, True] + self.iren.Initialize() + self.renWin.Render() + + def displayPolyData(self, polydata): + self.setPolyDataActor(self.createPolyDataActor(polydata)) + + def hideActor(self, actorno): + '''Hides an actor identified by its number in the list of actors''' + try: + if self.actors[actorno][1]: + self.ren.RemoveActor(self.actors[actorno][0]) + self.actors[actorno][1] = False + except KeyError as ke: + print ("Warning Actor not present") + + def showActor(self, actorno, actor = None): + '''Shows hidden actor identified by its number in the list of actors''' + try: + if not self.actors[actorno][1]: + self.ren.AddActor(self.actors[actorno][0]) + self.actors[actorno][1] = True + return actorno + except KeyError as ke: + # adds it to the actors if not there already + if actor != None: + self.ren.AddActor(actor) + self.actors[len(self.actors)+1] = [actor, True] + return len(self.actors) + + def addActor(self, actor): + '''Adds an actor to the render''' + return self.showActor(0, actor) + + + def saveRender(self, filename, renWin=None): + '''Save the render window to PNG file''' + # screenshot code: + w2if = vtk.vtkWindowToImageFilter() + if renWin == None: + renWin = self.renWin + w2if.SetInput(renWin) + w2if.Update() + + writer = vtk.vtkPNGWriter() + writer.SetFileName("%s.png" % (filename)) + writer.SetInputConnection(w2if.GetOutputPort()) + writer.Write() + + + def startRenderLoop(self): + self.iren.Start() + + + def setupObservers(self, interactor): + interactor.RemoveObservers('LeftButtonPressEvent') + interactor.AddObserver('LeftButtonPressEvent', self.mouseInteraction) + interactor.Initialize() + + + def mouseInteraction(self, interactor, event): + if event == 'MouseWheelForwardEvent': + maxSlice = self.img3D.GetDimensions()[self.sliceOrientation] + if (self.sliceno + 1 < maxSlice): + self.hideActor(self.sliceActorNo) + self.sliceno = self.sliceno + 1 + self.displaySliceActor(self.sliceno) + else: + minSlice = 0 + if (self.sliceno - 1 > minSlice): + self.hideActor(self.sliceActorNo) + self.sliceno = self.sliceno - 1 + self.displaySliceActor(self.sliceno) + + + def keyPress(self, interactor, event): + #print ("Pressed key %s" % interactor.GetKeyCode()) + # Slice Orientation + if interactor.GetKeyCode() == "x": + # slice on the other orientation + self.sliceOrientation = SLICE_ORIENTATION_YZ + self.sliceno = int(self.img3D.GetDimensions()[1] / 2) + self.hideActor(self.sliceActorNo) + self.displaySliceActor(self.sliceno) + elif interactor.GetKeyCode() == "y": + # slice on the other orientation + self.sliceOrientation = SLICE_ORIENTATION_XZ + self.sliceno = int(self.img3D.GetDimensions()[1] / 2) + self.hideActor(self.sliceActorNo) + self.displaySliceActor(self.sliceno) + elif interactor.GetKeyCode() == "z": + # slice on the other orientation + self.sliceOrientation = SLICE_ORIENTATION_XY + self.sliceno = int(self.img3D.GetDimensions()[2] / 2) + self.hideActor(self.sliceActorNo) + self.displaySliceActor(self.sliceno) + if interactor.GetKeyCode() == "X": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_YZ] = math.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.ren.SetActiveCamera(camera) + self.ren.ResetCamera() + self.ren.Render() + interactor.SetKeyCode("x") + self.keyPress(interactor, event) + elif interactor.GetKeyCode() == "Y": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XZ] = math.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_YZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.ren.SetActiveCamera(camera) + self.ren.ResetCamera() + self.ren.Render() + interactor.SetKeyCode("y") + self.keyPress(interactor, event) + elif interactor.GetKeyCode() == "Z": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XY] = math.sqrt(newposition[SLICE_ORIENTATION_YZ] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.ren.SetActiveCamera(camera) + self.ren.ResetCamera() + self.ren.Render() + interactor.SetKeyCode("z") + self.keyPress(interactor, event) + else : + print ("Unhandled event %s" % interactor.GetKeyCode()) + + + + def setInput3DData(self, imageData): + self.img3D = imageData + + def setInputAsNumpy(self, numpyarray): + if (len(numpy.shape(numpyarray)) == 3): + doubleImg = vtk.vtkImageData() + shape = numpy.shape(numpyarray) + doubleImg.SetDimensions(shape[0], shape[1], shape[2]) + doubleImg.SetOrigin(0,0,0) + doubleImg.SetSpacing(1,1,1) + doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) + #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) + doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) + + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(shape[2]): + doubleImg.SetScalarComponentFromDouble( + i,j,k,0, numpyarray[i][j][k]) + #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(doubleImg) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(doubleImg) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(iMin) + shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) + shiftScaler.Update() + self.img3D = shiftScaler.GetOutput() + + def displaySliceActor(self, sliceno = 0): + self.sliceno = sliceno + first = False + + self.sliceActor , self.voi, self.wl , self.ia = \ + self.getSliceActor(self.img3D, + sliceno, + self.sliceActor, + self.voi, + self.wl, + self.ia) + no = self.showActor(self.sliceActorNo, self.sliceActor) + self.sliceActorNo = no + + self.iren.Initialize() + self.renWin.Render() + + return self.sliceActorNo + + + def getSliceActor(self, + imageData , + sliceno=0, + imageActor=None , + voi=None, + windowLevel=None, + imageAccumulate=None): + '''Slices a 3D volume and then creates an actor to be rendered''' + if (voi==None): + voi = vtk.vtkExtractVOI() + #voi = vtk.vtkImageClip() + voi.SetInputData(imageData) + #select one slice in Z + extent = [ i for i in self.img3D.GetExtent()] + extent[self.sliceOrientation * 2] = sliceno + extent[self.sliceOrientation * 2 + 1] = sliceno + voi.SetVOI(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + + voi.Update() + # set window/level for all slices + if imageAccumulate == None: + imageAccumulate = vtk.vtkImageAccumulate() + + if (windowLevel == None): + windowLevel = vtk.vtkImageMapToWindowLevelColors() + imageAccumulate.SetInputData(imageData) + imageAccumulate.Update() + cmax = imageAccumulate.GetMax()[0] + cmin = imageAccumulate.GetMin()[0] + windowLevel.SetLevel((cmax+cmin)/2) + windowLevel.SetWindow(cmax-cmin) + + windowLevel.SetInputData(voi.GetOutput()) + windowLevel.Update() + + if imageActor == None: + imageActor = vtk.vtkImageActor() + imageActor.SetInputData(windowLevel.GetOutput()) + imageActor.SetDisplayExtent(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + imageActor.Update() + return (imageActor , voi, windowLevel, imageAccumulate) + + + # Set interpolation on + def setInterpolateOn(self): + self.sliceActor.SetInterpolate(True) + self.renWin.Render() + + # Set interpolation off + def setInterpolateOff(self): + self.sliceActor.SetInterpolate(False) + self.renWin.Render() \ No newline at end of file diff --git a/src/Python/ccpi/viewer/CILViewer2D.py b/src/Python/ccpi/viewer/CILViewer2D.py new file mode 100644 index 0000000..c1629af --- /dev/null +++ b/src/Python/ccpi/viewer/CILViewer2D.py @@ -0,0 +1,1126 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Edoardo Pasca +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import vtk +import numpy +from vtk.util import numpy_support , vtkImageImportFromArray +from enum import Enum + +SLICE_ORIENTATION_XY = 2 # Z +SLICE_ORIENTATION_XZ = 1 # Y +SLICE_ORIENTATION_YZ = 0 # X + +CONTROL_KEY = 8 +SHIFT_KEY = 4 +ALT_KEY = -128 + + +# Converter class +class Converter(): + + # Utility functions to transform numpy arrays to vtkImageData and viceversa + @staticmethod + def numpy2vtkImporter(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): + '''Creates a vtkImageImportFromArray object and returns it. + + It handles the different axis order from numpy to VTK''' + importer = vtkImageImportFromArray.vtkImageImportFromArray() + importer.SetArray(numpy.transpose(nparray).copy()) + importer.SetDataSpacing(spacing) + importer.SetDataOrigin(origin) + return importer + + @staticmethod + def numpy2vtk(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): + '''Converts a 3D numpy array to a vtkImageData''' + importer = Converter.numpy2vtkImporter(nparray, spacing, origin) + importer.Update() + return importer.GetOutput() + + @staticmethod + def vtk2numpy(imgdata): + '''Converts the VTK data to 3D numpy array''' + img_data = numpy_support.vtk_to_numpy( + imgdata.GetPointData().GetScalars()) + + dims = imgdata.GetDimensions() + dims = (dims[2],dims[1],dims[0]) + data3d = numpy.reshape(img_data, dims) + + return numpy.transpose(data3d).copy() + + @staticmethod + def tiffStack2numpy(filename, indices, + extent = None , sampleRate = None ,\ + flatField = None, darkField = None): + '''Converts a stack of TIFF files to numpy array. + + filename must contain the whole path. The filename is supposed to be named and + have a suffix with the ordinal file number, i.e. /path/to/projection_%03d.tif + + indices are the suffix, generally an increasing number + + Optionally extracts only a selection of the 2D images and (optionally) + normalizes. + ''' + + stack = vtk.vtkImageData() + reader = vtk.vtkTIFFReader() + voi = vtk.vtkExtractVOI() + + #directory = "C:\\Users\\ofn77899\\Documents\\CCPi\\IMAT\\20170419_crabtomo\\crabtomo\\" + + stack_image = numpy.asarray([]) + nreduced = len(indices) + + for num in range(len(indices)): + fn = filename % indices[num] + print ("resampling %s" % ( fn ) ) + reader.SetFileName(fn) + reader.Update() + print (reader.GetOutput().GetScalarTypeAsString()) + if num == 0: + if (extent == None): + sliced = reader.GetOutput().GetExtent() + stack.SetExtent(sliced[0],sliced[1], sliced[2],sliced[3], 0, nreduced-1) + else: + sliced = extent + voi.SetVOI(extent) + + if sampleRate is not None: + voi.SetSampleRate(sampleRate) + ext = numpy.asarray([(sliced[2*i+1] - sliced[2*i])/sampleRate[i] for i in range(3)], dtype=int) + print ("ext {0}".format(ext)) + stack.SetExtent(0, ext[0] , 0, ext[1], 0, nreduced-1) + else: + stack.SetExtent(0, sliced[1] - sliced[0] , 0, sliced[3]-sliced[2], 0, nreduced-1) + if (flatField != None and darkField != None): + stack.AllocateScalars(vtk.VTK_FLOAT, 1) + else: + stack.AllocateScalars(reader.GetOutput().GetScalarType(), 1) + print ("Image Size: %d" % ((sliced[1]+1)*(sliced[3]+1) )) + stack_image = Converter.vtk2numpy(stack) + print ("Stack shape %s" % str(numpy.shape(stack_image))) + + if extent!=None: + voi.SetInputData(reader.GetOutput()) + voi.Update() + img = voi.GetOutput() + else: + img = reader.GetOutput() + + theSlice = Converter.vtk2numpy(img).T[0] + if darkField != None and flatField != None: + print("Try to normalize") + #if numpy.shape(darkField) == numpy.shape(flatField) and numpy.shape(flatField) == numpy.shape(theSlice): + theSlice = Converter.normalize(theSlice, darkField, flatField, 0.01) + print (theSlice.dtype) + + + print ("Slice shape %s" % str(numpy.shape(theSlice))) + stack_image.T[num] = theSlice.copy() + + return stack_image + + @staticmethod + def normalize(projection, dark, flat, def_val=0): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + + +## Utility functions to transform numpy arrays to vtkImageData and viceversa +#def numpy2vtkImporter(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): +# return Converter.numpy2vtkImporter(nparray, spacing, origin) +# +#def numpy2vtk(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): +# return Converter.numpy2vtk(nparray, spacing, origin) +# +#def vtk2numpy(imgdata): +# return Converter.vtk2numpy(imgdata) +# +#def tiffStack2numpy(filename, indices): +# return Converter.tiffStack2numpy(filename, indices) + +class ViewerEvent(Enum): + # left button + PICK_EVENT = 0 + # alt + right button + move + WINDOW_LEVEL_EVENT = 1 + # shift + right button + ZOOM_EVENT = 2 + # control + right button + PAN_EVENT = 3 + # control + left button + CREATE_ROI_EVENT = 4 + # alt + left button + DELETE_ROI_EVENT = 5 + # release button + NO_EVENT = -1 + + +#class CILInteractorStyle(vtk.vtkInteractorStyleTrackballCamera): +class CILInteractorStyle(vtk.vtkInteractorStyleImage): + + def __init__(self, callback): + vtk.vtkInteractorStyleImage.__init__(self) + self.callback = callback + self._viewer = callback + priority = 1.0 + +# self.AddObserver("MouseWheelForwardEvent" , callback.OnMouseWheelForward , priority) +# self.AddObserver("MouseWheelBackwardEvent" , callback.OnMouseWheelBackward, priority) +# self.AddObserver('KeyPressEvent', callback.OnKeyPress, priority) +# self.AddObserver('LeftButtonPressEvent', callback.OnLeftButtonPressEvent, priority) +# self.AddObserver('RightButtonPressEvent', callback.OnRightButtonPressEvent, priority) +# self.AddObserver('LeftButtonReleaseEvent', callback.OnLeftButtonReleaseEvent, priority) +# self.AddObserver('RightButtonReleaseEvent', callback.OnRightButtonReleaseEvent, priority) +# self.AddObserver('MouseMoveEvent', callback.OnMouseMoveEvent, priority) + + self.AddObserver("MouseWheelForwardEvent" , self.OnMouseWheelForward , priority) + self.AddObserver("MouseWheelBackwardEvent" , self.OnMouseWheelBackward, priority) + self.AddObserver('KeyPressEvent', self.OnKeyPress, priority) + self.AddObserver('LeftButtonPressEvent', self.OnLeftButtonPressEvent, priority) + self.AddObserver('RightButtonPressEvent', self.OnRightButtonPressEvent, priority) + self.AddObserver('LeftButtonReleaseEvent', self.OnLeftButtonReleaseEvent, priority) + self.AddObserver('RightButtonReleaseEvent', self.OnRightButtonReleaseEvent, priority) + self.AddObserver('MouseMoveEvent', self.OnMouseMoveEvent, priority) + + self.InitialEventPosition = (0,0) + + + def SetInitialEventPosition(self, xy): + self.InitialEventPosition = xy + + def GetInitialEventPosition(self): + return self.InitialEventPosition + + def GetKeyCode(self): + return self.GetInteractor().GetKeyCode() + + def SetKeyCode(self, keycode): + self.GetInteractor().SetKeyCode(keycode) + + def GetControlKey(self): + return self.GetInteractor().GetControlKey() == CONTROL_KEY + + def GetShiftKey(self): + return self.GetInteractor().GetShiftKey() == SHIFT_KEY + + def GetAltKey(self): + return self.GetInteractor().GetAltKey() == ALT_KEY + + def GetEventPosition(self): + return self.GetInteractor().GetEventPosition() + + def GetEventPositionInWorldCoordinates(self): + pass + + def GetDeltaEventPosition(self): + x,y = self.GetInteractor().GetEventPosition() + return (x - self.InitialEventPosition[0] , y - self.InitialEventPosition[1]) + + def Dolly(self, factor): + self.callback.camera.Dolly(factor) + self.callback.ren.ResetCameraClippingRange() + + def GetDimensions(self): + return self._viewer.img3D.GetDimensions() + + def GetInputData(self): + return self._viewer.img3D + + def GetSliceOrientation(self): + return self._viewer.sliceOrientation + + def SetSliceOrientation(self, orientation): + self._viewer.sliceOrientation = orientation + + def GetActiveSlice(self): + return self._viewer.sliceno + + def SetActiveSlice(self, sliceno): + self._viewer.sliceno = sliceno + + def UpdatePipeline(self, reset = False): + self._viewer.updatePipeline(reset) + + def GetActiveCamera(self): + return self._viewer.ren.GetActiveCamera() + + def SetActiveCamera(self, camera): + self._viewer.ren.SetActiveCamera(camera) + + def ResetCamera(self): + self._viewer.ren.ResetCamera() + + def Render(self): + self._viewer.renWin.Render() + + def UpdateSliceActor(self): + self._viewer.sliceActor.Update() + + def AdjustCamera(self): + self._viewer.AdjustCamera() + + def SaveRender(self, filename): + self._viewer.SaveRender(filename) + + def GetRenderWindow(self): + return self._viewer.renWin + + def GetRenderer(self): + return self._viewer.ren + + def GetROIWidget(self): + return self._viewer.ROIWidget + + def SetViewerEvent(self, event): + self._viewer.event = event + + def GetViewerEvent(self): + return self._viewer.event + + def SetInitialCameraPosition(self, position): + self._viewer.InitialCameraPosition = position + + def GetInitialCameraPosition(self): + return self._viewer.InitialCameraPosition + + def SetInitialLevel(self, level): + self._viewer.InitialLevel = level + + def GetInitialLevel(self): + return self._viewer.InitialLevel + + def SetInitialWindow(self, window): + self._viewer.InitialWindow = window + + def GetInitialWindow(self): + return self._viewer.InitialWindow + + def GetWindowLevel(self): + return self._viewer.wl + + def SetROI(self, roi): + self._viewer.ROI = roi + + def GetROI(self): + return self._viewer.ROI + + def UpdateCornerAnnotation(self, text, corner): + self._viewer.updateCornerAnnotation(text, corner) + + def GetPicker(self): + return self._viewer.picker + + def GetCornerAnnotation(self): + return self._viewer.cornerAnnotation + + def UpdateROIHistogram(self): + self._viewer.updateROIHistogram() + + + ############### Handle events + def OnMouseWheelForward(self, interactor, event): + maxSlice = self.GetDimensions()[self.GetSliceOrientation()] + shift = interactor.GetShiftKey() + advance = 1 + if shift: + advance = 10 + + if (self.GetActiveSlice() + advance < maxSlice): + self.SetActiveSlice(self.GetActiveSlice() + advance) + + self.UpdatePipeline() + else: + print ("maxSlice %d request %d" % (maxSlice, self.GetActiveSlice() + 1 )) + + def OnMouseWheelBackward(self, interactor, event): + minSlice = 0 + shift = interactor.GetShiftKey() + advance = 1 + if shift: + advance = 10 + if (self.GetActiveSlice() - advance >= minSlice): + self.SetActiveSlice( self.GetActiveSlice() - advance) + self.UpdatePipeline() + else: + print ("minSlice %d request %d" % (minSlice, self.GetActiveSlice() + 1 )) + + def OnKeyPress(self, interactor, event): + #print ("Pressed key %s" % interactor.GetKeyCode()) + # Slice Orientation + if interactor.GetKeyCode() == "X": + # slice on the other orientation + self.SetSliceOrientation ( SLICE_ORIENTATION_YZ ) + self.SetActiveSlice( int(self.GetDimensions()[1] / 2) ) + self.UpdatePipeline(True) + elif interactor.GetKeyCode() == "Y": + # slice on the other orientation + self.SetSliceOrientation ( SLICE_ORIENTATION_XZ ) + self.SetActiveSlice ( int(self.GetInputData().GetDimensions()[1] / 2) ) + self.UpdatePipeline(True) + elif interactor.GetKeyCode() == "Z": + # slice on the other orientation + self.SetSliceOrientation ( SLICE_ORIENTATION_XY ) + self.SetActiveSlice ( int(self.GetInputData().GetDimensions()[2] / 2) ) + self.UpdatePipeline(True) + if interactor.GetKeyCode() == "x": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_YZ] = numpy.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.SetActiveCamera(camera) + self.Render() + interactor.SetKeyCode("X") + self.OnKeyPress(interactor, event) + elif interactor.GetKeyCode() == "y": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XZ] = numpy.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_YZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.SetActiveCamera(camera) + self.Render() + interactor.SetKeyCode("Y") + self.OnKeyPress(interactor, event) + elif interactor.GetKeyCode() == "z": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XY] = numpy.sqrt(newposition[SLICE_ORIENTATION_YZ] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,1,0) + self.SetActiveCamera(camera) + self.ResetCamera() + self.Render() + interactor.SetKeyCode("Z") + self.OnKeyPress(interactor, event) + elif interactor.GetKeyCode() == "a": + # reset color/window + cmax = self._viewer.ia.GetMax()[0] + cmin = self._viewer.ia.GetMin()[0] + + self.SetInitialLevel( (cmax+cmin)/2 ) + self.SetInitialWindow( cmax-cmin ) + + self.GetWindowLevel().SetLevel(self.GetInitialLevel()) + self.GetWindowLevel().SetWindow(self.GetInitialWindow()) + + self.GetWindowLevel().Update() + + self.UpdateSliceActor() + self.AdjustCamera() + self.Render() + + elif interactor.GetKeyCode() == "s": + filename = "current_render" + self.SaveRender(filename) + elif interactor.GetKeyCode() == "q": + print ("Terminating by pressing q %s" % (interactor.GetKeyCode(), )) + interactor.SetKeyCode("e") + self.OnKeyPress(interactor, event) + else : + #print ("Unhandled event %s" % (interactor.GetKeyCode(), ))) + pass + + def OnLeftButtonPressEvent(self, interactor, event): + alt = interactor.GetAltKey() + shift = interactor.GetShiftKey() + ctrl = interactor.GetControlKey() +# print ("alt pressed " + (lambda x : "Yes" if x else "No")(alt)) +# print ("shift pressed " + (lambda x : "Yes" if x else "No")(shift)) +# print ("ctrl pressed " + (lambda x : "Yes" if x else "No")(ctrl)) + + interactor.SetInitialEventPosition(interactor.GetEventPosition()) + + if ctrl and not (alt and shift): + self.SetViewerEvent( ViewerEvent.CREATE_ROI_EVENT ) + wsize = self.GetRenderWindow().GetSize() + position = interactor.GetEventPosition() + self.GetROIWidget().GetBorderRepresentation().SetPosition((position[0]/wsize[0] - 0.05) , (position[1]/wsize[1] - 0.05)) + self.GetROIWidget().GetBorderRepresentation().SetPosition2( (0.1) , (0.1)) + + self.GetROIWidget().On() + self.SetDisplayHistogram(True) + self.Render() + print ("Event %s is CREATE_ROI_EVENT" % (event)) + elif alt and not (shift and ctrl): + self.SetViewerEvent( ViewerEvent.DELETE_ROI_EVENT ) + self.GetROIWidget().Off() + self._viewer.updateCornerAnnotation("", 1, False) + self.SetDisplayHistogram(False) + self.Render() + print ("Event %s is DELETE_ROI_EVENT" % (event)) + elif not (ctrl and alt and shift): + self.SetViewerEvent ( ViewerEvent.PICK_EVENT ) + self.HandlePickEvent(interactor, event) + print ("Event %s is PICK_EVENT" % (event)) + + + def SetDisplayHistogram(self, display): + if display: + if (self._viewer.displayHistogram == 0): + self.GetRenderer().AddActor(self._viewer.histogramPlotActor) + self.firstHistogram = 1 + self.Render() + + self._viewer.histogramPlotActor.VisibilityOn() + self._viewer.displayHistogram = True + else: + self._viewer.histogramPlotActor.VisibilityOff() + self._viewer.displayHistogram = False + + + def OnLeftButtonReleaseEvent(self, interactor, event): + if self.GetViewerEvent() == ViewerEvent.CREATE_ROI_EVENT: + #bc = self.ROIWidget.GetBorderRepresentation().GetPositionCoordinate() + #print (bc.GetValue()) + self.OnROIModifiedEvent(interactor, event) + + elif self.GetViewerEvent() == ViewerEvent.PICK_EVENT: + self.HandlePickEvent(interactor, event) + + self.SetViewerEvent( ViewerEvent.NO_EVENT ) + + def OnRightButtonPressEvent(self, interactor, event): + alt = interactor.GetAltKey() + shift = interactor.GetShiftKey() + ctrl = interactor.GetControlKey() +# print ("alt pressed " + (lambda x : "Yes" if x else "No")(alt)) +# print ("shift pressed " + (lambda x : "Yes" if x else "No")(shift)) +# print ("ctrl pressed " + (lambda x : "Yes" if x else "No")(ctrl)) + + interactor.SetInitialEventPosition(interactor.GetEventPosition()) + + + if alt and not (ctrl and shift): + self.SetViewerEvent( ViewerEvent.WINDOW_LEVEL_EVENT ) + print ("Event %s is WINDOW_LEVEL_EVENT" % (event)) + self.HandleWindowLevel(interactor, event) + elif shift and not (ctrl and alt): + self.SetViewerEvent( ViewerEvent.ZOOM_EVENT ) + self.SetInitialCameraPosition( self.GetActiveCamera().GetPosition()) + print ("Event %s is ZOOM_EVENT" % (event)) + elif ctrl and not (shift and alt): + self.SetViewerEvent (ViewerEvent.PAN_EVENT ) + self.SetInitialCameraPosition ( self.GetActiveCamera().GetPosition() ) + print ("Event %s is PAN_EVENT" % (event)) + + def OnRightButtonReleaseEvent(self, interactor, event): + print (event) + if self.GetViewerEvent() == ViewerEvent.WINDOW_LEVEL_EVENT: + self.SetInitialLevel( self.GetWindowLevel().GetLevel() ) + self.SetInitialWindow ( self.GetWindowLevel().GetWindow() ) + elif self.GetViewerEvent() == ViewerEvent.ZOOM_EVENT or \ + self.GetViewerEvent() == ViewerEvent.PAN_EVENT: + self.SetInitialCameraPosition( () ) + + self.SetViewerEvent( ViewerEvent.NO_EVENT ) + + + def OnROIModifiedEvent(self, interactor, event): + + #print ("ROI EVENT " + event) + p1 = self.GetROIWidget().GetBorderRepresentation().GetPositionCoordinate() + p2 = self.GetROIWidget().GetBorderRepresentation().GetPosition2Coordinate() + wsize = self.GetRenderWindow().GetSize() + + #print (p1.GetValue()) + #print (p2.GetValue()) + pp1 = [p1.GetValue()[0] * wsize[0] , p1.GetValue()[1] * wsize[1] , 0.0] + pp2 = [p2.GetValue()[0] * wsize[0] + pp1[0] , p2.GetValue()[1] * wsize[1] + pp1[1] , 0.0] + vox1 = self.viewport2imageCoordinate(pp1) + vox2 = self.viewport2imageCoordinate(pp2) + + self.SetROI( (vox1 , vox2) ) + roi = self.GetROI() + print ("Pixel1 %d,%d,%d Value %f" % vox1 ) + print ("Pixel2 %d,%d,%d Value %f" % vox2 ) + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + print ("slice orientation : XY") + x = abs(roi[1][0] - roi[0][0]) + y = abs(roi[1][1] - roi[0][1]) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + print ("slice orientation : XY") + x = abs(roi[1][0] - roi[0][0]) + y = abs(roi[1][2] - roi[0][2]) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + print ("slice orientation : XY") + x = abs(roi[1][1] - roi[0][1]) + y = abs(roi[1][2] - roi[0][2]) + + text = "ROI: %d x %d, %.2f kp" % (x,y,float(x*y)/1024.) + print (text) + self.UpdateCornerAnnotation(text, 1) + self.UpdateROIHistogram() + self.SetViewerEvent( ViewerEvent.NO_EVENT ) + + def viewport2imageCoordinate(self, viewerposition): + #Determine point index + + self.GetPicker().Pick(viewerposition[0], viewerposition[1], 0.0, self.GetRenderer()) + pickPosition = list(self.GetPicker().GetPickPosition()) + pickPosition[self.GetSliceOrientation()] = \ + self.GetInputData().GetSpacing()[self.GetSliceOrientation()] * self.GetActiveSlice() + \ + self.GetInputData().GetOrigin()[self.GetSliceOrientation()] + print ("Pick Position " + str (pickPosition)) + + if (pickPosition != [0,0,0]): + dims = self.GetInputData().GetDimensions() + print (dims) + spac = self.GetInputData().GetSpacing() + orig = self.GetInputData().GetOrigin() + imagePosition = [int(pickPosition[i] / spac[i] + orig[i]) for i in range(3) ] + + pixelValue = self.GetInputData().GetScalarComponentAsDouble(imagePosition[0], imagePosition[1], imagePosition[2], 0) + return (imagePosition[0], imagePosition[1], imagePosition[2] , pixelValue) + else: + return (0,0,0,0) + + + + + def OnMouseMoveEvent(self, interactor, event): + if self.GetViewerEvent() == ViewerEvent.WINDOW_LEVEL_EVENT: + print ("Event %s is WINDOW_LEVEL_EVENT" % (event)) + self.HandleWindowLevel(interactor, event) + elif self.GetViewerEvent() == ViewerEvent.PICK_EVENT: + self.HandlePickEvent(interactor, event) + elif self.GetViewerEvent() == ViewerEvent.ZOOM_EVENT: + self.HandleZoomEvent(interactor, event) + elif self.GetViewerEvent() == ViewerEvent.PAN_EVENT: + self.HandlePanEvent(interactor, event) + + + def HandleZoomEvent(self, interactor, event): + dx,dy = interactor.GetDeltaEventPosition() + size = self.GetRenderWindow().GetSize() + dy = - 4 * dy / size[1] + + print ("distance: " + str(self.GetActiveCamera().GetDistance())) + + print ("\ndy: %f\ncamera dolly %f\n" % (dy, 1 + dy)) + + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + #print ("current position " + str(self.InitialCameraPosition)) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + camera.SetPosition(self.GetInitialCameraPosition()) + newposition = [i for i in self.GetInitialCameraPosition()] + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + dist = newposition[SLICE_ORIENTATION_XY] * ( 1 + dy ) + newposition[SLICE_ORIENTATION_XY] *= ( 1 + dy ) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + newposition[SLICE_ORIENTATION_XZ] *= ( 1 + dy ) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + newposition[SLICE_ORIENTATION_YZ] *= ( 1 + dy ) + #print ("new position " + str(newposition)) + camera.SetPosition(newposition) + self.SetActiveCamera(camera) + + self.Render() + + print ("distance after: " + str(self.GetActiveCamera().GetDistance())) + + def HandlePanEvent(self, interactor, event): + x,y = interactor.GetEventPosition() + x0,y0 = interactor.GetInitialEventPosition() + + ic = self.viewport2imageCoordinate((x,y)) + ic0 = self.viewport2imageCoordinate((x0,y0)) + + dx = 4 *( ic[0] - ic0[0]) + dy = 4* (ic[1] - ic0[1]) + + camera = vtk.vtkCamera() + #print ("current position " + str(self.InitialCameraPosition)) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + camera.SetPosition(self.GetInitialCameraPosition()) + newposition = [i for i in self.GetInitialCameraPosition()] + newfocalpoint = [i for i in self.GetActiveCamera().GetFocalPoint()] + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + newposition[0] -= dx + newposition[1] -= dy + newfocalpoint[0] = newposition[0] + newfocalpoint[1] = newposition[1] + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + newposition[0] -= dx + newposition[2] -= dy + newfocalpoint[0] = newposition[0] + newfocalpoint[2] = newposition[2] + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + newposition[1] -= dx + newposition[2] -= dy + newfocalpoint[2] = newposition[2] + newfocalpoint[1] = newposition[1] + #print ("new position " + str(newposition)) + camera.SetFocalPoint(newfocalpoint) + camera.SetPosition(newposition) + self.SetActiveCamera(camera) + + self.Render() + + def HandleWindowLevel(self, interactor, event): + dx,dy = interactor.GetDeltaEventPosition() + print ("Event delta %d %d" % (dx,dy)) + size = self.GetRenderWindow().GetSize() + + dx = 4 * dx / size[0] + dy = 4 * dy / size[1] + window = self.GetInitialWindow() + level = self.GetInitialLevel() + + if abs(window) > 0.01: + dx = dx * window + else: + dx = dx * (lambda x: -0.01 if x <0 else 0.01)(window); + + if abs(level) > 0.01: + dy = dy * level + else: + dy = dy * (lambda x: -0.01 if x <0 else 0.01)(level) + + + # Abs so that direction does not flip + + if window < 0.0: + dx = -1*dx + if level < 0.0: + dy = -1*dy + + # Compute new window level + + newWindow = dx + window + newLevel = level - dy + + # Stay away from zero and really + + if abs(newWindow) < 0.01: + newWindow = 0.01 * (lambda x: -1 if x <0 else 1)(newWindow) + + if abs(newLevel) < 0.01: + newLevel = 0.01 * (lambda x: -1 if x <0 else 1)(newLevel) + + self.GetWindowLevel().SetWindow(newWindow) + self.GetWindowLevel().SetLevel(newLevel) + + self.GetWindowLevel().Update() + self.UpdateSliceActor() + self.AdjustCamera() + + self.Render() + + def HandlePickEvent(self, interactor, event): + position = interactor.GetEventPosition() + #print ("PICK " + str(position)) + vox = self.viewport2imageCoordinate(position) + #print ("Pixel %d,%d,%d Value %f" % vox ) + self._viewer.cornerAnnotation.VisibilityOn() + self.UpdateCornerAnnotation("[%d,%d,%d] : %.2f" % vox , 0) + self.Render() + +############################################################################### + + + +class CILViewer2D(): + '''Simple Interactive Viewer based on VTK classes''' + + def __init__(self, dimx=600,dimy=600, ren=None, renWin=None,iren=None): + '''creates the rendering pipeline''' + # create a rendering window and renderer + if ren == None: + self.ren = vtk.vtkRenderer() + else: + self.ren = ren + if renWin == None: + self.renWin = vtk.vtkRenderWindow() + else: + self.renWin = renWin + if iren == None: + self.iren = vtk.vtkRenderWindowInteractor() + else: + self.iren = iren + + self.renWin.SetSize(dimx,dimy) + self.renWin.AddRenderer(self.ren) + + self.style = CILInteractorStyle(self) + + self.iren.SetInteractorStyle(self.style) + self.iren.SetRenderWindow(self.renWin) + self.iren.Initialize() + self.ren.SetBackground(.1, .2, .4) + + self.camera = vtk.vtkCamera() + self.camera.ParallelProjectionOn() + self.ren.SetActiveCamera(self.camera) + + # data + self.img3D = None + self.sliceno = 0 + self.sliceOrientation = SLICE_ORIENTATION_XY + + #Actors + self.sliceActor = vtk.vtkImageActor() + self.voi = vtk.vtkExtractVOI() + self.wl = vtk.vtkImageMapToWindowLevelColors() + self.ia = vtk.vtkImageAccumulate() + self.sliceActorNo = 0 + + #initial Window/Level + self.InitialLevel = 0 + self.InitialWindow = 0 + + #ViewerEvent + self.event = ViewerEvent.NO_EVENT + + # ROI Widget + self.ROIWidget = vtk.vtkBorderWidget() + self.ROIWidget.SetInteractor(self.iren) + self.ROIWidget.CreateDefaultRepresentation() + self.ROIWidget.GetBorderRepresentation().GetBorderProperty().SetColor(0,1,0) + self.ROIWidget.AddObserver(vtk.vtkWidgetEvent.Select, self.style.OnROIModifiedEvent, 1.0) + + # edge points of the ROI + self.ROI = () + + #picker + self.picker = vtk.vtkPropPicker() + self.picker.PickFromListOn() + self.picker.AddPickList(self.sliceActor) + + self.iren.SetPicker(self.picker) + + # corner annotation + self.cornerAnnotation = vtk.vtkCornerAnnotation() + self.cornerAnnotation.SetMaximumFontSize(12); + self.cornerAnnotation.PickableOff(); + self.cornerAnnotation.VisibilityOff(); + self.cornerAnnotation.GetTextProperty().ShadowOn(); + self.cornerAnnotation.SetLayerNumber(1); + + + + # cursor doesn't show up + self.cursor = vtk.vtkCursor2D() + self.cursorMapper = vtk.vtkPolyDataMapper2D() + self.cursorActor = vtk.vtkActor2D() + self.cursor.SetModelBounds(-10, 10, -10, 10, 0, 0) + self.cursor.SetFocalPoint(0, 0, 0) + self.cursor.AllOff() + self.cursor.AxesOn() + self.cursorActor.PickableOff() + self.cursorActor.VisibilityOn() + self.cursorActor.GetProperty().SetColor(1, 1, 1) + self.cursorActor.SetLayerNumber(1) + self.cursorMapper.SetInputData(self.cursor.GetOutput()) + self.cursorActor.SetMapper(self.cursorMapper) + + # Zoom + self.InitialCameraPosition = () + + # XY Plot actor for histogram + self.displayHistogram = False + self.firstHistogram = 0 + self.roiIA = vtk.vtkImageAccumulate() + self.roiVOI = vtk.vtkExtractVOI() + self.histogramPlotActor = vtk.vtkXYPlotActor() + self.histogramPlotActor.ExchangeAxesOff(); + self.histogramPlotActor.SetXLabelFormat( "%g" ) + self.histogramPlotActor.SetXLabelFormat( "%g" ) + self.histogramPlotActor.SetAdjustXLabels(3) + self.histogramPlotActor.SetXTitle( "Level" ) + self.histogramPlotActor.SetYTitle( "N" ) + self.histogramPlotActor.SetXValuesToValue() + self.histogramPlotActor.SetPlotColor(0, (0,1,1) ) + self.histogramPlotActor.SetPosition(0.6,0.6) + self.histogramPlotActor.SetPosition2(0.4,0.4) + + + + def GetInteractor(self): + return self.iren + + def GetRenderer(self): + return self.ren + + def setInput3DData(self, imageData): + self.img3D = imageData + self.installPipeline() + + def setInputAsNumpy(self, numpyarray, origin=(0,0,0), spacing=(1.,1.,1.), + rescale=True, dtype=vtk.VTK_UNSIGNED_SHORT): + importer = Converter.numpy2vtkImporter(numpyarray, spacing, origin) + importer.Update() + + if rescale: + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(importer.GetOutput()) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + if (iMax - iMin == 0): + scale = 1 + else: + if dtype == vtk.VTK_UNSIGNED_SHORT: + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + elif dtype == vtk.VTK_UNSIGNED_INT: + scale = vtk.VTK_UNSIGNED_INT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(importer.GetOutput()) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(-iMin) + shiftScaler.SetOutputScalarType(dtype) + shiftScaler.Update() + self.img3D = shiftScaler.GetOutput() + else: + self.img3D = importer.GetOutput() + + self.installPipeline() + + def displaySlice(self, sliceno = 0): + self.sliceno = sliceno + + self.updatePipeline() + + self.renWin.Render() + + return self.sliceActorNo + + def updatePipeline(self, resetcamera = False): + extent = [ i for i in self.img3D.GetExtent()] + extent[self.sliceOrientation * 2] = self.sliceno + extent[self.sliceOrientation * 2 + 1] = self.sliceno + self.voi.SetVOI(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + + self.voi.Update() + self.ia.Update() + self.wl.Update() + self.sliceActor.SetDisplayExtent(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + self.sliceActor.Update() + + self.updateCornerAnnotation("Slice %d/%d" % (self.sliceno + 1 , self.img3D.GetDimensions()[self.sliceOrientation])) + + if self.displayHistogram: + self.updateROIHistogram() + + self.AdjustCamera(resetcamera) + + self.renWin.Render() + + + def installPipeline(self): + '''Slices a 3D volume and then creates an actor to be rendered''' + + self.ren.AddViewProp(self.cornerAnnotation) + + self.voi.SetInputData(self.img3D) + #select one slice in Z + extent = [ i for i in self.img3D.GetExtent()] + extent[self.sliceOrientation * 2] = self.sliceno + extent[self.sliceOrientation * 2 + 1] = self.sliceno + self.voi.SetVOI(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + + self.voi.Update() + # set window/level for current slices + + + self.wl = vtk.vtkImageMapToWindowLevelColors() + self.ia.SetInputData(self.voi.GetOutput()) + self.ia.Update() + cmax = self.ia.GetMax()[0] + cmin = self.ia.GetMin()[0] + + self.InitialLevel = (cmax+cmin)/2 + self.InitialWindow = cmax-cmin + + + self.wl.SetLevel(self.InitialLevel) + self.wl.SetWindow(self.InitialWindow) + + self.wl.SetInputData(self.voi.GetOutput()) + self.wl.Update() + + self.sliceActor.SetInputData(self.wl.GetOutput()) + self.sliceActor.SetDisplayExtent(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + self.sliceActor.Update() + self.sliceActor.SetInterpolate(False) + self.ren.AddActor(self.sliceActor) + self.ren.ResetCamera() + self.ren.Render() + + self.AdjustCamera() + + self.ren.AddViewProp(self.cursorActor) + self.cursorActor.VisibilityOn() + + self.iren.Initialize() + self.renWin.Render() + #self.iren.Start() + + def AdjustCamera(self, resetcamera = False): + self.ren.ResetCameraClippingRange() + if resetcamera: + self.ren.ResetCamera() + + + def getROI(self): + return self.ROI + + def getROIExtent(self): + p0 = self.ROI[0] + p1 = self.ROI[1] + return (p0[0], p1[0],p0[1],p1[1],p0[2],p1[2]) + + ############### Handle events are moved to the interactor style + + + def viewport2imageCoordinate(self, viewerposition): + #Determine point index + + self.picker.Pick(viewerposition[0], viewerposition[1], 0.0, self.GetRenderer()) + pickPosition = list(self.picker.GetPickPosition()) + pickPosition[self.sliceOrientation] = \ + self.img3D.GetSpacing()[self.sliceOrientation] * self.sliceno + \ + self.img3D.GetOrigin()[self.sliceOrientation] + print ("Pick Position " + str (pickPosition)) + + if (pickPosition != [0,0,0]): + dims = self.img3D.GetDimensions() + print (dims) + spac = self.img3D.GetSpacing() + orig = self.img3D.GetOrigin() + imagePosition = [int(pickPosition[i] / spac[i] + orig[i]) for i in range(3) ] + + pixelValue = self.img3D.GetScalarComponentAsDouble(imagePosition[0], imagePosition[1], imagePosition[2], 0) + return (imagePosition[0], imagePosition[1], imagePosition[2] , pixelValue) + else: + return (0,0,0,0) + + + + def GetRenderWindow(self): + return self.renWin + + + def startRenderLoop(self): + self.iren.Start() + + def GetSliceOrientation(self): + return self.sliceOrientation + + def GetActiveSlice(self): + return self.sliceno + + def updateCornerAnnotation(self, text , idx=0, visibility=True): + if visibility: + self.cornerAnnotation.VisibilityOn() + else: + self.cornerAnnotation.VisibilityOff() + + self.cornerAnnotation.SetText(idx, text) + self.iren.Render() + + def saveRender(self, filename, renWin=None): + '''Save the render window to PNG file''' + # screenshot code: + w2if = vtk.vtkWindowToImageFilter() + if renWin == None: + renWin = self.renWin + w2if.SetInput(renWin) + w2if.Update() + + writer = vtk.vtkPNGWriter() + writer.SetFileName("%s.png" % (filename)) + writer.SetInputConnection(w2if.GetOutputPort()) + writer.Write() + + def updateROIHistogram(self): + + extent = [0 for i in range(6)] + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + print ("slice orientation : XY") + extent[0] = self.ROI[0][0] + extent[1] = self.ROI[1][0] + extent[2] = self.ROI[0][1] + extent[3] = self.ROI[1][1] + extent[4] = self.GetActiveSlice() + extent[5] = self.GetActiveSlice()+1 + #y = abs(roi[1][1] - roi[0][1]) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + print ("slice orientation : XY") + extent[0] = self.ROI[0][0] + extent[1] = self.ROI[1][0] + #x = abs(roi[1][0] - roi[0][0]) + extent[4] = self.ROI[0][2] + extent[5] = self.ROI[1][2] + #y = abs(roi[1][2] - roi[0][2]) + extent[2] = self.GetActiveSlice() + extent[3] = self.GetActiveSlice()+1 + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + print ("slice orientation : XY") + extent[2] = self.ROI[0][1] + extent[3] = self.ROI[1][1] + #x = abs(roi[1][1] - roi[0][1]) + extent[4] = self.ROI[0][2] + extent[5] = self.ROI[1][2] + #y = abs(roi[1][2] - roi[0][2]) + extent[0] = self.GetActiveSlice() + extent[1] = self.GetActiveSlice()+1 + + self.roiVOI.SetVOI(extent) + self.roiVOI.SetInputData(self.img3D) + self.roiVOI.Update() + irange = self.roiVOI.GetOutput().GetScalarRange() + + self.roiIA.SetInputData(self.roiVOI.GetOutput()) + self.roiIA.IgnoreZeroOff() + self.roiIA.SetComponentExtent(0,int(irange[1]-irange[0]-1),0,0,0,0 ) + self.roiIA.SetComponentOrigin( int(irange[0]),0,0 ); + self.roiIA.SetComponentSpacing( 1,0,0 ); + self.roiIA.Update() + + self.histogramPlotActor.AddDataSetInputConnection(self.roiIA.GetOutputPort()) + self.histogramPlotActor.SetXRange(irange[0],irange[1]) + + self.histogramPlotActor.SetYRange( self.roiIA.GetOutput().GetScalarRange() ) + + \ No newline at end of file diff --git a/src/Python/ccpi/viewer/QVTKWidget.py b/src/Python/ccpi/viewer/QVTKWidget.py new file mode 100644 index 0000000..906786b --- /dev/null +++ b/src/Python/ccpi/viewer/QVTKWidget.py @@ -0,0 +1,340 @@ +################################################################################ +# File: QVTKWidget.py +# Author: Edoardo Pasca +# Description: PyVE Viewer Qt widget +# +# License: +# This file is part of PyVE. PyVE is an open-source image +# analysis and visualization environment focused on medical +# imaging. More info at http://pyve.sourceforge.net +# +# Copyright (c) 2011-2012 Edoardo Pasca, Lukas Batteau +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or +# without modification, are permitted provided that the following +# conditions are met: +# +# Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. Neither name of Edoardo Pasca or Lukas +# Batteau nor the names of any contributors may be used to endorse +# or promote products derived from this software without specific +# prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, +# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR +# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +# OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY +# OF SUCH DAMAGE. +# +# CHANGE HISTORY +# +# 20120118 Edoardo Pasca Initial version +# +############################################################################### + +import os +from PyQt5 import QtCore, QtGui, QtWidgets +#import itk +import vtk +#from viewer import PyveViewer +from ccpi.viewer.CILViewer2D import CILViewer2D , Converter + +class QVTKWidget(QtWidgets.QWidget): + + """ A QVTKWidget for Python and Qt.""" + + # Map between VTK and Qt cursors. + _CURSOR_MAP = { + 0: QtCore.Qt.ArrowCursor, # VTK_CURSOR_DEFAULT + 1: QtCore.Qt.ArrowCursor, # VTK_CURSOR_ARROW + 2: QtCore.Qt.SizeBDiagCursor, # VTK_CURSOR_SIZENE + 3: QtCore.Qt.SizeFDiagCursor, # VTK_CURSOR_SIZENWSE + 4: QtCore.Qt.SizeBDiagCursor, # VTK_CURSOR_SIZESW + 5: QtCore.Qt.SizeFDiagCursor, # VTK_CURSOR_SIZESE + 6: QtCore.Qt.SizeVerCursor, # VTK_CURSOR_SIZENS + 7: QtCore.Qt.SizeHorCursor, # VTK_CURSOR_SIZEWE + 8: QtCore.Qt.SizeAllCursor, # VTK_CURSOR_SIZEALL + 9: QtCore.Qt.PointingHandCursor, # VTK_CURSOR_HAND + 10: QtCore.Qt.CrossCursor, # VTK_CURSOR_CROSSHAIR + } + + def __init__(self, parent=None, wflags=QtCore.Qt.WindowFlags(), **kw): + # the current button + self._ActiveButton = QtCore.Qt.NoButton + + # private attributes + self.__oldFocus = None + self.__saveX = 0 + self.__saveY = 0 + self.__saveModifiers = QtCore.Qt.NoModifier + self.__saveButtons = QtCore.Qt.NoButton + self.__timeframe = 0 + + # create qt-level widget + QtWidgets.QWidget.__init__(self, parent, wflags|QtCore.Qt.MSWindowsOwnDC) + + # Link to PyVE Viewer + self._PyveViewer = CILViewer2D() + #self._Viewer = self._PyveViewer._vtkPyveViewer + + self._Iren = self._PyveViewer.GetInteractor() + #self._Iren = self._Viewer.GetRenderWindow().GetInteractor() + self._RenderWindow = self._PyveViewer.GetRenderWindow() + #self._RenderWindow = self._Viewer.GetRenderWindow() + + self._Iren.Register(self._RenderWindow) + self._Iren.SetRenderWindow(self._RenderWindow) + self._RenderWindow.SetWindowInfo(str(int(self.winId()))) + + # do all the necessary qt setup + self.setAttribute(QtCore.Qt.WA_OpaquePaintEvent) + self.setAttribute(QtCore.Qt.WA_PaintOnScreen) + self.setMouseTracking(True) # get all mouse events + self.setFocusPolicy(QtCore.Qt.WheelFocus) + self.setSizePolicy(QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)) + + self._Timer = QtCore.QTimer(self) + #self.connect(self._Timer, QtCore.pyqtSignal('timeout()'), self.TimerEvent) + + self._Iren.AddObserver('CreateTimerEvent', self.CreateTimer) + self._Iren.AddObserver('DestroyTimerEvent', self.DestroyTimer) + self._Iren.GetRenderWindow().AddObserver('CursorChangedEvent', + self.CursorChangedEvent) + + # Destructor + def __del__(self): + self._Iren.UnRegister(self._RenderWindow) + #QtWidgets.QWidget.__del__(self) + + # Display image data + def SetInput(self, imageData): + self._PyveViewer.setInput3DData(imageData) + + # GetInteractor + def GetInteractor(self): + return self._Iren + + # Display image data + def GetPyveViewer(self): + return self._PyveViewer + + def __getattr__(self, attr): + """Makes the object behave like a vtkGenericRenderWindowInteractor""" + print (attr) + if attr == '__vtk__': + return lambda t=self._Iren: t + elif hasattr(self._Iren, attr): + return getattr(self._Iren, attr) +# else: +# raise AttributeError( self.__class__.__name__ + \ +# " has no attribute named " + attr ) + + def CreateTimer(self, obj, evt): + self._Timer.start(10) + + def DestroyTimer(self, obj, evt): + self._Timer.stop() + return 1 + + def TimerEvent(self): + self._Iren.InvokeEvent("TimerEvent") + + def CursorChangedEvent(self, obj, evt): + """Called when the CursorChangedEvent fires on the render window.""" + # This indirection is needed since when the event fires, the current + # cursor is not yet set so we defer this by which time the current + # cursor should have been set. + QtCore.QTimer.singleShot(0, self.ShowCursor) + + def HideCursor(self): + """Hides the cursor.""" + self.setCursor(QtCore.Qt.BlankCursor) + + def ShowCursor(self): + """Shows the cursor.""" + vtk_cursor = self._Iren.GetRenderWindow().GetCurrentCursor() + qt_cursor = self._CURSOR_MAP.get(vtk_cursor, QtCore.Qt.ArrowCursor) + self.setCursor(qt_cursor) + + def sizeHint(self): + return QtCore.QSize(400, 400) + + def paintEngine(self): + return None + + def paintEvent(self, ev): + self._RenderWindow.Render() + + def resizeEvent(self, ev): + self._RenderWindow.Render() + w = self.width() + h = self.height() + + self._RenderWindow.SetSize(w, h) + self._Iren.SetSize(w, h) + + def _GetCtrlShiftAlt(self, ev): + ctrl = shift = alt = False + + if hasattr(ev, 'modifiers'): + if ev.modifiers() & QtCore.Qt.ShiftModifier: + shift = True + if ev.modifiers() & QtCore.Qt.ControlModifier: + ctrl = True + if ev.modifiers() & QtCore.Qt.AltModifier: + alt = True + else: + if self.__saveModifiers & QtCore.Qt.ShiftModifier: + shift = True + if self.__saveModifiers & QtCore.Qt.ControlModifier: + ctrl = True + if self.__saveModifiers & QtCore.Qt.AltModifier: + alt = True + + return ctrl, shift, alt + + def enterEvent(self, ev): + if not self.hasFocus(): + self.__oldFocus = self.focusWidget() + self.setFocus() + + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("EnterEvent") + + def leaveEvent(self, ev): + if self.__saveButtons == QtCore.Qt.NoButton and self.__oldFocus: + self.__oldFocus.setFocus() + self.__oldFocus = None + + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("LeaveEvent") + + def mousePressEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + repeat = 0 + if ev.type() == QtCore.QEvent.MouseButtonDblClick: + repeat = 1 + self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), + ctrl, shift, chr(0), repeat, None) + + self._Iren.SetAltKey(alt) + self._ActiveButton = ev.button() + + if self._ActiveButton == QtCore.Qt.LeftButton: + self._Iren.InvokeEvent("LeftButtonPressEvent") + elif self._ActiveButton == QtCore.Qt.RightButton: + self._Iren.InvokeEvent("RightButtonPressEvent") + elif self._ActiveButton == QtCore.Qt.MidButton: + self._Iren.InvokeEvent("MiddleButtonPressEvent") + + def mouseReleaseEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + + if self._ActiveButton == QtCore.Qt.LeftButton: + self._Iren.InvokeEvent("LeftButtonReleaseEvent") + elif self._ActiveButton == QtCore.Qt.RightButton: + self._Iren.InvokeEvent("RightButtonReleaseEvent") + elif self._ActiveButton == QtCore.Qt.MidButton: + self._Iren.InvokeEvent("MiddleButtonReleaseEvent") + + def mouseMoveEvent(self, ev): + self.__saveModifiers = ev.modifiers() + self.__saveButtons = ev.buttons() + self.__saveX = ev.x() + self.__saveY = ev.y() + + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("MouseMoveEvent") + + def keyPressEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + if ev.key() < 256: + key = str(ev.text()) + else: + key = chr(0) + + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, key, 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("KeyPressEvent") + self._Iren.InvokeEvent("CharEvent") + + def keyReleaseEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + if ev.key() < 256: + key = chr(ev.key()) + else: + key = chr(0) + + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, key, 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("KeyReleaseEvent") + + def wheelEvent(self, ev): + print ("angleDeltaX %d" % ev.angleDelta().x()) + print ("angleDeltaY %d" % ev.angleDelta().y()) + if ev.angleDelta().y() >= 0: + self._Iren.InvokeEvent("MouseWheelForwardEvent") + else: + self._Iren.InvokeEvent("MouseWheelBackwardEvent") + + def GetRenderWindow(self): + return self._RenderWindow + + def Render(self): + self.update() + + +def QVTKExample(): + """A simple example that uses the QVTKWidget class.""" + + # every QT app needs an app + app = QtWidgets.QApplication(['PyVE QVTKWidget Example']) + page_VTK = QtWidgets.QWidget() + page_VTK.resize(500,500) + layout = QtWidgets.QVBoxLayout(page_VTK) + # create the widget + widget = QVTKWidget(parent=None) + layout.addWidget(widget) + + #reader = vtk.vtkPNGReader() + #reader.SetFileName("F:\Diagnostics\Images\PyVE\VTKData\Data\camscene.png") + reader = vtk.vtkMetaImageReader() + reader.SetFileName("C:\\Users\\ofn77899\\Documents\\GitHub\\CCPi-Simpleflex\\data\\head.mha") + reader.Update() + + widget.SetInput(reader.GetOutput()) + + # show the widget + page_VTK.show() + # start event processing + app.exec_() + +if __name__ == "__main__": + QVTKExample() diff --git a/src/Python/ccpi/viewer/QVTKWidget2.py b/src/Python/ccpi/viewer/QVTKWidget2.py new file mode 100644 index 0000000..e32e1c2 --- /dev/null +++ b/src/Python/ccpi/viewer/QVTKWidget2.py @@ -0,0 +1,84 @@ +################################################################################ +# File: QVTKWidget.py +# Author: Edoardo Pasca +# Description: PyVE Viewer Qt widget +# +# License: +# This file is part of PyVE. PyVE is an open-source image +# analysis and visualization environment focused on medical +# imaging. More info at http://pyve.sourceforge.net +# +# Copyright (c) 2011-2012 Edoardo Pasca, Lukas Batteau +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or +# without modification, are permitted provided that the following +# conditions are met: +# +# Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. Neither name of Edoardo Pasca or Lukas +# Batteau nor the names of any contributors may be used to endorse +# or promote products derived from this software without specific +# prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, +# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR +# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +# OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY +# OF SUCH DAMAGE. +# +# CHANGE HISTORY +# +# 20120118 Edoardo Pasca Initial version +# +############################################################################### + +import os +from PyQt5 import QtCore, QtGui, QtWidgets +#import itk +import vtk +#from viewer import PyveViewer +from ccpi.viewer.CILViewer2D import CILViewer2D , Converter +from vtk.qt.QVTKRenderWindowInteractor import QVTKRenderWindowInteractor + +class QVTKWidget(QVTKRenderWindowInteractor): + + + def __init__(self, parent=None, wflags=QtCore.Qt.WindowFlags(), **kw): + kw = dict() + super().__init__(parent, **kw) + + + # Link to PyVE Viewer + self._PyveViewer = CILViewer2D(400,400) + #self._Viewer = self._PyveViewer._vtkPyveViewer + + self._Iren = self._PyveViewer.GetInteractor() + kw['iren'] = self._Iren + #self._Iren = self._Viewer.GetRenderWindow().GetInteractor() + self._RenderWindow = self._PyveViewer.GetRenderWindow() + #self._RenderWindow = self._Viewer.GetRenderWindow() + kw['rw'] = self._RenderWindow + + + + + def GetInteractor(self): + return self._Iren + + # Display image data + def SetInput(self, imageData): + self._PyveViewer.setInput3DData(imageData) + \ No newline at end of file diff --git a/src/Python/ccpi/viewer/__init__.py b/src/Python/ccpi/viewer/__init__.py new file mode 100644 index 0000000..946188b --- /dev/null +++ b/src/Python/ccpi/viewer/__init__.py @@ -0,0 +1 @@ +from ccpi.viewer.CILViewer import CILViewer \ No newline at end of file diff --git a/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc new file mode 100644 index 0000000..711f77a Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc new file mode 100644 index 0000000..77c2ca8 Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc new file mode 100644 index 0000000..3d11b87 Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc new file mode 100644 index 0000000..2fa2eaf Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc new file mode 100644 index 0000000..fcea537 Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/embedvtk.py b/src/Python/ccpi/viewer/embedvtk.py new file mode 100644 index 0000000..b5eb0a7 --- /dev/null +++ b/src/Python/ccpi/viewer/embedvtk.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Jul 27 12:18:58 2017 + +@author: ofn77899 +""" + +#!/usr/bin/env python + +import sys +import vtk +from PyQt5 import QtCore, QtWidgets +from vtk.qt.QVTKRenderWindowInteractor import QVTKRenderWindowInteractor +import QVTKWidget2 + +class MainWindow(QtWidgets.QMainWindow): + + def __init__(self, parent = None): + QtWidgets.QMainWindow.__init__(self, parent) + + self.frame = QtWidgets.QFrame() + + self.vl = QtWidgets.QVBoxLayout() +# self.vtkWidget = QVTKRenderWindowInteractor(self.frame) + + self.vtkWidget = QVTKWidget2.QVTKWidget(self.frame) + self.iren = self.vtkWidget.GetInteractor() + self.vl.addWidget(self.vtkWidget) + + + + + self.ren = vtk.vtkRenderer() + self.vtkWidget.GetRenderWindow().AddRenderer(self.ren) +# self.iren = self.vtkWidget.GetRenderWindow().GetInteractor() +# +# # Create source +# source = vtk.vtkSphereSource() +# source.SetCenter(0, 0, 0) +# source.SetRadius(5.0) +# +# # Create a mapper +# mapper = vtk.vtkPolyDataMapper() +# mapper.SetInputConnection(source.GetOutputPort()) +# +# # Create an actor +# actor = vtk.vtkActor() +# actor.SetMapper(mapper) +# +# self.ren.AddActor(actor) +# +# self.ren.ResetCamera() +# + self.frame.setLayout(self.vl) + self.setCentralWidget(self.frame) + reader = vtk.vtkMetaImageReader() + reader.SetFileName("C:\\Users\\ofn77899\\Documents\\GitHub\\CCPi-Simpleflex\\data\\head.mha") + reader.Update() + + self.vtkWidget.SetInput(reader.GetOutput()) + + #self.vktWidget.Initialize() + #self.vktWidget.Start() + + self.show() + #self.iren.Initialize() + + +if __name__ == "__main__": + + app = QtWidgets.QApplication(sys.argv) + + window = MainWindow() + + sys.exit(app.exec_()) \ No newline at end of file -- cgit v1.2.3 From ad62962697509d977087c25d24a3ff083d9c4308 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 15:08:32 +0100 Subject: initial revision --- src/Python/ccpi/imaging/Regularizer.py | 322 +++++++++++++++++++++++++++++++++ 1 file changed, 322 insertions(+) create mode 100644 src/Python/ccpi/imaging/Regularizer.py (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/imaging/Regularizer.py b/src/Python/ccpi/imaging/Regularizer.py new file mode 100644 index 0000000..fb9ae08 --- /dev/null +++ b/src/Python/ccpi/imaging/Regularizer.py @@ -0,0 +1,322 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Aug 8 14:26:00 2017 + +@author: ofn77899 +""" + +from ccpi.imaging import cpu_regularizers +import numpy as np +from enum import Enum +import timeit + +class Regularizer(): + '''Class to handle regularizer algorithms to be used during reconstruction + + Currently 5 CPU (OMP) regularization algorithms are available: + + 1) SplitBregman_TV + 2) FGP_TV + 3) LLT_model + 4) PatchBased_Regul + 5) TGV_PD + + Usage: + the regularizer can be invoked as object or as static method + Depending on the actual regularizer the input parameter may vary, and + a different default setting is defined. + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + + out = reg(input=u0, regularization_parameter=10., number_of_iterations=30, + tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., + number_of_iterations=30, tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + A number of optional parameters can be passed or skipped + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) + + ''' + class Algorithm(Enum): + SplitBregman_TV = cpu_regularizers.SplitBregman_TV + FGP_TV = cpu_regularizers.FGP_TV + LLT_model = cpu_regularizers.LLT_model + PatchBased_Regul = cpu_regularizers.PatchBased_Regul + TGV_PD = cpu_regularizers.TGV_PD + # Algorithm + + class TotalVariationPenalty(Enum): + isotropic = 0 + l1 = 1 + # TotalVariationPenalty + + def __init__(self , algorithm, debug = True): + self.setAlgorithm ( algorithm ) + self.debug = debug + # __init__ + + def setAlgorithm(self, algorithm): + self.algorithm = algorithm + self.pars = self.getDefaultParsForAlgorithm(algorithm) + # setAlgorithm + + def getDefaultParsForAlgorithm(self, algorithm): + pars = dict() + + if algorithm == Regularizer.Algorithm.SplitBregman_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 35 + pars['tolerance_constant'] = 0.0001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + + elif algorithm == Regularizer.Algorithm.FGP_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 50 + pars['tolerance_constant'] = 0.001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + + elif algorithm == Regularizer.Algorithm.LLT_model: + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['time_step'] = None + pars['number_of_iterations'] = None + pars['tolerance_constant'] = None + pars['restrictive_Z_smoothing'] = 0 + + elif algorithm == Regularizer.Algorithm.PatchBased_Regul: + pars['algorithm'] = algorithm + pars['input'] = None + pars['searching_window_ratio'] = None + pars['similarity_window_ratio'] = None + pars['PB_filtering_parameter'] = None + pars['regularization_parameter'] = None + + elif algorithm == Regularizer.Algorithm.TGV_PD: + pars['algorithm'] = algorithm + pars['input'] = None + pars['first_order_term'] = None + pars['second_order_term'] = None + pars['number_of_iterations'] = None + pars['regularization_parameter'] = None + + else: + raise Exception('Unknown regularizer algorithm') + + return pars + # parsForAlgorithm + + def setParameter(self, **kwargs): + '''set named parameter for the regularization engine + + raises Exception if the named parameter is not recognized + Typical usage is: + + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0) + reg.setParameter(regularization_parameter=10.) + + it can be also used as + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0 , regularization_parameter=10.) + ''' + + for key , value in kwargs.items(): + if key in self.pars.keys(): + self.pars[key] = value + else: + raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) + # setParameter + + def getParameter(self, **kwargs): + ret = {} + for key , value in kwargs.items(): + if key in self.pars.keys(): + ret[key] = self.pars[key] + else: + raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) + # setParameter + + + def __call__(self, input = None, regularization_parameter = None, **kwargs): + '''Actual call for the regularizer. + + One can either set the regularization parameters first and then call the + algorithm or set the regularization parameter during the call (as + is done in the static methods). + ''' + + if kwargs is not None: + for key, value in kwargs.items(): + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + if input is not None: + self.pars['input'] = input + if regularization_parameter is not None: + self.pars['regularization_parameter'] = regularization_parameter + + if self.debug: + print ("--------------------------------------------------") + for key, value in self.pars.items(): + if key== 'algorithm' : + print("{0} = {1}".format(key, value.__name__)) + elif key == 'input': + print("{0} = {1}".format(key, np.shape(value))) + else: + print("{0} = {1}".format(key, value)) + + + if None in self.pars: + raise Exception("Not all parameters have been provided") + + input = self.pars['input'] + regularization_parameter = self.pars['regularization_parameter'] + if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.FGP_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.LLT_model : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, + regularization_parameter, + self.pars['time_step'] , + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['restrictive_Z_smoothing'] ) + elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, regularization_parameter, + self.pars['searching_window_ratio'] , + self.pars['similarity_window_ratio'] , + self.pars['PB_filtering_parameter']) + elif self.algorithm == Regularizer.Algorithm.TGV_PD : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + if len(np.shape(input)) == 2: + return self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + elif len(np.shape(input)) == 3: + #assuming it's 3D + # run independent calls on each slice + out3d = input.copy() + for i in range(np.shape(input)[2]): + out = self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + # copy the result in the 3D image + out3d.T[i] = out[0].copy() + # append the rest of the info that the algorithm returns + output = [out3d] + for i in range(1,len(out)): + output.append(out[i]) + return output + + + + + + # __call__ + + @staticmethod + def SplitBregman_TV(input, regularization_parameter , **kwargs): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def FGP_TV(input, regularization_parameter , **kwargs): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.FGP_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def LLT_model(input, regularization_parameter , time_step, number_of_iterations, + tolerance_constant, restrictive_Z_smoothing=0): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.LLT_model) + out = list( reg(input, regularization_parameter, time_step=time_step, + number_of_iterations=number_of_iterations, + tolerance_constant=tolerance_constant, + restrictive_Z_smoothing=restrictive_Z_smoothing) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def PatchBased_Regul(input, regularization_parameter, + searching_window_ratio, + similarity_window_ratio, + PB_filtering_parameter): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul) + out = list( reg(input, + regularization_parameter, + searching_window_ratio=searching_window_ratio, + similarity_window_ratio=similarity_window_ratio, + PB_filtering_parameter=PB_filtering_parameter ) + ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def TGV_PD(input, regularization_parameter , first_order_term, + second_order_term, number_of_iterations): + start_time = timeit.default_timer() + + reg = Regularizer(Regularizer.Algorithm.TGV_PD) + out = list( reg(input, regularization_parameter, + first_order_term=first_order_term, + second_order_term=second_order_term, + number_of_iterations=number_of_iterations) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + + return out + + def printParametersToString(self): + txt = r'' + for key, value in self.pars.items(): + if key== 'algorithm' : + txt += "{0} = {1}".format(key, value.__name__) + elif key == 'input': + txt += "{0} = {1}".format(key, np.shape(value)) + else: + txt += "{0} = {1}".format(key, value) + txt += '\n' + return txt + -- cgit v1.2.3 From 56915cc00ded38d24c23b9ab1a0717d52d430ddd Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 16:54:59 +0100 Subject: initial revision for testing --- .../ccpi/reconstruction/FISTAReconstructor.py | 354 +++++++++++++++++++++ 1 file changed, 354 insertions(+) create mode 100644 src/Python/ccpi/reconstruction/FISTAReconstructor.py (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py new file mode 100644 index 0000000..ea96b53 --- /dev/null +++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py @@ -0,0 +1,354 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +#from ccpi.reconstruction.parallelbeam import alg + +from ccpi.imaging.Regularizer import Regularizer +from enum import Enum + +import astra + + + +class FISTAReconstructor(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', + 'Lipschitz_constant' , + 'ideal_image' , + 'weights' , + 'region_of_interest' , + 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else: + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer is not None: + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5" +##nx = h5py.File(fname, "r") +## +### the data are stored in a particular location in the hdf5 +##for item in nx['entry1/tomo_entry/data'].keys(): +## print (item) +## +##data = nx.get('entry1/tomo_entry/data/rotation_angle') +##angles = numpy.zeros(data.shape) +##data.read_direct(angles) +##print (angles) +### angles should be in degrees +## +##data = nx.get('entry1/tomo_entry/data/data') +##stack = numpy.zeros(data.shape) +##data.read_direct(stack) +##print (data.shape) +## +##print ("Data Loaded") +## +## +### Normalize +##data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +##itype = numpy.zeros(data.shape) +##data.read_direct(itype) +### 2 is dark field +##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +##dark = darks[0] +##for i in range(1, len(darks)): +## dark += darks[i] +##dark = dark / len(darks) +###dark[0][0] = dark[0][1] +## +### 1 is flat field +##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +##flat = flats[0] +##for i in range(1, len(flats)): +## flat += flats[i] +##flat = flat / len(flats) +###flat[0][0] = dark[0][1] +## +## +### 0 is projection data +##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +##angle_proj = numpy.asarray (angle_proj) +##angle_proj = angle_proj.astype(numpy.float32) +## +### normalized data are +### norm = (projection - dark)/(flat-dark) +## +##def normalize(projection, dark, flat, def_val=0.1): +## a = (projection - dark) +## b = (flat-dark) +## with numpy.errstate(divide='ignore', invalid='ignore'): +## c = numpy.true_divide( a, b ) +## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 +## return c +## +## +##norm = [normalize(projection, dark, flat) for projection in proj] +##norm = numpy.asarray (norm) +##norm = norm.astype(numpy.float32) + + +##niterations = 15 +##threads = 3 +## +##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +## +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## iteration_values, False) +##print ("iteration values %s" % str(iteration_values)) +## +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## numpy.double(1e-5), iteration_values , False) +##print ("iteration values %s" % str(iteration_values)) +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## numpy.double(1e-5), iteration_values , False) +##print ("iteration values %s" % str(iteration_values)) +## +## +####numpy.save("cgls_recon.npy", img_data) +##import matplotlib.pyplot as plt +##fig, ax = plt.subplots(1,6,sharey=True) +##ax[0].imshow(img_cgls[80]) +##ax[0].axis('off') # clear x- and y-axes +##ax[1].imshow(img_sirt[80]) +##ax[1].axis('off') # clear x- and y-axes +##ax[2].imshow(img_mlem[80]) +##ax[2].axis('off') # clear x- and y-axesplt.show() +##ax[3].imshow(img_cgls_conv[80]) +##ax[3].axis('off') # clear x- and y-axesplt.show() +##ax[4].imshow(img_cgls_tikhonov[80]) +##ax[4].axis('off') # clear x- and y-axesplt.show() +##ax[5].imshow(img_cgls_TVreg[80]) +##ax[5].axis('off') # clear x- and y-axesplt.show() +## +## +##plt.show() +## + -- cgit v1.2.3 From 7111d98258becca09e4c93e3c66edb7d524d6463 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 24 Aug 2017 16:42:05 +0100 Subject: initial revision --- src/Python/ccpi/imaging/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/Python/ccpi/imaging/__init__.py (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/imaging/__init__.py b/src/Python/ccpi/imaging/__init__.py new file mode 100644 index 0000000..e69de29 -- cgit v1.2.3 From 3f26b1d8ab3a632ceca97bdf04225008f9163684 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 24 Aug 2017 16:42:27 +0100 Subject: initial revision --- src/Python/ccpi/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/Python/ccpi/__init__.py (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/__init__.py b/src/Python/ccpi/__init__.py new file mode 100644 index 0000000..e69de29 -- cgit v1.2.3 From 391473269674bc98697eabac0b4fb2bd89f5d85e Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 25 Aug 2017 16:55:48 +0100 Subject: Reorganized code with new fista package name --- src/Python/ccpi/fista/FISTAReconstructor.py | 389 ++++++++++++++ src/Python/ccpi/fista/FISTAReconstructor.pyc | Bin 0 -> 3804 bytes src/Python/ccpi/fista/FISTAReconstructor.py~ | 349 ++++++++++++ src/Python/ccpi/fista/Reconstructor.py | 425 +++++++++++++++ src/Python/ccpi/fista/Reconstructor.py~ | 598 +++++++++++++++++++++ src/Python/ccpi/fista/__init__.py | 0 src/Python/ccpi/fista/__init__.pyc | Bin 0 -> 189 bytes .../__pycache__/FISTAReconstructor.cpython-35.pyc | Bin 0 -> 3641 bytes .../ccpi/fista/__pycache__/__init__.cpython-35.pyc | Bin 0 -> 185 bytes 9 files changed, 1761 insertions(+) create mode 100644 src/Python/ccpi/fista/FISTAReconstructor.py create mode 100644 src/Python/ccpi/fista/FISTAReconstructor.pyc create mode 100644 src/Python/ccpi/fista/FISTAReconstructor.py~ create mode 100644 src/Python/ccpi/fista/Reconstructor.py create mode 100644 src/Python/ccpi/fista/Reconstructor.py~ create mode 100644 src/Python/ccpi/fista/__init__.py create mode 100644 src/Python/ccpi/fista/__init__.pyc create mode 100644 src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc create mode 100644 src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py new file mode 100644 index 0000000..1e76815 --- /dev/null +++ b/src/Python/ccpi/fista/FISTAReconstructor.py @@ -0,0 +1,389 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +#from ccpi.reconstruction.parallelbeam import alg + +#from ccpi.imaging.Regularizer import Regularizer +from enum import Enum + +import astra + + + +class FISTAReconstructor(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + # handle parmeters: + # obligatory parameters + self.pars = dict() + self.pars['projector_geometry'] = projector_geometry + self.pars['output_geometry'] = output_geometry + self.pars['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.pars['detectors'] = detectors + self.pars['number_og_angles'] = nangles + self.pars['SlicesZ'] = sliceZ + + print (self.pars) + # handle optional input parameters (at instantiation) + + # Accepted input keywords + kw = ('number_of_iterations', + 'Lipschitz_constant' , + 'ideal_image' , + 'weights' , + 'region_of_interest' , + 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.pars['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not 'ideal_image' in kwargs.keys(): + self.pars['ideal_image'] = None + + if not 'region_of_interest'in kwargs.keys() : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not 'regularizer' in kwargs.keys() : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not 'ring_lambda_R_L1' in kwargs.keys(): + self.pars['ring_lambda_R_L1'] = 0 + if not 'ring_alpha' in kwargs.keys(): + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + N = self.pars['output_geometry']['GridColCount'] + proj_geom = self.pars['projector_geometry'] + vol_geom = self.pars['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #print('Calculating Lipshitz constant for parallel beam geometry...') + niter = 5;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights[0]) + proj_geomT = proj_geom.copy(); + proj_geomT['DetectorRowCount'] = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + + + for i in range(niter): + # [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT); + # s = norm(x1(:)); + # x1 = x1/s; + # [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + # y = sqweight.*y; + # astra_mex_data3d('delete', sino_id); + # astra_mex_data3d('delete', id); + #print ("iteration {0}".format(i)) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geomT, + vol_geomT) + + y = (sqweight * y).copy() # element wise multiplication + + #b=fig.add_subplot(2,1,2) + #imgplot = plt.imshow(x1[0]) + #plt.show() + + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + del x1 + + idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(), + proj_geomT, + vol_geomT) + del y + + + s = numpy.linalg.norm(x1) + ### this line? + x1 = (x1/s).copy(); + + # ### this line? + # sino_id, y = astra.creators.create_sino3d_gpu(x1, + # proj_geomT, + # vol_geomT); + # y = sqweight * y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx) + print ("iteration {0} s= {1}".format(i,s)) + + #end + del proj_geomT + del vol_geomT + #plt.show() + else: + #% divergen beam geometry + print('Calculating Lipshitz constant for divergen beam geometry...') + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + + return s + + + def setRegularizer(self, regularizer): + if regularizer is not None: + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location, nx): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5" +##nx = h5py.File(fname, "r") +## +### the data are stored in a particular location in the hdf5 +##for item in nx['entry1/tomo_entry/data'].keys(): +## print (item) +## +##data = nx.get('entry1/tomo_entry/data/rotation_angle') +##angles = numpy.zeros(data.shape) +##data.read_direct(angles) +##print (angles) +### angles should be in degrees +## +##data = nx.get('entry1/tomo_entry/data/data') +##stack = numpy.zeros(data.shape) +##data.read_direct(stack) +##print (data.shape) +## +##print ("Data Loaded") +## +## +### Normalize +##data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +##itype = numpy.zeros(data.shape) +##data.read_direct(itype) +### 2 is dark field +##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +##dark = darks[0] +##for i in range(1, len(darks)): +## dark += darks[i] +##dark = dark / len(darks) +###dark[0][0] = dark[0][1] +## +### 1 is flat field +##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +##flat = flats[0] +##for i in range(1, len(flats)): +## flat += flats[i] +##flat = flat / len(flats) +###flat[0][0] = dark[0][1] +## +## +### 0 is projection data +##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +##angle_proj = numpy.asarray (angle_proj) +##angle_proj = angle_proj.astype(numpy.float32) +## +### normalized data are +### norm = (projection - dark)/(flat-dark) +## +##def normalize(projection, dark, flat, def_val=0.1): +## a = (projection - dark) +## b = (flat-dark) +## with numpy.errstate(divide='ignore', invalid='ignore'): +## c = numpy.true_divide( a, b ) +## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 +## return c +## +## +##norm = [normalize(projection, dark, flat) for projection in proj] +##norm = numpy.asarray (norm) +##norm = norm.astype(numpy.float32) + + +##niterations = 15 +##threads = 3 +## +##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +## +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## iteration_values, False) +##print ("iteration values %s" % str(iteration_values)) +## +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## numpy.double(1e-5), iteration_values , False) +##print ("iteration values %s" % str(iteration_values)) +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## numpy.double(1e-5), iteration_values , False) +##print ("iteration values %s" % str(iteration_values)) +## +## +####numpy.save("cgls_recon.npy", img_data) +##import matplotlib.pyplot as plt +##fig, ax = plt.subplots(1,6,sharey=True) +##ax[0].imshow(img_cgls[80]) +##ax[0].axis('off') # clear x- and y-axes +##ax[1].imshow(img_sirt[80]) +##ax[1].axis('off') # clear x- and y-axes +##ax[2].imshow(img_mlem[80]) +##ax[2].axis('off') # clear x- and y-axesplt.show() +##ax[3].imshow(img_cgls_conv[80]) +##ax[3].axis('off') # clear x- and y-axesplt.show() +##ax[4].imshow(img_cgls_tikhonov[80]) +##ax[4].axis('off') # clear x- and y-axesplt.show() +##ax[5].imshow(img_cgls_TVreg[80]) +##ax[5].axis('off') # clear x- and y-axesplt.show() +## +## +##plt.show() +## + diff --git a/src/Python/ccpi/fista/FISTAReconstructor.pyc b/src/Python/ccpi/fista/FISTAReconstructor.pyc new file mode 100644 index 0000000..ecc4d7d Binary files /dev/null and b/src/Python/ccpi/fista/FISTAReconstructor.pyc differ diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py~ b/src/Python/ccpi/fista/FISTAReconstructor.py~ new file mode 100644 index 0000000..6c7024d --- /dev/null +++ b/src/Python/ccpi/fista/FISTAReconstructor.py~ @@ -0,0 +1,349 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +#from ccpi.reconstruction.parallelbeam import alg + +from Regularizer import Regularizer +from enum import Enum + +import astra + + + +class FISTAReconstructor(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , + 'weights' , 'region_of_interest' , 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +nx = h5py.File(fname, "r") + +# the data are stored in a particular location in the hdf5 +for item in nx['entry1/tomo_entry/data'].keys(): + print (item) + +data = nx.get('entry1/tomo_entry/data/rotation_angle') +angles = numpy.zeros(data.shape) +data.read_direct(angles) +print (angles) +# angles should be in degrees + +data = nx.get('entry1/tomo_entry/data/data') +stack = numpy.zeros(data.shape) +data.read_direct(stack) +print (data.shape) + +print ("Data Loaded") + + +# Normalize +data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +itype = numpy.zeros(data.shape) +data.read_direct(itype) +# 2 is dark field +darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +dark = darks[0] +for i in range(1, len(darks)): + dark += darks[i] +dark = dark / len(darks) +#dark[0][0] = dark[0][1] + +# 1 is flat field +flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +flat = flats[0] +for i in range(1, len(flats)): + flat += flats[i] +flat = flat / len(flats) +#flat[0][0] = dark[0][1] + + +# 0 is projection data +proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = numpy.asarray (angle_proj) +angle_proj = angle_proj.astype(numpy.float32) + +# normalized data are +# norm = (projection - dark)/(flat-dark) + +def normalize(projection, dark, flat, def_val=0.1): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + +norm = [normalize(projection, dark, flat) for projection in proj] +norm = numpy.asarray (norm) +norm = norm.astype(numpy.float32) + + +niterations = 15 +threads = 3 + +img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + iteration_values, False) +print ("iteration values %s" % str(iteration_values)) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) +iteration_values = numpy.zeros((niterations,)) +img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) + + +##numpy.save("cgls_recon.npy", img_data) +import matplotlib.pyplot as plt +fig, ax = plt.subplots(1,6,sharey=True) +ax[0].imshow(img_cgls[80]) +ax[0].axis('off') # clear x- and y-axes +ax[1].imshow(img_sirt[80]) +ax[1].axis('off') # clear x- and y-axes +ax[2].imshow(img_mlem[80]) +ax[2].axis('off') # clear x- and y-axesplt.show() +ax[3].imshow(img_cgls_conv[80]) +ax[3].axis('off') # clear x- and y-axesplt.show() +ax[4].imshow(img_cgls_tikhonov[80]) +ax[4].axis('off') # clear x- and y-axesplt.show() +ax[5].imshow(img_cgls_TVreg[80]) +ax[5].axis('off') # clear x- and y-axesplt.show() + + +plt.show() + + diff --git a/src/Python/ccpi/fista/Reconstructor.py b/src/Python/ccpi/fista/Reconstructor.py new file mode 100644 index 0000000..d29ac0d --- /dev/null +++ b/src/Python/ccpi/fista/Reconstructor.py @@ -0,0 +1,425 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +from ccpi.reconstruction.parallelbeam import alg + +from Regularizer import Regularizer +from enum import Enum + +import astra + + + +class FISTAReconstructor(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , + 'weights' , 'region_of_interest' , 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +nx = h5py.File(fname, "r") + +# the data are stored in a particular location in the hdf5 +for item in nx['entry1/tomo_entry/data'].keys(): + print (item) + +data = nx.get('entry1/tomo_entry/data/rotation_angle') +angles = numpy.zeros(data.shape) +data.read_direct(angles) +print (angles) +# angles should be in degrees + +data = nx.get('entry1/tomo_entry/data/data') +stack = numpy.zeros(data.shape) +data.read_direct(stack) +print (data.shape) + +print ("Data Loaded") + + +# Normalize +data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +itype = numpy.zeros(data.shape) +data.read_direct(itype) +# 2 is dark field +darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +dark = darks[0] +for i in range(1, len(darks)): + dark += darks[i] +dark = dark / len(darks) +#dark[0][0] = dark[0][1] + +# 1 is flat field +flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +flat = flats[0] +for i in range(1, len(flats)): + flat += flats[i] +flat = flat / len(flats) +#flat[0][0] = dark[0][1] + + +# 0 is projection data +proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = numpy.asarray (angle_proj) +angle_proj = angle_proj.astype(numpy.float32) + +# normalized data are +# norm = (projection - dark)/(flat-dark) + +def normalize(projection, dark, flat, def_val=0.1): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + +norm = [normalize(projection, dark, flat) for projection in proj] +norm = numpy.asarray (norm) +norm = norm.astype(numpy.float32) + +#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) + +#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) +#img_cgls = recon.reconstruct() +# +#pars = dict() +#pars['algorithm'] = Reconstructor.Algorithm.SIRT +#pars['projection_data'] = proj +#pars['angles'] = angle_proj +#pars['center_of_rotation'] = numpy.double(86.2) +#pars['flat_field'] = flat +#pars['iterations'] = 15 +#pars['dark_field'] = dark +#pars['resolution'] = 1 +#pars['isLogScale'] = False +#pars['threads'] = 3 +# +#img_sirt = recon.reconstruct(pars) +# +#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM +#img_mlem = recon.reconstruct() + +############################################################ +############################################################ +#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV +#recon.pars['regularize'] = numpy.double(0.1) +#img_cgls_conv = recon.reconstruct() + +niterations = 15 +threads = 3 + +img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + iteration_values, False) +print ("iteration values %s" % str(iteration_values)) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) +iteration_values = numpy.zeros((niterations,)) +img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) + + +##numpy.save("cgls_recon.npy", img_data) +import matplotlib.pyplot as plt +fig, ax = plt.subplots(1,6,sharey=True) +ax[0].imshow(img_cgls[80]) +ax[0].axis('off') # clear x- and y-axes +ax[1].imshow(img_sirt[80]) +ax[1].axis('off') # clear x- and y-axes +ax[2].imshow(img_mlem[80]) +ax[2].axis('off') # clear x- and y-axesplt.show() +ax[3].imshow(img_cgls_conv[80]) +ax[3].axis('off') # clear x- and y-axesplt.show() +ax[4].imshow(img_cgls_tikhonov[80]) +ax[4].axis('off') # clear x- and y-axesplt.show() +ax[5].imshow(img_cgls_TVreg[80]) +ax[5].axis('off') # clear x- and y-axesplt.show() + + +plt.show() + +#viewer = edo.CILViewer() +#viewer.setInputAsNumpy(img_cgls2) +#viewer.displaySliceActor(0) +#viewer.startRenderLoop() + +import vtk + +def NumpyToVTKImageData(numpyarray): + if (len(numpy.shape(numpyarray)) == 3): + doubleImg = vtk.vtkImageData() + shape = numpy.shape(numpyarray) + doubleImg.SetDimensions(shape[0], shape[1], shape[2]) + doubleImg.SetOrigin(0,0,0) + doubleImg.SetSpacing(1,1,1) + doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) + #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) + doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) + + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(shape[2]): + doubleImg.SetScalarComponentFromDouble( + i,j,k,0, numpyarray[i][j][k]) + #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(doubleImg) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(doubleImg) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(iMin) + shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) + shiftScaler.Update() + return shiftScaler.GetOutput() + +#writer = vtk.vtkMetaImageWriter() +#writer.SetFileName(alg + "_recon.mha") +#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) +#writer.Write() diff --git a/src/Python/ccpi/fista/Reconstructor.py~ b/src/Python/ccpi/fista/Reconstructor.py~ new file mode 100644 index 0000000..ba67327 --- /dev/null +++ b/src/Python/ccpi/fista/Reconstructor.py~ @@ -0,0 +1,598 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +from ccpi.reconstruction.parallelbeam import alg + +from Regularizer import Regularizer +from enum import Enum + +import astra + + +class Reconstructor: + + class Algorithm(Enum): + CGLS = alg.cgls + CGLS_CONV = alg.cgls_conv + SIRT = alg.sirt + MLEM = alg.mlem + CGLS_TICHONOV = alg.cgls_tikhonov + CGLS_TVREG = alg.cgls_TVreg + FISTA = 'fista' + + def __init__(self, algorithm = None, projection_data = None, + angles = None, center_of_rotation = None , + flat_field = None, dark_field = None, + iterations = None, resolution = None, isLogScale = False, threads = None, + normalized_projection = None): + + self.pars = dict() + self.pars['algorithm'] = algorithm + self.pars['projection_data'] = projection_data + self.pars['normalized_projection'] = normalized_projection + self.pars['angles'] = angles + self.pars['center_of_rotation'] = numpy.double(center_of_rotation) + self.pars['flat_field'] = flat_field + self.pars['iterations'] = iterations + self.pars['dark_field'] = dark_field + self.pars['resolution'] = resolution + self.pars['isLogScale'] = isLogScale + self.pars['threads'] = threads + if (iterations != None): + self.pars['iterationValues'] = numpy.zeros((iterations)) + + if projection_data != None and dark_field != None and flat_field != None: + norm = self.normalize(projection_data, dark_field, flat_field, 0.1) + self.pars['normalized_projection'] = norm + + + def setPars(self, parameters): + keys = ['algorithm','projection_data' ,'normalized_projection', \ + 'angles' , 'center_of_rotation' , 'flat_field', \ + 'iterations','dark_field' , 'resolution', 'isLogScale' , \ + 'threads' , 'iterationValues', 'regularize'] + + for k in keys: + if k not in parameters.keys(): + self.pars[k] = None + else: + self.pars[k] = parameters[k] + + + def sanityCheck(self): + projection_data = self.pars['projection_data'] + dark_field = self.pars['dark_field'] + flat_field = self.pars['flat_field'] + angles = self.pars['angles'] + + if projection_data != None and dark_field != None and \ + angles != None and flat_field != None: + data_shape = numpy.shape(projection_data) + angle_shape = numpy.shape(angles) + + if angle_shape[0] != data_shape[0]: + #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ + # (angle_shape[0] , data_shape[0]) ) + return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ + (angle_shape[0] , data_shape[0]) ) + + if data_shape[1:] != numpy.shape(flat_field): + #raise Exception('Projection and flat field dimensions do not match') + return (False , 'Projection and flat field dimensions do not match') + if data_shape[1:] != numpy.shape(dark_field): + #raise Exception('Projection and dark field dimensions do not match') + return (False , 'Projection and dark field dimensions do not match') + + return (True , '' ) + elif self.pars['normalized_projection'] != None: + data_shape = numpy.shape(self.pars['normalized_projection']) + angle_shape = numpy.shape(angles) + + if angle_shape[0] != data_shape[0]: + #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ + # (angle_shape[0] , data_shape[0]) ) + return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ + (angle_shape[0] , data_shape[0]) ) + else: + return (True , '' ) + else: + return (False , 'Not enough data') + + def reconstruct(self, parameters = None): + if parameters != None: + self.setPars(parameters) + + go , reason = self.sanityCheck() + if go: + return self._reconstruct() + else: + raise Exception(reason) + + + def _reconstruct(self, parameters=None): + if parameters!=None: + self.setPars(parameters) + parameters = self.pars + + if parameters['algorithm'] != None and \ + parameters['normalized_projection'] != None and \ + parameters['angles'] != None and \ + parameters['center_of_rotation'] != None and \ + parameters['iterations'] != None and \ + parameters['resolution'] != None and\ + parameters['threads'] != None and\ + parameters['isLogScale'] != None: + + + if parameters['algorithm'] in (Reconstructor.Algorithm.CGLS, + Reconstructor.Algorithm.MLEM, Reconstructor.Algorithm.SIRT): + #store parameters + self.pars = parameters + result = parameters['algorithm']( + parameters['normalized_projection'] , + parameters['angles'], + parameters['center_of_rotation'], + parameters['resolution'], + parameters['iterations'], + parameters['threads'] , + parameters['isLogScale'] + ) + return result + elif parameters['algorithm'] in (Reconstructor.Algorithm.CGLS_CONV, + Reconstructor.Algorithm.CGLS_TICHONOV, + Reconstructor.Algorithm.CGLS_TVREG) : + self.pars = parameters + result = parameters['algorithm']( + parameters['normalized_projection'] , + parameters['angles'], + parameters['center_of_rotation'], + parameters['resolution'], + parameters['iterations'], + parameters['threads'] , + parameters['regularize'], + numpy.zeros((parameters['iterations'])), + parameters['isLogScale'] + ) + + elif parameters['algorithm'] == Reconstructor.Algorithm.FISTA: + pass + + else: + if parameters['projection_data'] != None and \ + parameters['dark_field'] != None and \ + parameters['flat_field'] != None: + norm = self.normalize(parameters['projection_data'], + parameters['dark_field'], + parameters['flat_field'], 0.1) + self.pars['normalized_projection'] = norm + return self._reconstruct(parameters) + + + + def _normalize(self, projection, dark, flat, def_val=0): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + def normalize(self, projections, dark, flat, def_val=0): + norm = [self._normalize(projection, dark, flat, def_val) for projection in projections] + return numpy.asarray (norm, dtype=numpy.float32) + + + +class FISTA(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , + 'weights' , 'region_of_interest' , 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +nx = h5py.File(fname, "r") + +# the data are stored in a particular location in the hdf5 +for item in nx['entry1/tomo_entry/data'].keys(): + print (item) + +data = nx.get('entry1/tomo_entry/data/rotation_angle') +angles = numpy.zeros(data.shape) +data.read_direct(angles) +print (angles) +# angles should be in degrees + +data = nx.get('entry1/tomo_entry/data/data') +stack = numpy.zeros(data.shape) +data.read_direct(stack) +print (data.shape) + +print ("Data Loaded") + + +# Normalize +data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +itype = numpy.zeros(data.shape) +data.read_direct(itype) +# 2 is dark field +darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +dark = darks[0] +for i in range(1, len(darks)): + dark += darks[i] +dark = dark / len(darks) +#dark[0][0] = dark[0][1] + +# 1 is flat field +flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +flat = flats[0] +for i in range(1, len(flats)): + flat += flats[i] +flat = flat / len(flats) +#flat[0][0] = dark[0][1] + + +# 0 is projection data +proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = numpy.asarray (angle_proj) +angle_proj = angle_proj.astype(numpy.float32) + +# normalized data are +# norm = (projection - dark)/(flat-dark) + +def normalize(projection, dark, flat, def_val=0.1): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + +norm = [normalize(projection, dark, flat) for projection in proj] +norm = numpy.asarray (norm) +norm = norm.astype(numpy.float32) + +#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) + +#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) +#img_cgls = recon.reconstruct() +# +#pars = dict() +#pars['algorithm'] = Reconstructor.Algorithm.SIRT +#pars['projection_data'] = proj +#pars['angles'] = angle_proj +#pars['center_of_rotation'] = numpy.double(86.2) +#pars['flat_field'] = flat +#pars['iterations'] = 15 +#pars['dark_field'] = dark +#pars['resolution'] = 1 +#pars['isLogScale'] = False +#pars['threads'] = 3 +# +#img_sirt = recon.reconstruct(pars) +# +#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM +#img_mlem = recon.reconstruct() + +############################################################ +############################################################ +#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV +#recon.pars['regularize'] = numpy.double(0.1) +#img_cgls_conv = recon.reconstruct() + +niterations = 15 +threads = 3 + +img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + iteration_values, False) +print ("iteration values %s" % str(iteration_values)) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) +iteration_values = numpy.zeros((niterations,)) +img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) + + +##numpy.save("cgls_recon.npy", img_data) +import matplotlib.pyplot as plt +fig, ax = plt.subplots(1,6,sharey=True) +ax[0].imshow(img_cgls[80]) +ax[0].axis('off') # clear x- and y-axes +ax[1].imshow(img_sirt[80]) +ax[1].axis('off') # clear x- and y-axes +ax[2].imshow(img_mlem[80]) +ax[2].axis('off') # clear x- and y-axesplt.show() +ax[3].imshow(img_cgls_conv[80]) +ax[3].axis('off') # clear x- and y-axesplt.show() +ax[4].imshow(img_cgls_tikhonov[80]) +ax[4].axis('off') # clear x- and y-axesplt.show() +ax[5].imshow(img_cgls_TVreg[80]) +ax[5].axis('off') # clear x- and y-axesplt.show() + + +plt.show() + +#viewer = edo.CILViewer() +#viewer.setInputAsNumpy(img_cgls2) +#viewer.displaySliceActor(0) +#viewer.startRenderLoop() + +import vtk + +def NumpyToVTKImageData(numpyarray): + if (len(numpy.shape(numpyarray)) == 3): + doubleImg = vtk.vtkImageData() + shape = numpy.shape(numpyarray) + doubleImg.SetDimensions(shape[0], shape[1], shape[2]) + doubleImg.SetOrigin(0,0,0) + doubleImg.SetSpacing(1,1,1) + doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) + #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) + doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) + + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(shape[2]): + doubleImg.SetScalarComponentFromDouble( + i,j,k,0, numpyarray[i][j][k]) + #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(doubleImg) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(doubleImg) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(iMin) + shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) + shiftScaler.Update() + return shiftScaler.GetOutput() + +#writer = vtk.vtkMetaImageWriter() +#writer.SetFileName(alg + "_recon.mha") +#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) +#writer.Write() diff --git a/src/Python/ccpi/fista/__init__.py b/src/Python/ccpi/fista/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/Python/ccpi/fista/__init__.pyc b/src/Python/ccpi/fista/__init__.pyc new file mode 100644 index 0000000..719e264 Binary files /dev/null and b/src/Python/ccpi/fista/__init__.pyc differ diff --git a/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc b/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc new file mode 100644 index 0000000..84f16e2 Binary files /dev/null and b/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc differ diff --git a/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc b/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc new file mode 100644 index 0000000..90c23ff Binary files /dev/null and b/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc differ -- cgit v1.2.3 From 64c0b9e7a1bcfc54e6ed8b57274d53c3ed9bb950 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 25 Aug 2017 17:03:17 +0100 Subject: use refactore code --- src/Python/ccpi/fista/FISTAReconstructor.pyc | Bin 3804 -> 0 bytes src/Python/ccpi/fista/FISTAReconstructor.py~ | 349 ------------ src/Python/ccpi/fista/Reconstructor.py~ | 598 --------------------- src/Python/ccpi/fista/__init__.pyc | Bin 189 -> 0 bytes .../__pycache__/FISTAReconstructor.cpython-35.pyc | Bin 3641 -> 0 bytes .../ccpi/fista/__pycache__/__init__.cpython-35.pyc | Bin 185 -> 0 bytes 6 files changed, 947 deletions(-) delete mode 100644 src/Python/ccpi/fista/FISTAReconstructor.pyc delete mode 100644 src/Python/ccpi/fista/FISTAReconstructor.py~ delete mode 100644 src/Python/ccpi/fista/Reconstructor.py~ delete mode 100644 src/Python/ccpi/fista/__init__.pyc delete mode 100644 src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc delete mode 100644 src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc (limited to 'src/Python/ccpi') diff --git a/src/Python/ccpi/fista/FISTAReconstructor.pyc b/src/Python/ccpi/fista/FISTAReconstructor.pyc deleted file mode 100644 index ecc4d7d..0000000 Binary files a/src/Python/ccpi/fista/FISTAReconstructor.pyc and /dev/null differ diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py~ b/src/Python/ccpi/fista/FISTAReconstructor.py~ deleted file mode 100644 index 6c7024d..0000000 --- a/src/Python/ccpi/fista/FISTAReconstructor.py~ +++ /dev/null @@ -1,349 +0,0 @@ -# -*- coding: utf-8 -*- -############################################################################### -#This work is part of the Core Imaging Library developed by -#Visual Analytics and Imaging System Group of the Science Technology -#Facilities Council, STFC -# -#Copyright 2017 Edoardo Pasca, Srikanth Nagella -#Copyright 2017 Daniil Kazantsev -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -#http://www.apache.org/licenses/LICENSE-2.0 -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. -############################################################################### - - - -import numpy -import h5py -#from ccpi.reconstruction.parallelbeam import alg - -from Regularizer import Regularizer -from enum import Enum - -import astra - - - -class FISTAReconstructor(): - '''FISTA-based reconstruction algorithm using ASTRA-toolbox - - ''' - # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> - # ___Input___: - # params.[] file: - # - .proj_geom (geometry of the projector) [required] - # - .vol_geom (geometry of the reconstructed object) [required] - # - .sino (vectorized in 2D or 3D sinogram) [required] - # - .iterFISTA (iterations for the main loop, default 40) - # - .L_const (Lipschitz constant, default Power method) ) - # - .X_ideal (ideal image, if given) - # - .weights (statisitcal weights, size of the sinogram) - # - .ROI (Region-of-interest, only if X_ideal is given) - # - .initialize (a 'warm start' using SIRT method from ASTRA) - #----------------Regularization choices------------------------ - # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) - # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) - # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) - # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) - # - .Regul_Iterations (iterations for the selected penalty, default 25) - # - .Regul_tauLLT (time step parameter for LLT term) - # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) - # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) - #----------------Visualization parameters------------------------ - # - .show (visualize reconstruction 1/0, (0 default)) - # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) - # - .slice (for 3D volumes - slice number to imshow) - # ___Output___: - # 1. X - reconstructed image/volume - # 2. output - a structure with - # - .Resid_error - residual error (if X_ideal is given) - # - .objective: value of the objective function - # - .L_const: Lipshitz constant to avoid recalculations - - # References: - # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse - # Problems" by A. Beck and M Teboulle - # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo - # 3. "A novel tomographic reconstruction method based on the robust - # Student's t function for suppressing data outliers" D. Kazantsev et.al. - # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): - self.params = dict() - self.params['projector_geometry'] = projector_geometry - self.params['output_geometry'] = output_geometry - self.params['input_sinogram'] = input_sinogram - detectors, nangles, sliceZ = numpy.shape(input_sinogram) - self.params['detectors'] = detectors - self.params['number_og_angles'] = nangles - self.params['SlicesZ'] = sliceZ - - # Accepted input keywords - kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , - 'weights' , 'region_of_interest' , 'initialize' , - 'regularizer' , - 'ring_lambda_R_L1', - 'ring_alpha') - - # handle keyworded parameters - if kwargs is not None: - for key, value in kwargs.items(): - if key in kw: - #print("{0} = {1}".format(key, value)) - self.pars[key] = value - - # set the default values for the parameters if not set - if 'number_of_iterations' in kwargs.keys(): - self.pars['number_of_iterations'] = kwargs['number_of_iterations'] - else: - self.pars['number_of_iterations'] = 40 - if 'weights' in kwargs.keys(): - self.pars['weights'] = kwargs['weights'] - else: - self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) - if 'Lipschitz_constant' in kwargs.keys(): - self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] - else: - self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() - - if not self.pars['ideal_image'] in kwargs.keys(): - self.pars['ideal_image'] = None - - if not self.pars['region_of_interest'] : - if self.pars['ideal_image'] == None: - pass - else: - self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) - - if not self.pars['regularizer'] : - self.pars['regularizer'] = None - else: - # the regularizer must be a correctly instantiated object - if not self.pars['ring_lambda_R_L1']: - self.pars['ring_lambda_R_L1'] = 0 - if not self.pars['ring_alpha']: - self.pars['ring_alpha'] = 1 - - - - - def calculateLipschitzConstantWithPowerMethod(self): - ''' using Power method (PM) to establish L constant''' - - #N = params.vol_geom.GridColCount - N = self.pars['output_geometry'].GridColCount - proj_geom = self.params['projector_geometry'] - vol_geom = self.params['output_geometry'] - weights = self.pars['weights'] - SlicesZ = self.pars['SlicesZ'] - - if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): - #% for parallel geometry we can do just one slice - #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); - niter = 15;# % number of iteration for the PM - #N = params.vol_geom.GridColCount; - #x1 = rand(N,N,1); - x1 = numpy.random.rand(1,N,N) - #sqweight = sqrt(weights(:,:,1)); - sqweight = numpy.sqrt(weights.T[0]) - proj_geomT = proj_geom.copy(); - proj_geomT.DetectorRowCount = 1; - vol_geomT = vol_geom.copy(); - vol_geomT['GridSliceCount'] = 1; - - - for i in range(niter): - if i == 0: - #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); - y = sqweight * y # element wise multiplication - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id) - - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - y = sqweight*y; - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - del proj_geomT - del vol_geomT - else - #% divergen beam geometry - #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); - niter = 8; #% number of iteration for PM - x1 = numpy.random.rand(SlicesZ , N , N); - #sqweight = sqrt(weights); - sqweight = numpy.sqrt(weights.T[0]) - - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id); - - for i in range(niter): - #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, - proj_geom, - vol_geom) - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); - sino_id, y = astra.creators.create_sino3d_gpu(x1, - proj_geom, - vol_geom); - - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - #astra_mex_data3d('delete', id); - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - #clear x1 - del x1 - - return s - - - def setRegularizer(self, regularizer): - if regularizer - self.pars['regularizer'] = regularizer - - - - - -def getEntry(location): - for item in nx[location].keys(): - print (item) - - -print ("Loading Data") - -##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" -####ind = [i * 1049 for i in range(360)] -#### use only 360 images -##images = 200 -##ind = [int(i * 1049 / images) for i in range(images)] -##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) - -#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" -fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" -nx = h5py.File(fname, "r") - -# the data are stored in a particular location in the hdf5 -for item in nx['entry1/tomo_entry/data'].keys(): - print (item) - -data = nx.get('entry1/tomo_entry/data/rotation_angle') -angles = numpy.zeros(data.shape) -data.read_direct(angles) -print (angles) -# angles should be in degrees - -data = nx.get('entry1/tomo_entry/data/data') -stack = numpy.zeros(data.shape) -data.read_direct(stack) -print (data.shape) - -print ("Data Loaded") - - -# Normalize -data = nx.get('entry1/tomo_entry/instrument/detector/image_key') -itype = numpy.zeros(data.shape) -data.read_direct(itype) -# 2 is dark field -darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] -dark = darks[0] -for i in range(1, len(darks)): - dark += darks[i] -dark = dark / len(darks) -#dark[0][0] = dark[0][1] - -# 1 is flat field -flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] -flat = flats[0] -for i in range(1, len(flats)): - flat += flats[i] -flat = flat / len(flats) -#flat[0][0] = dark[0][1] - - -# 0 is projection data -proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = numpy.asarray (angle_proj) -angle_proj = angle_proj.astype(numpy.float32) - -# normalized data are -# norm = (projection - dark)/(flat-dark) - -def normalize(projection, dark, flat, def_val=0.1): - a = (projection - dark) - b = (flat-dark) - with numpy.errstate(divide='ignore', invalid='ignore'): - c = numpy.true_divide( a, b ) - c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 - return c - - -norm = [normalize(projection, dark, flat) for projection in proj] -norm = numpy.asarray (norm) -norm = norm.astype(numpy.float32) - - -niterations = 15 -threads = 3 - -img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - iteration_values, False) -print ("iteration values %s" % str(iteration_values)) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) -iteration_values = numpy.zeros((niterations,)) -img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) - - -##numpy.save("cgls_recon.npy", img_data) -import matplotlib.pyplot as plt -fig, ax = plt.subplots(1,6,sharey=True) -ax[0].imshow(img_cgls[80]) -ax[0].axis('off') # clear x- and y-axes -ax[1].imshow(img_sirt[80]) -ax[1].axis('off') # clear x- and y-axes -ax[2].imshow(img_mlem[80]) -ax[2].axis('off') # clear x- and y-axesplt.show() -ax[3].imshow(img_cgls_conv[80]) -ax[3].axis('off') # clear x- and y-axesplt.show() -ax[4].imshow(img_cgls_tikhonov[80]) -ax[4].axis('off') # clear x- and y-axesplt.show() -ax[5].imshow(img_cgls_TVreg[80]) -ax[5].axis('off') # clear x- and y-axesplt.show() - - -plt.show() - - diff --git a/src/Python/ccpi/fista/Reconstructor.py~ b/src/Python/ccpi/fista/Reconstructor.py~ deleted file mode 100644 index ba67327..0000000 --- a/src/Python/ccpi/fista/Reconstructor.py~ +++ /dev/null @@ -1,598 +0,0 @@ -# -*- coding: utf-8 -*- -############################################################################### -#This work is part of the Core Imaging Library developed by -#Visual Analytics and Imaging System Group of the Science Technology -#Facilities Council, STFC -# -#Copyright 2017 Edoardo Pasca, Srikanth Nagella -#Copyright 2017 Daniil Kazantsev -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -#http://www.apache.org/licenses/LICENSE-2.0 -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. -############################################################################### - - - -import numpy -import h5py -from ccpi.reconstruction.parallelbeam import alg - -from Regularizer import Regularizer -from enum import Enum - -import astra - - -class Reconstructor: - - class Algorithm(Enum): - CGLS = alg.cgls - CGLS_CONV = alg.cgls_conv - SIRT = alg.sirt - MLEM = alg.mlem - CGLS_TICHONOV = alg.cgls_tikhonov - CGLS_TVREG = alg.cgls_TVreg - FISTA = 'fista' - - def __init__(self, algorithm = None, projection_data = None, - angles = None, center_of_rotation = None , - flat_field = None, dark_field = None, - iterations = None, resolution = None, isLogScale = False, threads = None, - normalized_projection = None): - - self.pars = dict() - self.pars['algorithm'] = algorithm - self.pars['projection_data'] = projection_data - self.pars['normalized_projection'] = normalized_projection - self.pars['angles'] = angles - self.pars['center_of_rotation'] = numpy.double(center_of_rotation) - self.pars['flat_field'] = flat_field - self.pars['iterations'] = iterations - self.pars['dark_field'] = dark_field - self.pars['resolution'] = resolution - self.pars['isLogScale'] = isLogScale - self.pars['threads'] = threads - if (iterations != None): - self.pars['iterationValues'] = numpy.zeros((iterations)) - - if projection_data != None and dark_field != None and flat_field != None: - norm = self.normalize(projection_data, dark_field, flat_field, 0.1) - self.pars['normalized_projection'] = norm - - - def setPars(self, parameters): - keys = ['algorithm','projection_data' ,'normalized_projection', \ - 'angles' , 'center_of_rotation' , 'flat_field', \ - 'iterations','dark_field' , 'resolution', 'isLogScale' , \ - 'threads' , 'iterationValues', 'regularize'] - - for k in keys: - if k not in parameters.keys(): - self.pars[k] = None - else: - self.pars[k] = parameters[k] - - - def sanityCheck(self): - projection_data = self.pars['projection_data'] - dark_field = self.pars['dark_field'] - flat_field = self.pars['flat_field'] - angles = self.pars['angles'] - - if projection_data != None and dark_field != None and \ - angles != None and flat_field != None: - data_shape = numpy.shape(projection_data) - angle_shape = numpy.shape(angles) - - if angle_shape[0] != data_shape[0]: - #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ - # (angle_shape[0] , data_shape[0]) ) - return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ - (angle_shape[0] , data_shape[0]) ) - - if data_shape[1:] != numpy.shape(flat_field): - #raise Exception('Projection and flat field dimensions do not match') - return (False , 'Projection and flat field dimensions do not match') - if data_shape[1:] != numpy.shape(dark_field): - #raise Exception('Projection and dark field dimensions do not match') - return (False , 'Projection and dark field dimensions do not match') - - return (True , '' ) - elif self.pars['normalized_projection'] != None: - data_shape = numpy.shape(self.pars['normalized_projection']) - angle_shape = numpy.shape(angles) - - if angle_shape[0] != data_shape[0]: - #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ - # (angle_shape[0] , data_shape[0]) ) - return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ - (angle_shape[0] , data_shape[0]) ) - else: - return (True , '' ) - else: - return (False , 'Not enough data') - - def reconstruct(self, parameters = None): - if parameters != None: - self.setPars(parameters) - - go , reason = self.sanityCheck() - if go: - return self._reconstruct() - else: - raise Exception(reason) - - - def _reconstruct(self, parameters=None): - if parameters!=None: - self.setPars(parameters) - parameters = self.pars - - if parameters['algorithm'] != None and \ - parameters['normalized_projection'] != None and \ - parameters['angles'] != None and \ - parameters['center_of_rotation'] != None and \ - parameters['iterations'] != None and \ - parameters['resolution'] != None and\ - parameters['threads'] != None and\ - parameters['isLogScale'] != None: - - - if parameters['algorithm'] in (Reconstructor.Algorithm.CGLS, - Reconstructor.Algorithm.MLEM, Reconstructor.Algorithm.SIRT): - #store parameters - self.pars = parameters - result = parameters['algorithm']( - parameters['normalized_projection'] , - parameters['angles'], - parameters['center_of_rotation'], - parameters['resolution'], - parameters['iterations'], - parameters['threads'] , - parameters['isLogScale'] - ) - return result - elif parameters['algorithm'] in (Reconstructor.Algorithm.CGLS_CONV, - Reconstructor.Algorithm.CGLS_TICHONOV, - Reconstructor.Algorithm.CGLS_TVREG) : - self.pars = parameters - result = parameters['algorithm']( - parameters['normalized_projection'] , - parameters['angles'], - parameters['center_of_rotation'], - parameters['resolution'], - parameters['iterations'], - parameters['threads'] , - parameters['regularize'], - numpy.zeros((parameters['iterations'])), - parameters['isLogScale'] - ) - - elif parameters['algorithm'] == Reconstructor.Algorithm.FISTA: - pass - - else: - if parameters['projection_data'] != None and \ - parameters['dark_field'] != None and \ - parameters['flat_field'] != None: - norm = self.normalize(parameters['projection_data'], - parameters['dark_field'], - parameters['flat_field'], 0.1) - self.pars['normalized_projection'] = norm - return self._reconstruct(parameters) - - - - def _normalize(self, projection, dark, flat, def_val=0): - a = (projection - dark) - b = (flat-dark) - with numpy.errstate(divide='ignore', invalid='ignore'): - c = numpy.true_divide( a, b ) - c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 - return c - - def normalize(self, projections, dark, flat, def_val=0): - norm = [self._normalize(projection, dark, flat, def_val) for projection in projections] - return numpy.asarray (norm, dtype=numpy.float32) - - - -class FISTA(): - '''FISTA-based reconstruction algorithm using ASTRA-toolbox - - ''' - # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> - # ___Input___: - # params.[] file: - # - .proj_geom (geometry of the projector) [required] - # - .vol_geom (geometry of the reconstructed object) [required] - # - .sino (vectorized in 2D or 3D sinogram) [required] - # - .iterFISTA (iterations for the main loop, default 40) - # - .L_const (Lipschitz constant, default Power method) ) - # - .X_ideal (ideal image, if given) - # - .weights (statisitcal weights, size of the sinogram) - # - .ROI (Region-of-interest, only if X_ideal is given) - # - .initialize (a 'warm start' using SIRT method from ASTRA) - #----------------Regularization choices------------------------ - # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) - # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) - # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) - # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) - # - .Regul_Iterations (iterations for the selected penalty, default 25) - # - .Regul_tauLLT (time step parameter for LLT term) - # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) - # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) - #----------------Visualization parameters------------------------ - # - .show (visualize reconstruction 1/0, (0 default)) - # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) - # - .slice (for 3D volumes - slice number to imshow) - # ___Output___: - # 1. X - reconstructed image/volume - # 2. output - a structure with - # - .Resid_error - residual error (if X_ideal is given) - # - .objective: value of the objective function - # - .L_const: Lipshitz constant to avoid recalculations - - # References: - # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse - # Problems" by A. Beck and M Teboulle - # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo - # 3. "A novel tomographic reconstruction method based on the robust - # Student's t function for suppressing data outliers" D. Kazantsev et.al. - # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): - self.params = dict() - self.params['projector_geometry'] = projector_geometry - self.params['output_geometry'] = output_geometry - self.params['input_sinogram'] = input_sinogram - detectors, nangles, sliceZ = numpy.shape(input_sinogram) - self.params['detectors'] = detectors - self.params['number_og_angles'] = nangles - self.params['SlicesZ'] = sliceZ - - # Accepted input keywords - kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , - 'weights' , 'region_of_interest' , 'initialize' , - 'regularizer' , - 'ring_lambda_R_L1', - 'ring_alpha') - - # handle keyworded parameters - if kwargs is not None: - for key, value in kwargs.items(): - if key in kw: - #print("{0} = {1}".format(key, value)) - self.pars[key] = value - - # set the default values for the parameters if not set - if 'number_of_iterations' in kwargs.keys(): - self.pars['number_of_iterations'] = kwargs['number_of_iterations'] - else: - self.pars['number_of_iterations'] = 40 - if 'weights' in kwargs.keys(): - self.pars['weights'] = kwargs['weights'] - else: - self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) - if 'Lipschitz_constant' in kwargs.keys(): - self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] - else: - self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() - - if not self.pars['ideal_image'] in kwargs.keys(): - self.pars['ideal_image'] = None - - if not self.pars['region_of_interest'] : - if self.pars['ideal_image'] == None: - pass - else: - self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) - - if not self.pars['regularizer'] : - self.pars['regularizer'] = None - else: - # the regularizer must be a correctly instantiated object - if not self.pars['ring_lambda_R_L1']: - self.pars['ring_lambda_R_L1'] = 0 - if not self.pars['ring_alpha']: - self.pars['ring_alpha'] = 1 - - - - - def calculateLipschitzConstantWithPowerMethod(self): - ''' using Power method (PM) to establish L constant''' - - #N = params.vol_geom.GridColCount - N = self.pars['output_geometry'].GridColCount - proj_geom = self.params['projector_geometry'] - vol_geom = self.params['output_geometry'] - weights = self.pars['weights'] - SlicesZ = self.pars['SlicesZ'] - - if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): - #% for parallel geometry we can do just one slice - #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); - niter = 15;# % number of iteration for the PM - #N = params.vol_geom.GridColCount; - #x1 = rand(N,N,1); - x1 = numpy.random.rand(1,N,N) - #sqweight = sqrt(weights(:,:,1)); - sqweight = numpy.sqrt(weights.T[0]) - proj_geomT = proj_geom.copy(); - proj_geomT.DetectorRowCount = 1; - vol_geomT = vol_geom.copy(); - vol_geomT['GridSliceCount'] = 1; - - - for i in range(niter): - if i == 0: - #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); - y = sqweight * y # element wise multiplication - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id) - - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - y = sqweight*y; - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - del proj_geomT - del vol_geomT - else - #% divergen beam geometry - #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); - niter = 8; #% number of iteration for PM - x1 = numpy.random.rand(SlicesZ , N , N); - #sqweight = sqrt(weights); - sqweight = numpy.sqrt(weights.T[0]) - - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id); - - for i in range(niter): - #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, - proj_geom, - vol_geom) - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); - sino_id, y = astra.creators.create_sino3d_gpu(x1, - proj_geom, - vol_geom); - - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - #astra_mex_data3d('delete', id); - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - #clear x1 - del x1 - - return s - - - def setRegularizer(self, regularizer): - if regularizer - self.pars['regularizer'] = regularizer - - - - - -def getEntry(location): - for item in nx[location].keys(): - print (item) - - -print ("Loading Data") - -##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" -####ind = [i * 1049 for i in range(360)] -#### use only 360 images -##images = 200 -##ind = [int(i * 1049 / images) for i in range(images)] -##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) - -#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" -fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" -nx = h5py.File(fname, "r") - -# the data are stored in a particular location in the hdf5 -for item in nx['entry1/tomo_entry/data'].keys(): - print (item) - -data = nx.get('entry1/tomo_entry/data/rotation_angle') -angles = numpy.zeros(data.shape) -data.read_direct(angles) -print (angles) -# angles should be in degrees - -data = nx.get('entry1/tomo_entry/data/data') -stack = numpy.zeros(data.shape) -data.read_direct(stack) -print (data.shape) - -print ("Data Loaded") - - -# Normalize -data = nx.get('entry1/tomo_entry/instrument/detector/image_key') -itype = numpy.zeros(data.shape) -data.read_direct(itype) -# 2 is dark field -darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] -dark = darks[0] -for i in range(1, len(darks)): - dark += darks[i] -dark = dark / len(darks) -#dark[0][0] = dark[0][1] - -# 1 is flat field -flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] -flat = flats[0] -for i in range(1, len(flats)): - flat += flats[i] -flat = flat / len(flats) -#flat[0][0] = dark[0][1] - - -# 0 is projection data -proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = numpy.asarray (angle_proj) -angle_proj = angle_proj.astype(numpy.float32) - -# normalized data are -# norm = (projection - dark)/(flat-dark) - -def normalize(projection, dark, flat, def_val=0.1): - a = (projection - dark) - b = (flat-dark) - with numpy.errstate(divide='ignore', invalid='ignore'): - c = numpy.true_divide( a, b ) - c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 - return c - - -norm = [normalize(projection, dark, flat) for projection in proj] -norm = numpy.asarray (norm) -norm = norm.astype(numpy.float32) - -#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, -# angles = angle_proj, center_of_rotation = 86.2 , -# flat_field = flat, dark_field = dark, -# iterations = 15, resolution = 1, isLogScale = False, threads = 3) - -#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, -# angles = angle_proj, center_of_rotation = 86.2 , -# flat_field = flat, dark_field = dark, -# iterations = 15, resolution = 1, isLogScale = False, threads = 3) -#img_cgls = recon.reconstruct() -# -#pars = dict() -#pars['algorithm'] = Reconstructor.Algorithm.SIRT -#pars['projection_data'] = proj -#pars['angles'] = angle_proj -#pars['center_of_rotation'] = numpy.double(86.2) -#pars['flat_field'] = flat -#pars['iterations'] = 15 -#pars['dark_field'] = dark -#pars['resolution'] = 1 -#pars['isLogScale'] = False -#pars['threads'] = 3 -# -#img_sirt = recon.reconstruct(pars) -# -#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM -#img_mlem = recon.reconstruct() - -############################################################ -############################################################ -#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV -#recon.pars['regularize'] = numpy.double(0.1) -#img_cgls_conv = recon.reconstruct() - -niterations = 15 -threads = 3 - -img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - iteration_values, False) -print ("iteration values %s" % str(iteration_values)) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) -iteration_values = numpy.zeros((niterations,)) -img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) - - -##numpy.save("cgls_recon.npy", img_data) -import matplotlib.pyplot as plt -fig, ax = plt.subplots(1,6,sharey=True) -ax[0].imshow(img_cgls[80]) -ax[0].axis('off') # clear x- and y-axes -ax[1].imshow(img_sirt[80]) -ax[1].axis('off') # clear x- and y-axes -ax[2].imshow(img_mlem[80]) -ax[2].axis('off') # clear x- and y-axesplt.show() -ax[3].imshow(img_cgls_conv[80]) -ax[3].axis('off') # clear x- and y-axesplt.show() -ax[4].imshow(img_cgls_tikhonov[80]) -ax[4].axis('off') # clear x- and y-axesplt.show() -ax[5].imshow(img_cgls_TVreg[80]) -ax[5].axis('off') # clear x- and y-axesplt.show() - - -plt.show() - -#viewer = edo.CILViewer() -#viewer.setInputAsNumpy(img_cgls2) -#viewer.displaySliceActor(0) -#viewer.startRenderLoop() - -import vtk - -def NumpyToVTKImageData(numpyarray): - if (len(numpy.shape(numpyarray)) == 3): - doubleImg = vtk.vtkImageData() - shape = numpy.shape(numpyarray) - doubleImg.SetDimensions(shape[0], shape[1], shape[2]) - doubleImg.SetOrigin(0,0,0) - doubleImg.SetSpacing(1,1,1) - doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) - #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) - doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) - - for i in range(shape[0]): - for j in range(shape[1]): - for k in range(shape[2]): - doubleImg.SetScalarComponentFromDouble( - i,j,k,0, numpyarray[i][j][k]) - #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) - # rescale to appropriate VTK_UNSIGNED_SHORT - stats = vtk.vtkImageAccumulate() - stats.SetInputData(doubleImg) - stats.Update() - iMin = stats.GetMin()[0] - iMax = stats.GetMax()[0] - scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) - - shiftScaler = vtk.vtkImageShiftScale () - shiftScaler.SetInputData(doubleImg) - shiftScaler.SetScale(scale) - shiftScaler.SetShift(iMin) - shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) - shiftScaler.Update() - return shiftScaler.GetOutput() - -#writer = vtk.vtkMetaImageWriter() -#writer.SetFileName(alg + "_recon.mha") -#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) -#writer.Write() diff --git a/src/Python/ccpi/fista/__init__.pyc b/src/Python/ccpi/fista/__init__.pyc deleted file mode 100644 index 719e264..0000000 Binary files a/src/Python/ccpi/fista/__init__.pyc and /dev/null differ diff --git a/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc b/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc deleted file mode 100644 index 84f16e2..0000000 Binary files a/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc and /dev/null differ diff --git a/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc b/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc deleted file mode 100644 index 90c23ff..0000000 Binary files a/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc and /dev/null differ -- cgit v1.2.3