diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Python/ccpi/imaging/Regularizer.py | 42 |
1 files changed, 27 insertions, 15 deletions
diff --git a/src/Python/ccpi/imaging/Regularizer.py b/src/Python/ccpi/imaging/Regularizer.py index 8ab6c6a..23799d6 100644 --- a/src/Python/ccpi/imaging/Regularizer.py +++ b/src/Python/ccpi/imaging/Regularizer.py @@ -108,6 +108,8 @@ class Regularizer(): else: raise Exception('Unknown regularizer algorithm') + + self.acceptedInputKeywords = pars.keys() return pars # parsForAlgorithm @@ -134,17 +136,24 @@ class Regularizer(): raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) # setParameter - def getParameter(self, **kwargs): - ret = {} - for key , value in kwargs.items(): - if key in self.pars.keys(): - ret[key] = self.pars[key] + def getParameter(self, key): + if type(key) is str: + if key in self.acceptedInputKeywords: + return self.pars[key] + else: + raise Exception('Unrecongnised parameter: {0} '.format(key) ) + elif type(key) is list: + outpars = [] + for k in key: + outpars.append(self.getParameter(k)) + return outpars else: - raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) - # setParameter + raise Exception('Unhandled input {0}' .format(str(type(key)))) + # getParameter - def __call__(self, input = None, regularization_parameter = None, **kwargs): + def __call__(self, input = None, regularization_parameter = None, + output_all = False, **kwargs): '''Actual call for the regularizer. One can either set the regularization parameters first and then call the @@ -179,19 +188,19 @@ class Regularizer(): input = self.pars['input'] regularization_parameter = self.pars['regularization_parameter'] if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : - return self.algorithm(input, regularization_parameter, + ret = self.algorithm(input, regularization_parameter, self.pars['number_of_iterations'], self.pars['tolerance_constant'], self.pars['TV_penalty'].value ) elif self.algorithm == Regularizer.Algorithm.FGP_TV : - return self.algorithm(input, regularization_parameter, + ret = self.algorithm(input, regularization_parameter, self.pars['number_of_iterations'], self.pars['tolerance_constant'], self.pars['TV_penalty'].value ) elif self.algorithm == Regularizer.Algorithm.LLT_model : #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) # no default - return self.algorithm(input, + ret = self.algorithm(input, regularization_parameter, self.pars['time_step'] , self.pars['number_of_iterations'], @@ -200,7 +209,7 @@ class Regularizer(): elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) # no default - return self.algorithm(input, regularization_parameter, + ret = self.algorithm(input, regularization_parameter, self.pars['searching_window_ratio'] , self.pars['similarity_window_ratio'] , self.pars['PB_filtering_parameter']) @@ -208,7 +217,7 @@ class Regularizer(): #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) # no default if len(np.shape(input)) == 2: - return self.algorithm(input, regularization_parameter, + ret = self.algorithm(input, regularization_parameter, self.pars['first_order_term'] , self.pars['second_order_term'] , self.pars['number_of_iterations']) @@ -227,11 +236,14 @@ class Regularizer(): output = [out3d] for i in range(1,len(out)): output.append(out[i]) - return output + ret = output - + if output_all: + return ret + else: + return ret[0] # __call__ |