summaryrefslogtreecommitdiffstats
path: root/Wrappers
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-05-03 14:10:51 +0100
committerGitHub <noreply@github.com>2019-05-03 14:10:51 +0100
commit3441b56e1ba887c71e54eaea7a0a71e44c58c5b1 (patch)
treeb8b572bee346799e313c8dd8e084fa54a327fa3f /Wrappers
parent376e12c567f045169119f87f82efc196570753ad (diff)
parent82d94d608ea639c0aa8aefb80cc97c5d8b1ba2cb (diff)
downloadframework-3441b56e1ba887c71e54eaea7a0a71e44c58c5b1.tar.gz
framework-3441b56e1ba887c71e54eaea7a0a71e44c58c5b1.tar.bz2
framework-3441b56e1ba887c71e54eaea7a0a71e44c58c5b1.tar.xz
framework-3441b56e1ba887c71e54eaea7a0a71e44c58c5b1.zip
Merge pull request #273 from vais-ral/cgls_bug_beast
Fix dot product bug to fix new CGLS algorithm closes #239
Diffstat (limited to 'Wrappers')
-rwxr-xr-xWrappers/Python/ccpi/framework/framework.py20
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algs.py9
-rwxr-xr-xWrappers/Python/test/test_DataContainer.py5
-rw-r--r--Wrappers/Python/wip/compare_CGLS_algos.py127
4 files changed, 151 insertions, 10 deletions
diff --git a/Wrappers/Python/ccpi/framework/framework.py b/Wrappers/Python/ccpi/framework/framework.py
index ffc91ae..7236e0e 100755
--- a/Wrappers/Python/ccpi/framework/framework.py
+++ b/Wrappers/Python/ccpi/framework/framework.py
@@ -764,12 +764,26 @@ class DataContainer(object):
return numpy.sqrt(self.squared_norm())
def dot(self, other, *args, **kwargs):
'''return the inner product of 2 DataContainers viewed as vectors'''
+ method = kwargs.get('method', 'reduce')
+ if method not in ['numpy','reduce']:
+ raise ValueError('dot: specified method not valid. Expecting numpy or reduce got {} '.format(
+ method))
if self.shape == other.shape:
- return numpy.dot(self.as_array().ravel(), other.as_array().ravel())
+ # return (self*other).sum()
+ if method == 'numpy':
+ return numpy.dot(self.as_array().ravel(), other.as_array())
+ elif method == 'reduce':
+ # see https://github.com/vais-ral/CCPi-Framework/pull/273
+ # notice that Python seems to be smart enough to use
+ # the appropriate type to hold the result of the reduction
+ sf = reduce(lambda x,y: x + y[0]*y[1],
+ zip(self.as_array().ravel(),
+ other.as_array().ravel()),
+ 0)
+ return sf
else:
raise ValueError('Shapes are not aligned: {} != {}'.format(self.shape, other.shape))
-
-
+
diff --git a/Wrappers/Python/ccpi/optimisation/algs.py b/Wrappers/Python/ccpi/optimisation/algs.py
index 89b5519..f5ba85e 100755
--- a/Wrappers/Python/ccpi/optimisation/algs.py
+++ b/Wrappers/Python/ccpi/optimisation/algs.py
@@ -20,13 +20,8 @@
import numpy
import time
-from ccpi.optimisation.functions import Function
-from ccpi.optimisation.functions import ZeroFunction
-from ccpi.framework import ImageData
-from ccpi.framework import AcquisitionData
-from ccpi.optimisation.spdhg import spdhg
-from ccpi.optimisation.spdhg import KullbackLeibler
-from ccpi.optimisation.spdhg import KullbackLeiblerConvexConjugate
+
+
def FISTA(x_init, f=None, g=None, opt=None):
'''Fast Iterative Shrinkage-Thresholding Algorithm
diff --git a/Wrappers/Python/test/test_DataContainer.py b/Wrappers/Python/test/test_DataContainer.py
index 8e8ab87..e92d4c6 100755
--- a/Wrappers/Python/test/test_DataContainer.py
+++ b/Wrappers/Python/test/test_DataContainer.py
@@ -455,6 +455,11 @@ class TestDataContainer(unittest.TestCase):
self.assertTrue(False)
except ValueError as ve:
self.assertTrue(True)
+
+ print ("test dot numpy")
+ n0 = (ds0 * ds1).sum()
+ n1 = ds0.as_array().ravel().dot(ds1.as_array().ravel())
+ self.assertEqual(n0, n1)
diff --git a/Wrappers/Python/wip/compare_CGLS_algos.py b/Wrappers/Python/wip/compare_CGLS_algos.py
new file mode 100644
index 0000000..119752c
--- /dev/null
+++ b/Wrappers/Python/wip/compare_CGLS_algos.py
@@ -0,0 +1,127 @@
+# This demo illustrates how to use the SIRT algorithm without and with
+# nonnegativity and box constraints. The ASTRA 2D projectors are used.
+
+# First make all imports
+from ccpi.framework import ImageData, ImageGeometry, AcquisitionGeometry, \
+ AcquisitionData
+from ccpi.optimisation.algs import FISTA, FBPD, CGLS, SIRT
+from ccpi.astra.operators import AstraProjectorSimple
+
+from ccpi.optimisation.algorithms import CGLS as CGLSalg
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+# Choose either a parallel-beam (1=parallel2D) or fan-beam (2=cone2D) test case
+test_case = 1
+
+# Set up phantom size NxN by creating ImageGeometry, initialising the
+# ImageData object with this geometry and empty array and finally put some
+# data into its array, and display as image.
+N = 128
+ig = ImageGeometry(voxel_num_x=N,voxel_num_y=N)
+Phantom = ImageData(geometry=ig)
+
+x = Phantom.as_array()
+x[round(N/4):round(3*N/4),round(N/4):round(3*N/4)] = 0.5
+x[round(N/8):round(7*N/8),round(3*N/8):round(5*N/8)] = 1
+
+#plt.figure()
+#plt.imshow(x)
+#plt.title('Phantom image')
+#plt.show()
+
+# Set up AcquisitionGeometry object to hold the parameters of the measurement
+# setup geometry: # Number of angles, the actual angles from 0 to
+# pi for parallel beam and 0 to 2pi for fanbeam, set the width of a detector
+# pixel relative to an object pixel, the number of detector pixels, and the
+# source-origin and origin-detector distance (here the origin-detector distance
+# set to 0 to simulate a "virtual detector" with same detector pixel size as
+# object pixel size).
+angles_num = 20
+det_w = 1.0
+det_num = N
+SourceOrig = 200
+OrigDetec = 0
+
+if test_case==1:
+ angles = np.linspace(0,np.pi,angles_num,endpoint=False)
+ ag = AcquisitionGeometry('parallel',
+ '2D',
+ angles,
+ det_num,det_w)
+elif test_case==2:
+ angles = np.linspace(0,2*np.pi,angles_num,endpoint=False)
+ ag = AcquisitionGeometry('cone',
+ '2D',
+ angles,
+ det_num,
+ det_w,
+ dist_source_center=SourceOrig,
+ dist_center_detector=OrigDetec)
+else:
+ NotImplemented
+
+# Set up Operator object combining the ImageGeometry and AcquisitionGeometry
+# wrapping calls to ASTRA as well as specifying whether to use CPU or GPU.
+Aop = AstraProjectorSimple(ig, ag, 'cpu')
+
+# Forward and backprojection are available as methods direct and adjoint. Here
+# generate test data b and do simple backprojection to obtain z.
+b = Aop.direct(Phantom)
+z = Aop.adjoint(b)
+
+#plt.figure()
+#plt.imshow(b.array)
+#plt.title('Simulated data')
+#plt.show()
+
+#plt.figure()
+#plt.imshow(z.array)
+#plt.title('Backprojected data')
+#plt.show()
+
+# Using the test data b, different reconstruction methods can now be set up as
+# demonstrated in the rest of this file. In general all methods need an initial
+# guess and some algorithm options to be set:
+x_init = ImageData(np.zeros(x.shape),geometry=ig)
+opt = {'tol': 1e-4, 'iter': 7}
+
+# First a CGLS reconstruction using the function version of CGLS can be done:
+x_CGLS, it_CGLS, timing_CGLS, criter_CGLS = CGLS(x_init, Aop, b, opt)
+
+#plt.figure()
+#plt.imshow(x_CGLS.array)
+#plt.title('CGLS')
+#plt.colorbar()
+#plt.show()
+
+#plt.figure()
+#plt.semilogy(criter_CGLS)
+#plt.title('CGLS criterion')
+#plt.show()
+
+
+
+# Now CLGS using the algorithm class
+CGLS_alg = CGLSalg()
+CGLS_alg.set_up(x_init, Aop, b )
+CGLS_alg.max_iteration = 2000
+CGLS_alg.run(opt['iter'])
+x_CGLS_alg = CGLS_alg.get_output()
+
+#plt.figure()
+#plt.imshow(x_CGLS_alg.as_array())
+#plt.title('CGLS ALG')
+#plt.colorbar()
+#plt.show()
+
+#plt.figure()
+#plt.semilogy(CGLS_alg.objective)
+#plt.title('CGLS criterion')
+#plt.show()
+
+print(criter_CGLS)
+print(CGLS_alg.objective)
+
+print((x_CGLS - x_CGLS_alg).norm()) \ No newline at end of file