diff options
author | Edoardo Pasca <edo.paskino@gmail.com> | 2019-03-11 13:33:07 -0400 |
---|---|---|
committer | Edoardo Pasca <edo.paskino@gmail.com> | 2019-03-11 13:33:07 -0400 |
commit | 6a76bd07171ccf4e95372e7d84f6b381aad9e557 (patch) | |
tree | 834bc1b70b74eb7f57bced06cf6f8222b37ab6f6 /Wrappers | |
parent | 78d97a226ede52ccd7386a8bf4097c9f83f6c4a6 (diff) | |
download | framework-6a76bd07171ccf4e95372e7d84f6b381aad9e557.tar.gz framework-6a76bd07171ccf4e95372e7d84f6b381aad9e557.tar.bz2 framework-6a76bd07171ccf4e95372e7d84f6b381aad9e557.tar.xz framework-6a76bd07171ccf4e95372e7d84f6b381aad9e557.zip |
fix initialisation for memopt
Diffstat (limited to 'Wrappers')
-rwxr-xr-x | Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py b/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py index 7794b4d..f1e4132 100755 --- a/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py @@ -51,13 +51,17 @@ class GradientDescent(Algorithm): def set_up(self, x_init, objective_function, rate): '''initialisation of the algorithm''' self.x = x_init.copy() - if self.memopt: - self.x_update = x_init.copy() self.objective_function = objective_function self.rate = rate self.loss.append(objective_function(x_init)) self.iteration = 0 - + try: + self.memopt = self.objective_function.memopt + except AttributeError as ae: + self.memopt = False + if self.memopt: + self.x_update = x_init.copy() + def update(self): '''Single iteration''' if self.memopt: @@ -65,7 +69,7 @@ class GradientDescent(Algorithm): self.x_update *= -self.rate self.x += self.x_update else: - self.x += -self.rate * self.objective_function.grad(self.x) + self.x += -self.rate * self.objective_function.gradient(self.x) def update_objective(self): self.loss.append(self.objective_function(self.x)) |