summaryrefslogtreecommitdiffstats
path: root/src/Python/ccpi
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2017-10-13 16:48:24 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2017-10-13 16:48:24 +0100
commit49c4a595c58d296c3a4b2f7fd480e9c64f638897 (patch)
treea00282935e76ac741f581ca30dbcaac67afc1fb1 /src/Python/ccpi
parentf7e1cf04f791898737bc15b0eb437abc2c5d9305 (diff)
downloadregularization-49c4a595c58d296c3a4b2f7fd480e9c64f638897.tar.gz
regularization-49c4a595c58d296c3a4b2f7fd480e9c64f638897.tar.bz2
regularization-49c4a595c58d296c3a4b2f7fd480e9c64f638897.tar.xz
regularization-49c4a595c58d296c3a4b2f7fd480e9c64f638897.zip
Added setParameter
minor beautification of code
Diffstat (limited to 'src/Python/ccpi')
-rw-r--r--src/Python/ccpi/fista/FISTAReconstructor.py164
1 files changed, 34 insertions, 130 deletions
diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py
index 1e76815..cbd27da 100644
--- a/src/Python/ccpi/fista/FISTAReconstructor.py
+++ b/src/Python/ccpi/fista/FISTAReconstructor.py
@@ -73,7 +73,8 @@ class FISTAReconstructor():
# 3. "A novel tomographic reconstruction method based on the robust
# Student's t function for suppressing data outliers" D. Kazantsev et.al.
# D. Kazantsev, 2016-17
- def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs):
+ def __init__(self, projector_geometry, output_geometry, input_sinogram,
+ **kwargs):
# handle parmeters:
# obligatory parameters
self.pars = dict()
@@ -98,6 +99,7 @@ class FISTAReconstructor():
'regularizer' ,
'ring_lambda_R_L1',
'ring_alpha')
+ self.acceptedInputKeywords = kw
# handle keyworded parameters
if kwargs is not None:
@@ -114,11 +116,14 @@ class FISTAReconstructor():
if 'weights' in kwargs.keys():
self.pars['weights'] = kwargs['weights']
else:
- self.pars['weights'] = numpy.ones(numpy.shape(self.pars['input_sinogram']))
+ self.pars['weights'] = \
+ numpy.ones(numpy.shape(
+ self.pars['input_sinogram']))
if 'Lipschitz_constant' in kwargs.keys():
self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant']
else:
- self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod()
+ self.pars['Lipschitz_constant'] = \
+ self.calculateLipschitzConstantWithPowerMethod()
if not 'ideal_image' in kwargs.keys():
self.pars['ideal_image'] = None
@@ -127,7 +132,8 @@ class FISTAReconstructor():
if self.pars['ideal_image'] == None:
pass
else:
- self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0)
+ self.pars['region_of_interest'] = numpy.nonzero(
+ self.pars['ideal_image']>0.0)
if not 'regularizer' in kwargs.keys() :
self.pars['regularizer'] = None
@@ -140,7 +146,29 @@ class FISTAReconstructor():
+ def setParameter(self, **kwargs):
+ '''set named parameter for the regularization engine
+ raises Exception if the named parameter is not recognized
+ Typical usage is:
+
+ reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+ reg.setParameter(input=u0)
+ reg.setParameter(regularization_parameter=10.)
+
+ it can be also used as
+ reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+ reg.setParameter(input=u0 , regularization_parameter=10.)
+ '''
+
+ for key , value in kwargs.items():
+ if key in self.acceptedInputKeywords.keys():
+ self.pars[key] = value
+ else:
+ raise Exception('Wrong parameter {0} for '.format(key) +
+ 'Reconstruction algorithm')
+ # setParameter
+
def calculateLipschitzConstantWithPowerMethod(self):
''' using Power method (PM) to establish L constant'''
@@ -152,7 +180,8 @@ class FISTAReconstructor():
- if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'):
+ if (proj_geom['type'] == 'parallel') or \
+ (proj_geom['type'] == 'parallel3d'):
#% for parallel geometry we can do just one slice
#print('Calculating Lipshitz constant for parallel beam geometry...')
niter = 5;# % number of iteration for the PM
@@ -262,128 +291,3 @@ class FISTAReconstructor():
-
-
-def getEntry(location, nx):
- for item in nx[location].keys():
- print (item)
-
-
-print ("Loading Data")
-
-##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif"
-####ind = [i * 1049 for i in range(360)]
-#### use only 360 images
-##images = 200
-##ind = [int(i * 1049 / images) for i in range(images)]
-##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None)
-
-#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs"
-#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs"
-##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5"
-##nx = h5py.File(fname, "r")
-##
-### the data are stored in a particular location in the hdf5
-##for item in nx['entry1/tomo_entry/data'].keys():
-## print (item)
-##
-##data = nx.get('entry1/tomo_entry/data/rotation_angle')
-##angles = numpy.zeros(data.shape)
-##data.read_direct(angles)
-##print (angles)
-### angles should be in degrees
-##
-##data = nx.get('entry1/tomo_entry/data/data')
-##stack = numpy.zeros(data.shape)
-##data.read_direct(stack)
-##print (data.shape)
-##
-##print ("Data Loaded")
-##
-##
-### Normalize
-##data = nx.get('entry1/tomo_entry/instrument/detector/image_key')
-##itype = numpy.zeros(data.shape)
-##data.read_direct(itype)
-### 2 is dark field
-##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ]
-##dark = darks[0]
-##for i in range(1, len(darks)):
-## dark += darks[i]
-##dark = dark / len(darks)
-###dark[0][0] = dark[0][1]
-##
-### 1 is flat field
-##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ]
-##flat = flats[0]
-##for i in range(1, len(flats)):
-## flat += flats[i]
-##flat = flat / len(flats)
-###flat[0][0] = dark[0][1]
-##
-##
-### 0 is projection data
-##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ]
-##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ]
-##angle_proj = numpy.asarray (angle_proj)
-##angle_proj = angle_proj.astype(numpy.float32)
-##
-### normalized data are
-### norm = (projection - dark)/(flat-dark)
-##
-##def normalize(projection, dark, flat, def_val=0.1):
-## a = (projection - dark)
-## b = (flat-dark)
-## with numpy.errstate(divide='ignore', invalid='ignore'):
-## c = numpy.true_divide( a, b )
-## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0
-## return c
-##
-##
-##norm = [normalize(projection, dark, flat) for projection in proj]
-##norm = numpy.asarray (norm)
-##norm = norm.astype(numpy.float32)
-
-
-##niterations = 15
-##threads = 3
-##
-##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
-##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
-##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
-##
-##iteration_values = numpy.zeros((niterations,))
-##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
-## iteration_values, False)
-##print ("iteration values %s" % str(iteration_values))
-##
-##iteration_values = numpy.zeros((niterations,))
-##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
-## numpy.double(1e-5), iteration_values , False)
-##print ("iteration values %s" % str(iteration_values))
-##iteration_values = numpy.zeros((niterations,))
-##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
-## numpy.double(1e-5), iteration_values , False)
-##print ("iteration values %s" % str(iteration_values))
-##
-##
-####numpy.save("cgls_recon.npy", img_data)
-##import matplotlib.pyplot as plt
-##fig, ax = plt.subplots(1,6,sharey=True)
-##ax[0].imshow(img_cgls[80])
-##ax[0].axis('off') # clear x- and y-axes
-##ax[1].imshow(img_sirt[80])
-##ax[1].axis('off') # clear x- and y-axes
-##ax[2].imshow(img_mlem[80])
-##ax[2].axis('off') # clear x- and y-axesplt.show()
-##ax[3].imshow(img_cgls_conv[80])
-##ax[3].axis('off') # clear x- and y-axesplt.show()
-##ax[4].imshow(img_cgls_tikhonov[80])
-##ax[4].axis('off') # clear x- and y-axesplt.show()
-##ax[5].imshow(img_cgls_TVreg[80])
-##ax[5].axis('off') # clear x- and y-axesplt.show()
-##
-##
-##plt.show()
-##
-