diff options
Diffstat (limited to 'Wrappers')
-rwxr-xr-x | Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py b/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py index 7794b4d..f1e4132 100755 --- a/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py @@ -51,13 +51,17 @@ class GradientDescent(Algorithm): def set_up(self, x_init, objective_function, rate): '''initialisation of the algorithm''' self.x = x_init.copy() - if self.memopt: - self.x_update = x_init.copy() self.objective_function = objective_function self.rate = rate self.loss.append(objective_function(x_init)) self.iteration = 0 - + try: + self.memopt = self.objective_function.memopt + except AttributeError as ae: + self.memopt = False + if self.memopt: + self.x_update = x_init.copy() + def update(self): '''Single iteration''' if self.memopt: @@ -65,7 +69,7 @@ class GradientDescent(Algorithm): self.x_update *= -self.rate self.x += self.x_update else: - self.x += -self.rate * self.objective_function.grad(self.x) + self.x += -self.rate * self.objective_function.gradient(self.x) def update_objective(self): self.loss.append(self.objective_function(self.x)) |