summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-06-13 15:20:53 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2019-06-13 15:20:53 +0100
commite162d5ef4897dcdfdb9b27af050939298409301d (patch)
tree746f1d6e3250110e6a47d0b23d49c78d3809e202 /Wrappers/Python
parent3869559b14500fa4d730f084c4645b6c485c647f (diff)
parent4ad11441b6042de7518148d0fa59492313ee5e2e (diff)
downloadframework-e162d5ef4897dcdfdb9b27af050939298409301d.tar.gz
framework-e162d5ef4897dcdfdb9b27af050939298409301d.tar.bz2
framework-e162d5ef4897dcdfdb9b27af050939298409301d.tar.xz
framework-e162d5ef4897dcdfdb9b27af050939298409301d.zip
added test
Diffstat (limited to 'Wrappers/Python')
-rwxr-xr-xWrappers/Python/test/test_run_test.py23
-rwxr-xr-xWrappers/Python/wip/fix_test.py40
2 files changed, 44 insertions, 19 deletions
diff --git a/Wrappers/Python/test/test_run_test.py b/Wrappers/Python/test/test_run_test.py
index 4f66da1..a0db9cb 100755
--- a/Wrappers/Python/test/test_run_test.py
+++ b/Wrappers/Python/test/test_run_test.py
@@ -10,10 +10,9 @@ from ccpi.optimisation.algorithms import FISTA
#from ccpi.optimisation.algs import FBPD
from ccpi.optimisation.functions import Norm2Sq
from ccpi.optimisation.functions import ZeroFunction
+# from ccpi.optimisation.funcs import Norm1
from ccpi.optimisation.functions import L1Norm
-# This was removed
-#from ccpi.optimisation.funcs import Norm2
-#from ccpi.optimisation.funcs import Norm1
+from ccpi.optimisation.funcs import Norm2
from ccpi.optimisation.operators import LinearOperatorMatrix
from ccpi.optimisation.operators import Identity
@@ -83,7 +82,6 @@ class TestAlgorithms(unittest.TestCase):
opt = {'memopt': True}
# Create object instances with the test data A and b.
f = Norm2Sq(A, b, c=0.5, memopt=True)
- f.L = LinearOperator.PowerMethod(A, 10)
g0 = ZeroFunction()
# Initial guess
@@ -126,7 +124,6 @@ class TestAlgorithms(unittest.TestCase):
self.assertTrue(cvx_not_installable)
def test_FISTA_Norm1_cvx(self):
- print ("test_FISTA_Norm1_cvx")
if not cvx_not_installable:
try:
opt = {'memopt': True}
@@ -141,8 +138,11 @@ class TestAlgorithms(unittest.TestCase):
# A = Identity()
# Change n to equal to m.
-
- b = DataContainer(bmat)
+ vgb = VectorGeometry(m)
+ vgx = VectorGeometry(n)
+ b = vgb.allocate()
+ b.fill(bmat)
+ #b = DataContainer(bmat)
# Regularization parameter
lam = 10
@@ -152,10 +152,11 @@ class TestAlgorithms(unittest.TestCase):
g0 = ZeroFunction()
# Initial guess
- x_init = DataContainer(np.zeros((n, 1)))
+ #x_init = DataContainer(np.zeros((n, 1)))
+ x_init = vgx.allocate()
# Create 1-norm object instance
- g1 = lam * L1Norm()
+ g1 = Norm1(lam)
g1(x_init)
g1.prox(x_init, 0.02)
@@ -229,7 +230,7 @@ class TestAlgorithms(unittest.TestCase):
# Create 1-norm object instance
- g1 = lam * L1Norm()
+ g1 = Norm1(lam)
# Compare to CVXPY
@@ -296,7 +297,7 @@ class TestAlgorithms(unittest.TestCase):
# 1-norm regulariser
lam1_denoise = 1.0
- g1_denoise = lam1_denoise * L1Norm()
+ g1_denoise = Norm1(lam1_denoise)
# Initial guess
x_init_denoise = ImageData(np.zeros((N, N)))
diff --git a/Wrappers/Python/wip/fix_test.py b/Wrappers/Python/wip/fix_test.py
index 9eb0a4e..b1006c0 100755
--- a/Wrappers/Python/wip/fix_test.py
+++ b/Wrappers/Python/wip/fix_test.py
@@ -103,13 +103,12 @@ a = VectorData(x_init.as_array(), deep_copy=True)
assert id(x_init.as_array()) != id(a.as_array())
-#%%
-# f.L = LinearOperator.PowerMethod(A, 25, x_init)[0]
-# print ('f.L', f.L)
+
+#f.L = LinearOperator.PowerMethod(A, 25, x_init)[0]
+#print ('f.L', f.L)
rate = (1 / f.L) / 6
-f.L *= 12
-print (f.L)
-# rate = f.L / 1000
+#f.L *= 12
+
# Initial guess
#x_init = DataContainer(np.zeros((n, 1)))
print ('x_init', x_init.as_array())
@@ -154,7 +153,7 @@ fa.update_objective_interval = int( fa.max_iteration / 10 )
fa.run(fa.max_iteration, callback = None, verbose=True)
gd = GradientDescent(x_init=x_init, objective_function=f, rate = rate )
-gd.max_iteration = 10000
+gd.max_iteration = 5000
gd.update_objective_interval = int( gd.max_iteration / 10 )
gd.run(gd.max_iteration, callback = None, verbose=True)
@@ -181,4 +180,29 @@ print ('CGLS ', A.direct(cgls.get_output()).as_array())
cond = numpy.linalg.cond(A.A)
-print ("cond" , cond) \ No newline at end of file
+print ("cond" , cond)
+
+#%%
+try:
+ import cvxpy as cp
+ # Construct the problem.
+ x = cp.Variable(n)
+ objective = cp.Minimize(cp.sum_squares(A.A*x - bmat))
+ prob = cp.Problem(objective)
+ # The optimal objective is returned by prob.solve().
+ result = prob.solve(solver = cp.MOSEK)
+
+ print ('CGLS ', cgls.get_output().as_array())
+ print ('CVX ', x.value)
+
+ print ('FISTA ', fa.get_output().as_array())
+ print ('GD ', gd.get_output().as_array())
+except ImportError as ir:
+ pass
+
+ #%%
+
+
+
+
+