From d82298ee9a6e38ff6e286077f52a694acd58d5db Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 22 Oct 2019 17:43:58 +0100 Subject: Pdhg last objective (#407) * save previous iteration at start of iteration * adds very_verbose to run method * modified test closes #396 --- .../ccpi/optimisation/algorithms/Algorithm.py | 42 +++++++++++++--------- .../Python/ccpi/optimisation/algorithms/PDHG.py | 9 ++--- Wrappers/Python/test/test_algorithms.py | 3 +- 3 files changed, 33 insertions(+), 21 deletions(-) (limited to 'Wrappers/Python') diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/Algorithm.py b/Wrappers/Python/ccpi/optimisation/algorithms/Algorithm.py index f08688d..78ce438 100755 --- a/Wrappers/Python/ccpi/optimisation/algorithms/Algorithm.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/Algorithm.py @@ -60,6 +60,7 @@ class Algorithm(object): self.timing = [] self._iteration = [] self.update_objective_interval = kwargs.get('update_objective_interval', 1) + self.x = None def set_up(self, *args, **kwargs): '''Set up the algorithm''' raise NotImplementedError() @@ -109,16 +110,23 @@ class Algorithm(object): '''Returns the solution found''' return self.x - def get_last_loss(self): + def get_last_loss(self, **kwargs): '''Returns the last stored value of the loss function if update_objective_interval is 1 it is the value of the objective at the current iteration. If update_objective_interval > 1 it is the last stored value. ''' - return self.__loss[-1] - def get_last_objective(self): + return_all = kwargs.get('return_all', False) + objective = self.__loss[-1] + if return_all: + return list(objective) + if isinstance(objective, list): + return objective[0] + else: + return objective + def get_last_objective(self, **kwargs): '''alias to get_last_loss''' - return self.get_last_loss() + return self.get_last_loss(**kwargs) def update_objective(self): '''calculates the objective with the current solution''' raise NotImplementedError() @@ -155,7 +163,7 @@ class Algorithm(object): raise ValueError('Update objective interval must be an integer >= 1') else: raise ValueError('Update objective interval must be an integer >= 1') - def run(self, iterations=None, verbose=True, callback=None): + def run(self, iterations=None, verbose=True, callback=None, very_verbose=False): '''run n iterations and update the user with the callback if specified :param iterations: number of iterations to run. If not set the algorithm will @@ -163,30 +171,32 @@ class Algorithm(object): :param verbose: toggles verbose output to screen :param callback: is a function that receives: current iteration number, last objective function value and the current solution + :param very_verbose: bool, useful for algorithms with primal and dual objectives (PDHG), + prints to screen both primal and dual ''' if self.should_stop(): print ("Stop cryterion has been reached.") i = 0 if verbose: - print (self.verbose_header()) + print (self.verbose_header(very_verbose)) if self.iteration == 0: if verbose: - print(self.verbose_output()) + print(self.verbose_output(very_verbose)) for _ in self: if (self.iteration) % self.update_objective_interval == 0: if verbose: - print (self.verbose_output()) + print (self.verbose_output(very_verbose)) if callback is not None: - callback(self.iteration, self.get_last_objective(), self.x) + callback(self.iteration, self.get_last_objective(return_all=very_verbose), self.x) i += 1 if i == iterations: if self.iteration != self._iteration[-1]: self.update_objective() if verbose: - print (self.verbose_output()) + print (self.verbose_output(very_verbose)) break - def verbose_output(self): + def verbose_output(self, verbose=False): '''Creates a nice tabulated output''' timing = self.timing[-self.update_objective_interval-1:-1] self._iteration.append(self.iteration) @@ -198,20 +208,20 @@ class Algorithm(object): self.iteration, self.max_iteration, "{:.3f}".format(t), - self.objective_to_string() + self.objective_to_string(verbose) ) return out - def objective_to_string(self): - el = self.get_last_objective() + def objective_to_string(self, verbose=False): + el = self.get_last_objective(return_all=verbose) if type(el) == list: string = functools.reduce(lambda x,y: x+' {:>13.5e}'.format(y), el[:-1],'') string += '{:>15.5e}'.format(el[-1]) else: string = "{:>20.5e}".format(el) return string - def verbose_header(self): - el = self.get_last_objective() + def verbose_header(self, verbose=False): + el = self.get_last_objective(return_all=verbose) if type(el) == list: out = "{:>9} {:>10} {:>13} {:>13} {:>13} {:>15}\n".format('Iter', 'Max Iter', diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py index 7ed82b2..db1b8dc 100644 --- a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py @@ -123,7 +123,10 @@ class PDHG(Algorithm): def update(self): - + # save previous iteration + self.x_old.fill(self.x) + self.y_old.fill(self.y) + # Gradient ascent for the dual variable self.operator.direct(self.xbar, out=self.y_tmp) self.y_tmp *= self.sigma @@ -145,9 +148,7 @@ class PDHG(Algorithm): self.xbar += self.x - self.x_old.fill(self.x) - self.y_old.fill(self.y) - + def update_objective(self): p1 = self.f(self.operator.direct(self.x)) + self.g(self.x) diff --git a/Wrappers/Python/test/test_algorithms.py b/Wrappers/Python/test/test_algorithms.py index 2b38e3f..db13b97 100755 --- a/Wrappers/Python/test/test_algorithms.py +++ b/Wrappers/Python/test/test_algorithms.py @@ -195,6 +195,7 @@ class TestAlgorithms(unittest.TestCase): print ("PDHG Denoising with 3 noises") # adapted from demo PDHG_TV_Color_Denoising.py in CIL-Demos repository + # loader = TestData(data_dir=os.path.join(os.environ['SIRF_INSTALL_PATH'], 'share','ccpi')) # loader = TestData(data_dir=os.path.join(sys.prefix, 'share','ccpi')) loader = TestData() @@ -254,7 +255,7 @@ class TestAlgorithms(unittest.TestCase): pdhg1 = PDHG(f=f1,g=g,operator=operator, tau=tau, sigma=sigma) pdhg1.max_iteration = 2000 pdhg1.update_objective_interval = 200 - pdhg1.run(1000) + pdhg1.run(1000, very_verbose=True) rmse = (pdhg1.get_output() - data).norm() / data.as_array().size print ("RMSE", rmse) -- cgit v1.2.3