diff options
author | Edoardo Pasca <edo.paskino@gmail.com> | 2019-06-06 10:07:37 +0100 |
---|---|---|
committer | Edoardo Pasca <edo.paskino@gmail.com> | 2019-06-06 10:07:37 +0100 |
commit | 940a1371fdf88e8c9e8230cece6fa1c73842804c (patch) | |
tree | 535701d603c3bdb4de86524be0fa3cd3d3aec80d | |
parent | 935361ba734c7a2ecae8835d5f6959d32f4c7403 (diff) | |
download | framework-940a1371fdf88e8c9e8230cece6fa1c73842804c.tar.gz framework-940a1371fdf88e8c9e8230cece6fa1c73842804c.tar.bz2 framework-940a1371fdf88e8c9e8230cece6fa1c73842804c.tar.xz framework-940a1371fdf88e8c9e8230cece6fa1c73842804c.zip |
add callback
-rw-r--r-- | Wrappers/Python/wip/compare_CGLS_algos.py | 28 | ||||
-rwxr-xr-x | Wrappers/Python/wip/fix_test.py | 31 |
2 files changed, 41 insertions, 18 deletions
diff --git a/Wrappers/Python/wip/compare_CGLS_algos.py b/Wrappers/Python/wip/compare_CGLS_algos.py index 119752c..52f3f31 100644 --- a/Wrappers/Python/wip/compare_CGLS_algos.py +++ b/Wrappers/Python/wip/compare_CGLS_algos.py @@ -12,6 +12,8 @@ from ccpi.optimisation.algorithms import CGLS as CGLSalg import numpy as np import matplotlib.pyplot as plt +from ccpi.optimisation.functions import Norm2Sq + # Choose either a parallel-beam (1=parallel2D) or fan-beam (2=cone2D) test case test_case = 1 @@ -101,25 +103,29 @@ x_CGLS, it_CGLS, timing_CGLS, criter_CGLS = CGLS(x_init, Aop, b, opt) #plt.title('CGLS criterion') #plt.show() +f = Norm2Sq(Aop, b, c=1.) +def callback(it, objective, solution): + print (objective, f(solution)) # Now CLGS using the algorithm class CGLS_alg = CGLSalg() CGLS_alg.set_up(x_init, Aop, b ) -CGLS_alg.max_iteration = 2000 -CGLS_alg.run(opt['iter']) +CGLS_alg.max_iteration = 500 +CGLS_alg.update_objective_interval = 10 +CGLS_alg.run(300, callback=callback) x_CGLS_alg = CGLS_alg.get_output() -#plt.figure() -#plt.imshow(x_CGLS_alg.as_array()) -#plt.title('CGLS ALG') -#plt.colorbar() -#plt.show() +plt.figure() +plt.imshow(x_CGLS_alg.as_array()) +plt.title('CGLS ALG') +plt.colorbar() +plt.show() -#plt.figure() -#plt.semilogy(CGLS_alg.objective) -#plt.title('CGLS criterion') -#plt.show() +plt.figure() +plt.semilogy(CGLS_alg.objective) +plt.title('CGLS criterion') +plt.show() print(criter_CGLS) print(CGLS_alg.objective) diff --git a/Wrappers/Python/wip/fix_test.py b/Wrappers/Python/wip/fix_test.py index 094f571..5e40d70 100755 --- a/Wrappers/Python/wip/fix_test.py +++ b/Wrappers/Python/wip/fix_test.py @@ -1,4 +1,5 @@ import numpy as np +import numpy from ccpi.optimisation.operators import * from ccpi.optimisation.algorithms import * from ccpi.optimisation.functions import * @@ -97,10 +98,10 @@ a = VectorData(x_init.as_array(), deep_copy=True) assert id(x_init.as_array()) != id(a.as_array()) #%% -f.L = LinearOperator.PowerMethod(A, 25, x_init)[0] * 2 +f.L = LinearOperator.PowerMethod(A, 25, x_init)[0] print ('f.L', f.L) -rate = (1 / f.L) / 3 -f.L *= 6 +rate = (1 / f.L) / 6 +f.L *= 12 # Initial guess #x_init = DataContainer(np.zeros((n, 1))) @@ -138,7 +139,7 @@ def callback(it, objective, solution): print (it, objective, solution.as_array()) fa = FISTA(x_init=x_init, f=f, g=g1) -fa.max_iteration = 1000 +fa.max_iteration = 10000 fa.update_objective_interval = int( fa.max_iteration / 10 ) fa.run(fa.max_iteration, callback = None, verbose=True) @@ -147,12 +148,28 @@ gd.max_iteration = 100000 gd.update_objective_interval = 10000 gd.run(gd.max_iteration, callback = None, verbose=True) +cgls = CGLS(x_init= x_init, operator=A, data=b) +cgls.max_iteration = 200 +cgls.update_objective_interval = 1 +def stop_criterion(alg): + try: + x = alg.get_last_objective() + print (x) + a = True if x < numpy.finfo(numpy.float32).eps else False + except IndexError as ie: + a = False + def f (): + return a or alg.max_iteration_stop_cryterion() + return f +#cgls.should_stop = stop_criterion(cgls) +cgls.run(cgls.max_iteration, callback = None, verbose=True) # Print for comparison print("FISTA least squares plus 1-norm solution and objective value:") print(fa.get_output().as_array()) print(fa.get_last_objective()) -print (A.direct(fa.get_output()).as_array()) -print (b.as_array()) -print (A.direct(gd.get_output()).as_array()) +print ("data ", b.as_array()) +print ('FISTA ', A.direct(fa.get_output()).as_array()) +print ('GradientDescent', A.direct(gd.get_output()).as_array()) +print ('CGLS ', A.direct(cgls.get_output()).as_array()) |