summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-02-20 15:05:07 +0000
committerEdoardo Pasca <edo.paskino@gmail.com>2019-02-20 15:05:07 +0000
commit5317bf21b45433313907c8f4d6331230c2c8349f (patch)
treed68f2e7b43a06bfc3eca33832134299eb5bcdb24
parent10c52f5eda45b412ca8859a04950df62745acbe8 (diff)
downloadframework-5317bf21b45433313907c8f4d6331230c2c8349f.tar.gz
framework-5317bf21b45433313907c8f4d6331230c2c8349f.tar.bz2
framework-5317bf21b45433313907c8f4d6331230c2c8349f.tar.xz
framework-5317bf21b45433313907c8f4d6331230c2c8349f.zip
add default stop criterion and run method
-rw-r--r--Wrappers/Python/ccpi/optimisation/Algorithms.py16
1 files changed, 9 insertions, 7 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/Algorithms.py b/Wrappers/Python/ccpi/optimisation/Algorithms.py
index 448a7b1..9115e6e 100644
--- a/Wrappers/Python/ccpi/optimisation/Algorithms.py
+++ b/Wrappers/Python/ccpi/optimisation/Algorithms.py
@@ -43,8 +43,8 @@ class Algorithm(object):
raise NotImplementedError()
def should_stop(self):
- '''stopping cryterion'''
- raise NotImplementedError()
+ '''default stopping cryterion: number of iterations'''
+ return self.iteration >= self.max_iteration
def __iter__(self):
return self
@@ -58,6 +58,7 @@ class Algorithm(object):
time0 = time.time()
self.update()
self.timing.append( time.time() - time0 )
+ # TODO update every N iterations
self.update_objective()
self.iteration += 1
def get_output(self):
@@ -66,12 +67,17 @@ class Algorithm(object):
def get_current_loss(self):
'''Returns the current value of the loss function'''
return self.__loss[-1]
+ def get_current_objective(self):
+ return self.get_current_loss()
def update_objective(self):
raise NotImplementedError()
@property
def loss(self):
return self.__loss
@property
+ def objective(self):
+ return self.__loss
+ @property
def max_iteration(self):
return self.__max_iteration
@max_iteration.setter
@@ -198,11 +204,7 @@ class FISTA(Algorithm):
self.invL = 1/f.L
self.t_old = 1
-
- def should_stop(self):
- '''stopping cryterion, currently only based on number of iterations'''
- return self.iteration >= self.max_iteration
-
+
def update(self):
# algorithm loop
#for it in range(0, max_iter):