summaryrefslogtreecommitdiffstats
path: root/Wrappers
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-03-11 13:33:07 -0400
committerEdoardo Pasca <edo.paskino@gmail.com>2019-03-11 13:33:07 -0400
commit6a76bd07171ccf4e95372e7d84f6b381aad9e557 (patch)
tree834bc1b70b74eb7f57bced06cf6f8222b37ab6f6 /Wrappers
parent78d97a226ede52ccd7386a8bf4097c9f83f6c4a6 (diff)
downloadframework-6a76bd07171ccf4e95372e7d84f6b381aad9e557.tar.gz
framework-6a76bd07171ccf4e95372e7d84f6b381aad9e557.tar.bz2
framework-6a76bd07171ccf4e95372e7d84f6b381aad9e557.tar.xz
framework-6a76bd07171ccf4e95372e7d84f6b381aad9e557.zip
fix initialisation for memopt
Diffstat (limited to 'Wrappers')
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py b/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py
index 7794b4d..f1e4132 100755
--- a/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py
@@ -51,13 +51,17 @@ class GradientDescent(Algorithm):
def set_up(self, x_init, objective_function, rate):
'''initialisation of the algorithm'''
self.x = x_init.copy()
- if self.memopt:
- self.x_update = x_init.copy()
self.objective_function = objective_function
self.rate = rate
self.loss.append(objective_function(x_init))
self.iteration = 0
-
+ try:
+ self.memopt = self.objective_function.memopt
+ except AttributeError as ae:
+ self.memopt = False
+ if self.memopt:
+ self.x_update = x_init.copy()
+
def update(self):
'''Single iteration'''
if self.memopt:
@@ -65,7 +69,7 @@ class GradientDescent(Algorithm):
self.x_update *= -self.rate
self.x += self.x_update
else:
- self.x += -self.rate * self.objective_function.grad(self.x)
+ self.x += -self.rate * self.objective_function.gradient(self.x)
def update_objective(self):
self.loss.append(self.objective_function(self.x))