summaryrefslogtreecommitdiffstats
path: root/Wrappers
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-06-06 10:18:31 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2019-06-06 10:18:31 +0100
commitb234f4cf26ee56da94211dc15c9b277c7c29fff4 (patch)
tree69d0cbd9f782fb771e4d6ee875adb75c07d90f30 /Wrappers
parent940a1371fdf88e8c9e8230cece6fa1c73842804c (diff)
downloadframework-b234f4cf26ee56da94211dc15c9b277c7c29fff4.tar.gz
framework-b234f4cf26ee56da94211dc15c9b277c7c29fff4.tar.bz2
framework-b234f4cf26ee56da94211dc15c9b277c7c29fff4.tar.xz
framework-b234f4cf26ee56da94211dc15c9b277c7c29fff4.zip
add prints
Diffstat (limited to 'Wrappers')
-rwxr-xr-xWrappers/Python/wip/fix_test.py38
1 files changed, 15 insertions, 23 deletions
diff --git a/Wrappers/Python/wip/fix_test.py b/Wrappers/Python/wip/fix_test.py
index 5e40d70..316606e 100755
--- a/Wrappers/Python/wip/fix_test.py
+++ b/Wrappers/Python/wip/fix_test.py
@@ -61,8 +61,8 @@ class Norm1(Function):
opt = {'memopt': True}
# Problem data.
-m = 30
-n = 30
+m = 4
+n = 10
np.random.seed(1)
Amat = np.asarray( np.random.randn(m, n), dtype=numpy.float32)
#Amat = np.asarray(np.eye(m), dtype=np.float32) * 2
@@ -78,20 +78,21 @@ print ("A", A.A)
# Change n to equal to m.
vgb = VectorGeometry(m)
vgx = VectorGeometry(n)
-b = vgb.allocate(VectorGeometry.RANDOM, dtype=numpy.float32)
-b.fill(bmat)
+b = vgb.allocate(2, dtype=numpy.float32)
+# b.fill(bmat)
#b = DataContainer(bmat)
# Regularization parameter
lam = 10
opt = {'memopt': True}
# Create object instances with the test data A and b.
-f = Norm2Sq(A, b, c=0.5, memopt=True)
+f = Norm2Sq(A, b, c=1., memopt=True)
#f = FunctionOperatorComposition(A, L2NormSquared(b=bmat))
g0 = ZeroFunction()
#f.L = 30.003
x_init = vgx.allocate(VectorGeometry.RANDOM, dtype=numpy.float32)
+x_initcgls = x_init.copy()
a = VectorData(x_init.as_array(), deep_copy=True)
@@ -136,33 +137,24 @@ print ("x1", x1.as_array())
# Combine with least squares and solve using generic FISTA implementation
#x_fista1, it1, timing1, criter1 = FISTA(x_init, f, g1, opt=opt)
def callback(it, objective, solution):
- print (it, objective, solution.as_array())
+ print (objective, f(solution))
fa = FISTA(x_init=x_init, f=f, g=g1)
-fa.max_iteration = 10000
+fa.max_iteration = 100
fa.update_objective_interval = int( fa.max_iteration / 10 )
fa.run(fa.max_iteration, callback = None, verbose=True)
gd = GradientDescent(x_init=x_init, objective_function=f, rate = rate )
-gd.max_iteration = 100000
-gd.update_objective_interval = 10000
+gd.max_iteration = 100
+gd.update_objective_interval = int( gd.max_iteration / 10 )
gd.run(gd.max_iteration, callback = None, verbose=True)
-cgls = CGLS(x_init= x_init, operator=A, data=b)
-cgls.max_iteration = 200
-cgls.update_objective_interval = 1
-def stop_criterion(alg):
- try:
- x = alg.get_last_objective()
- print (x)
- a = True if x < numpy.finfo(numpy.float32).eps else False
- except IndexError as ie:
- a = False
- def f ():
- return a or alg.max_iteration_stop_cryterion()
- return f
+cgls = CGLS(x_init= x_initcgls, operator=A, data=b)
+cgls.max_iteration = 1000
+cgls.update_objective_interval = 2
+
#cgls.should_stop = stop_criterion(cgls)
-cgls.run(cgls.max_iteration, callback = None, verbose=True)
+cgls.run(10, callback = callback, verbose=True)
# Print for comparison
print("FISTA least squares plus 1-norm solution and objective value:")