summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python
diff options
context:
space:
mode:
authorSuren A. Chilingaryan <csa@suren.me>2020-03-27 20:25:10 +0100
committerSuren A. Chilingaryan <csa@suren.me>2020-03-27 20:25:10 +0100
commitda77a606f5d48ff31d72816d858736954f4585aa (patch)
tree4f003e2bee0f22ac77a2c339d9c8abb7c8b10615 /Wrappers/Python
parentadf4163c145e6ddc16899a92a06c3282f144d88c (diff)
downloadframework-fast_tnv.tar.gz
framework-fast_tnv.tar.bz2
framework-fast_tnv.tar.xz
framework-fast_tnv.zip
Cache computed lipschitz constant, seed random number generator to ensure that different runs give exactly the same resultfast_tnv
Diffstat (limited to 'Wrappers/Python')
-rw-r--r--Wrappers/Python/ccpi/optimisation/operators/Operator.py31
1 files changed, 27 insertions, 4 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/operators/Operator.py b/Wrappers/Python/ccpi/optimisation/operators/Operator.py
index d49bc1a..4e2d04d 100644
--- a/Wrappers/Python/ccpi/optimisation/operators/Operator.py
+++ b/Wrappers/Python/ccpi/optimisation/operators/Operator.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from numbers import Number
import numpy
import functools
+from hashlib import md5
class Operator(object):
'''Operator that maps from a space X -> Y'''
@@ -126,11 +127,13 @@ class LinearOperator(Operator):
:returns: tuple with: L, list of L at each iteration, the data the iteration worked on.
'''
+
# Initialise random
- if x_init is None:
- x0 = operator.domain_geometry().allocate('random')
- else:
- x0 = x_init.copy()
+ x0 = operator.domain_geometry().allocate('random', seed=1)
+ #if x_init is None:
+ # x0 = operator.domain_geometry().allocate('random', seed=1)
+ #else:
+ # x0 = x_init.copy()
x1 = operator.domain_geometry().allocate()
y_tmp = operator.range_geometry().allocate()
@@ -158,9 +161,29 @@ class LinearOperator(Operator):
:parameter force: forces the recalculation of the norm
:type force: boolean, default :code:`False`
'''
+
+ fname = md5("{} {}".format(self.domain_geometry(), self.range_geometry()).encode('utf-8')).hexdigest()
+ fname = "/tmp/ccpi_cache_opnorm_{}".format(fname)
+ try:
+ f = open(fname)
+ s1 = float(f.read())
+ print ("Returning norm {} from cache".format(s1))
+ return s1
+ except:
+ pass
+
x0 = kwargs.get('x_init', None)
iterations = kwargs.get('iterations', 25)
s1, sall, svec = LinearOperator.PowerMethod(self, iterations, x_init=x0)
+
+ try:
+ print(fname)
+ f = open(fname, "w")
+ f.write("{}".format(s1))
+ except:
+ pass
+
+
return s1
@staticmethod