summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rwxr-xr-xWrappers/Python/test/test_run_test.py31
1 files changed, 19 insertions, 12 deletions
diff --git a/Wrappers/Python/test/test_run_test.py b/Wrappers/Python/test/test_run_test.py
index 8cef925..c698032 100755
--- a/Wrappers/Python/test/test_run_test.py
+++ b/Wrappers/Python/test/test_run_test.py
@@ -6,10 +6,10 @@ from ccpi.framework import ImageData
from ccpi.framework import AcquisitionData
from ccpi.framework import ImageGeometry
from ccpi.framework import AcquisitionGeometry
-from ccpi.optimisation.algs import FISTA
-from ccpi.optimisation.algs import FBPD
+from ccpi.optimisation.algorithms import FISTA
+#from ccpi.optimisation.algs import FBPD
from ccpi.optimisation.funcs import Norm2sq
-from ccpi.optimisation.functions import ZeroFun
+from ccpi.optimisation.functions import ZeroFunction
from ccpi.optimisation.funcs import Norm1
from ccpi.optimisation.funcs import TV2D
from ccpi.optimisation.funcs import Norm2
@@ -82,7 +82,7 @@ class TestAlgorithms(unittest.TestCase):
opt = {'memopt': True}
# Create object instances with the test data A and b.
f = Norm2sq(A, b, c=0.5, memopt=True)
- g0 = ZeroFun()
+ g0 = ZeroFunction()
# Initial guess
x_init = DataContainer(np.zeros((n, 1)))
@@ -90,12 +90,15 @@ class TestAlgorithms(unittest.TestCase):
f.grad(x_init)
# Run FISTA for least squares plus zero function.
- x_fista0, it0, timing0, criter0 = FISTA(x_init, f, g0, opt=opt)
-
+ #x_fista0, it0, timing0, criter0 = FISTA(x_init, f, g0, opt=opt)
+ fa = FISTA(x_init=x_init, f=f, g=g0)
+ fa.max_iteration = 10
+ fa.run(10)
+
# Print solution and final objective/criterion value for comparison
print("FISTA least squares plus zero function solution and objective value:")
- print(x_fista0.array)
- print(criter0[-1])
+ print(fa.get_output())
+ print(fa.get_last_objective())
# Compare to CVXPY
@@ -143,7 +146,7 @@ class TestAlgorithms(unittest.TestCase):
opt = {'memopt': True}
# Create object instances with the test data A and b.
f = Norm2sq(A, b, c=0.5, memopt=True)
- g0 = ZeroFun()
+ g0 = ZeroFunction()
# Initial guess
x_init = DataContainer(np.zeros((n, 1)))
@@ -155,12 +158,16 @@ class TestAlgorithms(unittest.TestCase):
g1.prox(x_init, 0.02)
# Combine with least squares and solve using generic FISTA implementation
- x_fista1, it1, timing1, criter1 = FISTA(x_init, f, g1, opt=opt)
+ #x_fista1, it1, timing1, criter1 = FISTA(x_init, f, g1, opt=opt)
+ fa = FISTA(x_init=x_init, f=f, g=g1)
+ fa.max_iteration = 10
+ fa.run(10)
+
# Print for comparison
print("FISTA least squares plus 1-norm solution and objective value:")
- print(x_fista1.as_array().squeeze())
- print(criter1[-1])
+ print(fa.get_output())
+ print(fa.get_last_objective())
# Compare to CVXPY