diff options
author | Edoardo Pasca <edo.paskino@gmail.com> | 2019-06-25 22:49:52 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-06-25 22:49:52 +0100 |
commit | 57382bc824cf38ea1ce0b1bf62e4b3d39c0296be (patch) | |
tree | 91865551e65827af9cd2171d370957fa2d90e1ce | |
parent | 7df55839877aa36a93e1326bc08071993ece949c (diff) | |
download | framework-plugins-57382bc824cf38ea1ce0b1bf62e4b3d39c0296be.tar.gz framework-plugins-57382bc824cf38ea1ce0b1bf62e4b3d39c0296be.tar.bz2 framework-plugins-57382bc824cf38ea1ce0b1bf62e4b3d39c0296be.tar.xz framework-plugins-57382bc824cf38ea1ce0b1bf62e4b3d39c0296be.zip |
Update regularisers.py (#27)
-rw-r--r-- | Wrappers/Python/ccpi/plugins/regularisers.py | 36 |
1 files changed, 25 insertions, 11 deletions
diff --git a/Wrappers/Python/ccpi/plugins/regularisers.py b/Wrappers/Python/ccpi/plugins/regularisers.py index f665a04..77543f9 100644 --- a/Wrappers/Python/ccpi/plugins/regularisers.py +++ b/Wrappers/Python/ccpi/plugins/regularisers.py @@ -43,11 +43,16 @@ class ROF_TV(Function): 'number_of_iterations' :self.iterationsTV ,\ 'time_marching_parameter':self.time_marchstep} - out = regularisers.ROF_TV(pars['input'], + res = regularisers.ROF_TV(pars['input'], pars['regularization_parameter'], pars['number_of_iterations'], pars['time_marching_parameter'], self.device) - return DataContainer(out) + if out is not None: + out.fill(res) + else: + out = x.copy() + out.fill(res) + return out class FGP_TV(Function): def __init__(self,lambdaReg,iterationsTV,tolerance,methodTV,nonnegativity,printing,device): @@ -63,7 +68,7 @@ class FGP_TV(Function): # evaluate objective function of TV gradient EnergyValTV = TV_ENERGY(np.asarray(x.as_array(), dtype=np.float32), np.asarray(x.as_array(), dtype=np.float32), self.lambdaReg, 2) return 0.5*EnergyValTV[0] - def prox(self,x,tau): + def proximal(self,x,tau, out=None): pars = {'algorithm' : FGP_TV, \ 'input' : np.asarray(x.as_array(), dtype=np.float32),\ 'regularization_parameter':self.lambdaReg*tau, \ @@ -73,16 +78,20 @@ class FGP_TV(Function): 'nonneg': self.nonnegativity ,\ 'printingOut': self.printing} - out = regularisers.FGP_TV(pars['input'], + res = regularisers.FGP_TV(pars['input'], pars['regularization_parameter'], pars['number_of_iterations'], pars['tolerance_constant'], pars['methodTV'], pars['nonneg'], - pars['printingOut'], self.device) - return DataContainer(out) - - + self.device) + if out is not None: + out.fill(res) + else: + out = x.copy() + out.fill(res) + return out + class SB_TV(Function): def __init__(self,lambdaReg,iterationsTV,tolerance,methodTV,printing,device): # set parameters @@ -96,7 +105,7 @@ class SB_TV(Function): # evaluate objective function of TV gradient EnergyValTV = TV_ENERGY(np.asarray(x.as_array(), dtype=np.float32), np.asarray(x.as_array(), dtype=np.float32), self.lambdaReg, 2) return 0.5*EnergyValTV[0] - def prox(self,x,tau): + def proximal(self,x,tau, out=None): pars = {'algorithm' : SB_TV, \ 'input' : np.asarray(x.as_array(), dtype=np.float32),\ 'regularization_parameter':self.lambdaReg*tau, \ @@ -105,10 +114,15 @@ class SB_TV(Function): 'methodTV': self.methodTV ,\ 'printingOut': self.printing} - out = regularisers.SB_TV(pars['input'], + res = regularisers.SB_TV(pars['input'], pars['regularization_parameter'], pars['number_of_iterations'], pars['tolerance_constant'], pars['methodTV'], pars['printingOut'], self.device) - return DataContainer(out) + if out is not None: + out.fill(res) + else: + out = x.copy() + out.fill(res) + return out |