diff options
-rwxr-xr-x | Wrappers/Python/test/test_run_test.py | 31 |
1 files changed, 19 insertions, 12 deletions
diff --git a/Wrappers/Python/test/test_run_test.py b/Wrappers/Python/test/test_run_test.py index 8cef925..c698032 100755 --- a/Wrappers/Python/test/test_run_test.py +++ b/Wrappers/Python/test/test_run_test.py @@ -6,10 +6,10 @@ from ccpi.framework import ImageData from ccpi.framework import AcquisitionData from ccpi.framework import ImageGeometry from ccpi.framework import AcquisitionGeometry -from ccpi.optimisation.algs import FISTA -from ccpi.optimisation.algs import FBPD +from ccpi.optimisation.algorithms import FISTA +#from ccpi.optimisation.algs import FBPD from ccpi.optimisation.funcs import Norm2sq -from ccpi.optimisation.functions import ZeroFun +from ccpi.optimisation.functions import ZeroFunction from ccpi.optimisation.funcs import Norm1 from ccpi.optimisation.funcs import TV2D from ccpi.optimisation.funcs import Norm2 @@ -82,7 +82,7 @@ class TestAlgorithms(unittest.TestCase): opt = {'memopt': True} # Create object instances with the test data A and b. f = Norm2sq(A, b, c=0.5, memopt=True) - g0 = ZeroFun() + g0 = ZeroFunction() # Initial guess x_init = DataContainer(np.zeros((n, 1))) @@ -90,12 +90,15 @@ class TestAlgorithms(unittest.TestCase): f.grad(x_init) # Run FISTA for least squares plus zero function. - x_fista0, it0, timing0, criter0 = FISTA(x_init, f, g0, opt=opt) - + #x_fista0, it0, timing0, criter0 = FISTA(x_init, f, g0, opt=opt) + fa = FISTA(x_init=x_init, f=f, g=g0) + fa.max_iteration = 10 + fa.run(10) + # Print solution and final objective/criterion value for comparison print("FISTA least squares plus zero function solution and objective value:") - print(x_fista0.array) - print(criter0[-1]) + print(fa.get_output()) + print(fa.get_last_objective()) # Compare to CVXPY @@ -143,7 +146,7 @@ class TestAlgorithms(unittest.TestCase): opt = {'memopt': True} # Create object instances with the test data A and b. f = Norm2sq(A, b, c=0.5, memopt=True) - g0 = ZeroFun() + g0 = ZeroFunction() # Initial guess x_init = DataContainer(np.zeros((n, 1))) @@ -155,12 +158,16 @@ class TestAlgorithms(unittest.TestCase): g1.prox(x_init, 0.02) # Combine with least squares and solve using generic FISTA implementation - x_fista1, it1, timing1, criter1 = FISTA(x_init, f, g1, opt=opt) + #x_fista1, it1, timing1, criter1 = FISTA(x_init, f, g1, opt=opt) + fa = FISTA(x_init=x_init, f=f, g=g1) + fa.max_iteration = 10 + fa.run(10) + # Print for comparison print("FISTA least squares plus 1-norm solution and objective value:") - print(x_fista1.as_array().squeeze()) - print(criter1[-1]) + print(fa.get_output()) + print(fa.get_last_objective()) # Compare to CVXPY |