summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Python/ccpi/imaging/Regularizer.py42
1 files changed, 27 insertions, 15 deletions
diff --git a/src/Python/ccpi/imaging/Regularizer.py b/src/Python/ccpi/imaging/Regularizer.py
index 8ab6c6a..23799d6 100644
--- a/src/Python/ccpi/imaging/Regularizer.py
+++ b/src/Python/ccpi/imaging/Regularizer.py
@@ -108,6 +108,8 @@ class Regularizer():
else:
raise Exception('Unknown regularizer algorithm')
+
+ self.acceptedInputKeywords = pars.keys()
return pars
# parsForAlgorithm
@@ -134,17 +136,24 @@ class Regularizer():
raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key))
# setParameter
- def getParameter(self, **kwargs):
- ret = {}
- for key , value in kwargs.items():
- if key in self.pars.keys():
- ret[key] = self.pars[key]
+ def getParameter(self, key):
+ if type(key) is str:
+ if key in self.acceptedInputKeywords:
+ return self.pars[key]
+ else:
+ raise Exception('Unrecongnised parameter: {0} '.format(key) )
+ elif type(key) is list:
+ outpars = []
+ for k in key:
+ outpars.append(self.getParameter(k))
+ return outpars
else:
- raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key))
- # setParameter
+ raise Exception('Unhandled input {0}' .format(str(type(key))))
+ # getParameter
- def __call__(self, input = None, regularization_parameter = None, **kwargs):
+ def __call__(self, input = None, regularization_parameter = None,
+ output_all = False, **kwargs):
'''Actual call for the regularizer.
One can either set the regularization parameters first and then call the
@@ -179,19 +188,19 @@ class Regularizer():
input = self.pars['input']
regularization_parameter = self.pars['regularization_parameter']
if self.algorithm == Regularizer.Algorithm.SplitBregman_TV :
- return self.algorithm(input, regularization_parameter,
+ ret = self.algorithm(input, regularization_parameter,
self.pars['number_of_iterations'],
self.pars['tolerance_constant'],
self.pars['TV_penalty'].value )
elif self.algorithm == Regularizer.Algorithm.FGP_TV :
- return self.algorithm(input, regularization_parameter,
+ ret = self.algorithm(input, regularization_parameter,
self.pars['number_of_iterations'],
self.pars['tolerance_constant'],
self.pars['TV_penalty'].value )
elif self.algorithm == Regularizer.Algorithm.LLT_model :
#LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
# no default
- return self.algorithm(input,
+ ret = self.algorithm(input,
regularization_parameter,
self.pars['time_step'] ,
self.pars['number_of_iterations'],
@@ -200,7 +209,7 @@ class Regularizer():
elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul :
#LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
# no default
- return self.algorithm(input, regularization_parameter,
+ ret = self.algorithm(input, regularization_parameter,
self.pars['searching_window_ratio'] ,
self.pars['similarity_window_ratio'] ,
self.pars['PB_filtering_parameter'])
@@ -208,7 +217,7 @@ class Regularizer():
#LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
# no default
if len(np.shape(input)) == 2:
- return self.algorithm(input, regularization_parameter,
+ ret = self.algorithm(input, regularization_parameter,
self.pars['first_order_term'] ,
self.pars['second_order_term'] ,
self.pars['number_of_iterations'])
@@ -227,11 +236,14 @@ class Regularizer():
output = [out3d]
for i in range(1,len(out)):
output.append(out[i])
- return output
+ ret = output
-
+ if output_all:
+ return ret
+ else:
+ return ret[0]
# __call__