summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-06-06 10:07:37 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2019-06-06 10:07:37 +0100
commit940a1371fdf88e8c9e8230cece6fa1c73842804c (patch)
tree535701d603c3bdb4de86524be0fa3cd3d3aec80d
parent935361ba734c7a2ecae8835d5f6959d32f4c7403 (diff)
downloadframework-940a1371fdf88e8c9e8230cece6fa1c73842804c.tar.gz
framework-940a1371fdf88e8c9e8230cece6fa1c73842804c.tar.bz2
framework-940a1371fdf88e8c9e8230cece6fa1c73842804c.tar.xz
framework-940a1371fdf88e8c9e8230cece6fa1c73842804c.zip
add callback
-rw-r--r--Wrappers/Python/wip/compare_CGLS_algos.py28
-rwxr-xr-xWrappers/Python/wip/fix_test.py31
2 files changed, 41 insertions, 18 deletions
diff --git a/Wrappers/Python/wip/compare_CGLS_algos.py b/Wrappers/Python/wip/compare_CGLS_algos.py
index 119752c..52f3f31 100644
--- a/Wrappers/Python/wip/compare_CGLS_algos.py
+++ b/Wrappers/Python/wip/compare_CGLS_algos.py
@@ -12,6 +12,8 @@ from ccpi.optimisation.algorithms import CGLS as CGLSalg
import numpy as np
import matplotlib.pyplot as plt
+from ccpi.optimisation.functions import Norm2Sq
+
# Choose either a parallel-beam (1=parallel2D) or fan-beam (2=cone2D) test case
test_case = 1
@@ -101,25 +103,29 @@ x_CGLS, it_CGLS, timing_CGLS, criter_CGLS = CGLS(x_init, Aop, b, opt)
#plt.title('CGLS criterion')
#plt.show()
+f = Norm2Sq(Aop, b, c=1.)
+def callback(it, objective, solution):
+ print (objective, f(solution))
# Now CLGS using the algorithm class
CGLS_alg = CGLSalg()
CGLS_alg.set_up(x_init, Aop, b )
-CGLS_alg.max_iteration = 2000
-CGLS_alg.run(opt['iter'])
+CGLS_alg.max_iteration = 500
+CGLS_alg.update_objective_interval = 10
+CGLS_alg.run(300, callback=callback)
x_CGLS_alg = CGLS_alg.get_output()
-#plt.figure()
-#plt.imshow(x_CGLS_alg.as_array())
-#plt.title('CGLS ALG')
-#plt.colorbar()
-#plt.show()
+plt.figure()
+plt.imshow(x_CGLS_alg.as_array())
+plt.title('CGLS ALG')
+plt.colorbar()
+plt.show()
-#plt.figure()
-#plt.semilogy(CGLS_alg.objective)
-#plt.title('CGLS criterion')
-#plt.show()
+plt.figure()
+plt.semilogy(CGLS_alg.objective)
+plt.title('CGLS criterion')
+plt.show()
print(criter_CGLS)
print(CGLS_alg.objective)
diff --git a/Wrappers/Python/wip/fix_test.py b/Wrappers/Python/wip/fix_test.py
index 094f571..5e40d70 100755
--- a/Wrappers/Python/wip/fix_test.py
+++ b/Wrappers/Python/wip/fix_test.py
@@ -1,4 +1,5 @@
import numpy as np
+import numpy
from ccpi.optimisation.operators import *
from ccpi.optimisation.algorithms import *
from ccpi.optimisation.functions import *
@@ -97,10 +98,10 @@ a = VectorData(x_init.as_array(), deep_copy=True)
assert id(x_init.as_array()) != id(a.as_array())
#%%
-f.L = LinearOperator.PowerMethod(A, 25, x_init)[0] * 2
+f.L = LinearOperator.PowerMethod(A, 25, x_init)[0]
print ('f.L', f.L)
-rate = (1 / f.L) / 3
-f.L *= 6
+rate = (1 / f.L) / 6
+f.L *= 12
# Initial guess
#x_init = DataContainer(np.zeros((n, 1)))
@@ -138,7 +139,7 @@ def callback(it, objective, solution):
print (it, objective, solution.as_array())
fa = FISTA(x_init=x_init, f=f, g=g1)
-fa.max_iteration = 1000
+fa.max_iteration = 10000
fa.update_objective_interval = int( fa.max_iteration / 10 )
fa.run(fa.max_iteration, callback = None, verbose=True)
@@ -147,12 +148,28 @@ gd.max_iteration = 100000
gd.update_objective_interval = 10000
gd.run(gd.max_iteration, callback = None, verbose=True)
+cgls = CGLS(x_init= x_init, operator=A, data=b)
+cgls.max_iteration = 200
+cgls.update_objective_interval = 1
+def stop_criterion(alg):
+ try:
+ x = alg.get_last_objective()
+ print (x)
+ a = True if x < numpy.finfo(numpy.float32).eps else False
+ except IndexError as ie:
+ a = False
+ def f ():
+ return a or alg.max_iteration_stop_cryterion()
+ return f
+#cgls.should_stop = stop_criterion(cgls)
+cgls.run(cgls.max_iteration, callback = None, verbose=True)
# Print for comparison
print("FISTA least squares plus 1-norm solution and objective value:")
print(fa.get_output().as_array())
print(fa.get_last_objective())
-print (A.direct(fa.get_output()).as_array())
-print (b.as_array())
-print (A.direct(gd.get_output()).as_array())
+print ("data ", b.as_array())
+print ('FISTA ', A.direct(fa.get_output()).as_array())
+print ('GradientDescent', A.direct(gd.get_output()).as_array())
+print ('CGLS ', A.direct(cgls.get_output()).as_array())