summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-04-15 16:11:16 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2019-04-15 16:11:16 +0100
commite0ec99b8a8a0e55a53531612da38c378790bbb60 (patch)
tree066974b10382d25cba1dee8fe9ad9058086c6d10 /Wrappers/Python
parent617f2e71dd34b3c1fe2997ffbaeefd7f030ec3aa (diff)
downloadframework-e0ec99b8a8a0e55a53531612da38c378790bbb60.tar.gz
framework-e0ec99b8a8a0e55a53531612da38c378790bbb60.tar.bz2
framework-e0ec99b8a8a0e55a53531612da38c378790bbb60.tar.xz
framework-e0ec99b8a8a0e55a53531612da38c378790bbb60.zip
use new algorithm class
Diffstat (limited to 'Wrappers/Python')
-rwxr-xr-xWrappers/Python/test/test_run_test.py31
1 files changed, 19 insertions, 12 deletions
diff --git a/Wrappers/Python/test/test_run_test.py b/Wrappers/Python/test/test_run_test.py
index 8cef925..c698032 100755
--- a/Wrappers/Python/test/test_run_test.py
+++ b/Wrappers/Python/test/test_run_test.py
@@ -6,10 +6,10 @@ from ccpi.framework import ImageData
from ccpi.framework import AcquisitionData
from ccpi.framework import ImageGeometry
from ccpi.framework import AcquisitionGeometry
-from ccpi.optimisation.algs import FISTA
-from ccpi.optimisation.algs import FBPD
+from ccpi.optimisation.algorithms import FISTA
+#from ccpi.optimisation.algs import FBPD
from ccpi.optimisation.funcs import Norm2sq
-from ccpi.optimisation.functions import ZeroFun
+from ccpi.optimisation.functions import ZeroFunction
from ccpi.optimisation.funcs import Norm1
from ccpi.optimisation.funcs import TV2D
from ccpi.optimisation.funcs import Norm2
@@ -82,7 +82,7 @@ class TestAlgorithms(unittest.TestCase):
opt = {'memopt': True}
# Create object instances with the test data A and b.
f = Norm2sq(A, b, c=0.5, memopt=True)
- g0 = ZeroFun()
+ g0 = ZeroFunction()
# Initial guess
x_init = DataContainer(np.zeros((n, 1)))
@@ -90,12 +90,15 @@ class TestAlgorithms(unittest.TestCase):
f.grad(x_init)
# Run FISTA for least squares plus zero function.
- x_fista0, it0, timing0, criter0 = FISTA(x_init, f, g0, opt=opt)
-
+ #x_fista0, it0, timing0, criter0 = FISTA(x_init, f, g0, opt=opt)
+ fa = FISTA(x_init=x_init, f=f, g=g0)
+ fa.max_iteration = 10
+ fa.run(10)
+
# Print solution and final objective/criterion value for comparison
print("FISTA least squares plus zero function solution and objective value:")
- print(x_fista0.array)
- print(criter0[-1])
+ print(fa.get_output())
+ print(fa.get_last_objective())
# Compare to CVXPY
@@ -143,7 +146,7 @@ class TestAlgorithms(unittest.TestCase):
opt = {'memopt': True}
# Create object instances with the test data A and b.
f = Norm2sq(A, b, c=0.5, memopt=True)
- g0 = ZeroFun()
+ g0 = ZeroFunction()
# Initial guess
x_init = DataContainer(np.zeros((n, 1)))
@@ -155,12 +158,16 @@ class TestAlgorithms(unittest.TestCase):
g1.prox(x_init, 0.02)
# Combine with least squares and solve using generic FISTA implementation
- x_fista1, it1, timing1, criter1 = FISTA(x_init, f, g1, opt=opt)
+ #x_fista1, it1, timing1, criter1 = FISTA(x_init, f, g1, opt=opt)
+ fa = FISTA(x_init=x_init, f=f, g=g1)
+ fa.max_iteration = 10
+ fa.run(10)
+
# Print for comparison
print("FISTA least squares plus 1-norm solution and objective value:")
- print(x_fista1.as_array().squeeze())
- print(criter1[-1])
+ print(fa.get_output())
+ print(fa.get_last_objective())
# Compare to CVXPY