summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-06-25 22:49:52 +0100
committerGitHub <noreply@github.com>2019-06-25 22:49:52 +0100
commit57382bc824cf38ea1ce0b1bf62e4b3d39c0296be (patch)
tree91865551e65827af9cd2171d370957fa2d90e1ce
parent7df55839877aa36a93e1326bc08071993ece949c (diff)
downloadframework-plugins-57382bc824cf38ea1ce0b1bf62e4b3d39c0296be.tar.gz
framework-plugins-57382bc824cf38ea1ce0b1bf62e4b3d39c0296be.tar.bz2
framework-plugins-57382bc824cf38ea1ce0b1bf62e4b3d39c0296be.tar.xz
framework-plugins-57382bc824cf38ea1ce0b1bf62e4b3d39c0296be.zip
Update regularisers.py (#27)
-rw-r--r--Wrappers/Python/ccpi/plugins/regularisers.py36
1 files changed, 25 insertions, 11 deletions
diff --git a/Wrappers/Python/ccpi/plugins/regularisers.py b/Wrappers/Python/ccpi/plugins/regularisers.py
index f665a04..77543f9 100644
--- a/Wrappers/Python/ccpi/plugins/regularisers.py
+++ b/Wrappers/Python/ccpi/plugins/regularisers.py
@@ -43,11 +43,16 @@ class ROF_TV(Function):
'number_of_iterations' :self.iterationsTV ,\
'time_marching_parameter':self.time_marchstep}
- out = regularisers.ROF_TV(pars['input'],
+ res = regularisers.ROF_TV(pars['input'],
pars['regularization_parameter'],
pars['number_of_iterations'],
pars['time_marching_parameter'], self.device)
- return DataContainer(out)
+ if out is not None:
+ out.fill(res)
+ else:
+ out = x.copy()
+ out.fill(res)
+ return out
class FGP_TV(Function):
def __init__(self,lambdaReg,iterationsTV,tolerance,methodTV,nonnegativity,printing,device):
@@ -63,7 +68,7 @@ class FGP_TV(Function):
# evaluate objective function of TV gradient
EnergyValTV = TV_ENERGY(np.asarray(x.as_array(), dtype=np.float32), np.asarray(x.as_array(), dtype=np.float32), self.lambdaReg, 2)
return 0.5*EnergyValTV[0]
- def prox(self,x,tau):
+ def proximal(self,x,tau, out=None):
pars = {'algorithm' : FGP_TV, \
'input' : np.asarray(x.as_array(), dtype=np.float32),\
'regularization_parameter':self.lambdaReg*tau, \
@@ -73,16 +78,20 @@ class FGP_TV(Function):
'nonneg': self.nonnegativity ,\
'printingOut': self.printing}
- out = regularisers.FGP_TV(pars['input'],
+ res = regularisers.FGP_TV(pars['input'],
pars['regularization_parameter'],
pars['number_of_iterations'],
pars['tolerance_constant'],
pars['methodTV'],
pars['nonneg'],
- pars['printingOut'], self.device)
- return DataContainer(out)
-
-
+ self.device)
+ if out is not None:
+ out.fill(res)
+ else:
+ out = x.copy()
+ out.fill(res)
+ return out
+
class SB_TV(Function):
def __init__(self,lambdaReg,iterationsTV,tolerance,methodTV,printing,device):
# set parameters
@@ -96,7 +105,7 @@ class SB_TV(Function):
# evaluate objective function of TV gradient
EnergyValTV = TV_ENERGY(np.asarray(x.as_array(), dtype=np.float32), np.asarray(x.as_array(), dtype=np.float32), self.lambdaReg, 2)
return 0.5*EnergyValTV[0]
- def prox(self,x,tau):
+ def proximal(self,x,tau, out=None):
pars = {'algorithm' : SB_TV, \
'input' : np.asarray(x.as_array(), dtype=np.float32),\
'regularization_parameter':self.lambdaReg*tau, \
@@ -105,10 +114,15 @@ class SB_TV(Function):
'methodTV': self.methodTV ,\
'printingOut': self.printing}
- out = regularisers.SB_TV(pars['input'],
+ res = regularisers.SB_TV(pars['input'],
pars['regularization_parameter'],
pars['number_of_iterations'],
pars['tolerance_constant'],
pars['methodTV'],
pars['printingOut'], self.device)
- return DataContainer(out)
+ if out is not None:
+ out.fill(res)
+ else:
+ out = x.copy()
+ out.fill(res)
+ return out