diff options
Diffstat (limited to 'Wrappers/Python')
| -rwxr-xr-x | Wrappers/Python/ccpi/framework/framework.py | 246 | ||||
| -rw-r--r-- | Wrappers/Python/ccpi/io/reader.py | 21 | ||||
| -rwxr-xr-x | Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py | 3 | ||||
| -rw-r--r-- | Wrappers/Python/ccpi/optimisation/algorithms/FBPD.py | 2 | ||||
| -rwxr-xr-x | Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py | 6 | ||||
| -rwxr-xr-x | Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py | 2 | ||||
| -rwxr-xr-x | Wrappers/Python/ccpi/optimisation/algs.py | 2 | ||||
| -rwxr-xr-x | Wrappers/Python/test/test_DataContainer.py | 16 | ||||
| -rw-r--r-- | Wrappers/Python/test/test_functions.py | 4 | ||||
| -rwxr-xr-x | Wrappers/Python/test/test_run_test.py | 31 | 
10 files changed, 197 insertions, 136 deletions
| diff --git a/Wrappers/Python/ccpi/framework/framework.py b/Wrappers/Python/ccpi/framework/framework.py index af4139b..7516447 100755 --- a/Wrappers/Python/ccpi/framework/framework.py +++ b/Wrappers/Python/ccpi/framework/framework.py @@ -772,61 +772,18 @@ class DataContainer(object):  class ImageData(DataContainer):      '''DataContainer for holding 2D or 3D DataContainer'''      __container_priority__ = 1 +     +          def __init__(self,                    array = None,                    deep_copy=False,                    dimension_labels=None,                    **kwargs): -        self.geometry = None +        self.geometry = kwargs.get('geometry', None)          if array is None: -            if 'geometry' in kwargs.keys(): -                geometry  = kwargs['geometry'] -                self.geometry = geometry -                channels  = geometry.channels -                horiz_x   = geometry.voxel_num_x -                horiz_y   = geometry.voxel_num_y -                vert      = 1 if geometry.voxel_num_z is None\ -                              else geometry.voxel_num_z # this should be 1 for 2D -                if dimension_labels is None: -                    if channels > 1: -                        if vert > 1: -                            shape = (channels, vert, horiz_y, horiz_x) -                            dim_labels = [ImageGeometry.CHANNEL,  -                                          ImageGeometry.VERTICAL, -                                          ImageGeometry.HORIZONTAL_Y, -                                          ImageGeometry.HORIZONTAL_X] -                        else: -                            shape = (channels , horiz_y, horiz_x) -                            dim_labels = [ImageGeometry.CHANNEL, -                                          ImageGeometry.HORIZONTAL_Y, -                                          ImageGeometry.HORIZONTAL_X] -                    else: -                        if vert > 1: -                            shape = (vert, horiz_y, horiz_x) -                            dim_labels = [ImageGeometry.VERTICAL, -                                          ImageGeometry.HORIZONTAL_Y, -                                          ImageGeometry.HORIZONTAL_X] -                        else: -                            shape = (horiz_y, horiz_x) -                            dim_labels = [ImageGeometry.HORIZONTAL_Y, -                                          ImageGeometry.HORIZONTAL_X] -                    dimension_labels = dim_labels -                else: -                    shape = [] -                    for dim in dimension_labels: -                        if dim == ImageGeometry.CHANNEL: -                            shape.append(channels) -                        elif dim == ImageGeometry.HORIZONTAL_Y: -                            shape.append(horiz_y) -                        elif dim == ImageGeometry.VERTICAL: -                            shape.append(vert) -                        elif dim == ImageGeometry.HORIZONTAL_X: -                            shape.append(horiz_x) -                    if len(shape) != len(dimension_labels): -                        raise ValueError('Missing {0} axes'.format( -                                len(dimension_labels) - len(shape))) -                    shape = tuple(shape) +            if self.geometry is not None: +                shape, dimension_labels = self.get_shape_labels(self.geometry)                  array = numpy.zeros( shape , dtype=numpy.float32)                   super(ImageData, self).__init__(array, deep_copy, @@ -836,6 +793,11 @@ class ImageData(DataContainer):                  raise ValueError('Please pass either a DataContainer, ' +\                                   'a numpy array or a geometry')          else: +            if self.geometry is not None: +                shape, labels = self.get_shape_labels(self.geometry, dimension_labels) +                if array.shape != shape: +                    raise ValueError('Shape mismatch {} {}'.format(shape, array.shape)) +                          if issubclass(type(array) , DataContainer):                  # if the array is a DataContainer get the info from there                  if not ( array.number_of_dimensions == 2 or \ @@ -890,78 +852,85 @@ class ImageData(DataContainer):              #out.geometry = self.recalculate_geometry(dimensions , **kw)              out.geometry = self.geometry              return out -                         + +    def get_shape_labels(self, geometry, dimension_labels=None): +        channels  = geometry.channels +        horiz_x   = geometry.voxel_num_x +        horiz_y   = geometry.voxel_num_y +        vert      = 1 if geometry.voxel_num_z is None\ +                      else geometry.voxel_num_z # this should be 1 for 2D +        if dimension_labels is None: +            if channels > 1: +                if vert > 1: +                    shape = (channels, vert, horiz_y, horiz_x) +                    dim_labels = [ImageGeometry.CHANNEL,  +                                  ImageGeometry.VERTICAL, +                                  ImageGeometry.HORIZONTAL_Y, +                                  ImageGeometry.HORIZONTAL_X] +                else: +                    shape = (channels , horiz_y, horiz_x) +                    dim_labels = [ImageGeometry.CHANNEL, +                                  ImageGeometry.HORIZONTAL_Y, +                                  ImageGeometry.HORIZONTAL_X] +            else: +                if vert > 1: +                    shape = (vert, horiz_y, horiz_x) +                    dim_labels = [ImageGeometry.VERTICAL, +                                  ImageGeometry.HORIZONTAL_Y, +                                  ImageGeometry.HORIZONTAL_X] +                else: +                    shape = (horiz_y, horiz_x) +                    dim_labels = [ImageGeometry.HORIZONTAL_Y, +                                  ImageGeometry.HORIZONTAL_X] +            dimension_labels = dim_labels +        else: +            shape = [] +            for i in range(len(dimension_labels)): +                dim = dimension_labels[i] +                if dim == ImageGeometry.CHANNEL: +                    shape.append(channels) +                elif dim == ImageGeometry.HORIZONTAL_Y: +                    shape.append(horiz_y) +                elif dim == ImageGeometry.VERTICAL: +                    shape.append(vert) +                elif dim == ImageGeometry.HORIZONTAL_X: +                    shape.append(horiz_x) +            if len(shape) != len(dimension_labels): +                raise ValueError('Missing {0} axes {1} shape {2}'.format( +                        len(dimension_labels) - len(shape), dimension_labels, shape)) +            shape = tuple(shape) +             +        return (shape, dimension_labels) +                              class AcquisitionData(DataContainer):      '''DataContainer for holding 2D or 3D sinogram'''      __container_priority__ = 1 +     +          def __init__(self,                    array = None,                    deep_copy=True,                    dimension_labels=None,                    **kwargs): -        self.geometry = None +        self.geometry = kwargs.get('geometry', None)          if array is None:              if 'geometry' in kwargs.keys():                  geometry      = kwargs['geometry']                  self.geometry = geometry -                channels      = geometry.channels -                horiz         = geometry.pixel_num_h -                vert          = geometry.pixel_num_v -                angles        = geometry.angles -                num_of_angles = numpy.shape(angles)[0] -                if dimension_labels is None: -                    if channels > 1: -                        if vert > 1: -                            shape = (channels, num_of_angles , vert, horiz) -                            dim_labels = [AcquisitionGeometry.CHANNEL, -                                          AcquisitionGeometry.ANGLE, -                                          AcquisitionGeometry.VERTICAL, -                                          AcquisitionGeometry.HORIZONTAL] -                        else: -                            shape = (channels , num_of_angles, horiz) -                            dim_labels = [AcquisitionGeometry.CHANNEL, -                                          AcquisitionGeometry.ANGLE, -                                          AcquisitionGeometry.HORIZONTAL] -                    else: -                        if vert > 1: -                            shape = (num_of_angles, vert, horiz) -                            dim_labels = [AcquisitionGeometry.ANGLE, -                                          AcquisitionGeometry.VERTICAL, -                                          AcquisitionGeometry.HORIZONTAL -                                          ] -                        else: -                            shape = (num_of_angles, horiz) -                            dim_labels = [AcquisitionGeometry.ANGLE, -                                          AcquisitionGeometry.HORIZONTAL -                                          ] -                     -                    dimension_labels = dim_labels -                else: -                    shape = [] -                    for dim in dimension_labels: -                        if dim == AcquisitionGeometry.CHANNEL: -                            shape.append(channels) -                        elif dim == AcquisitionGeometry.ANGLE: -                            shape.append(num_of_angles) -                        elif dim == AcquisitionGeometry.VERTICAL: -                            shape.append(vert) -                        elif dim == AcquisitionGeometry.HORIZONTAL: -                            shape.append(horiz) -                    if len(shape) != len(dimension_labels): -                        raise ValueError('Missing {0} axes.\nExpected{1} got {2}'\ -                            .format( -                                len(dimension_labels) - len(shape), -                                dimension_labels, shape)  -                            ) -                    shape = tuple(shape) +                shape, dimension_labels = self.get_shape_labels(geometry, dimension_labels) +                                  array = numpy.zeros( shape , dtype=numpy.float32)                   super(AcquisitionData, self).__init__(array, deep_copy,                                   dimension_labels, **kwargs)          else: -             +            if self.geometry is not None: +                shape, labels = self.get_shape_labels(self.geometry, dimension_labels) +                if array.shape != shape: +                    raise ValueError('Shape mismatch {} {}'.format(shape, array.shape)) +                                  if issubclass(type(array) ,DataContainer):                  # if the array is a DataContainer get the info from there                  if not ( array.number_of_dimensions == 2 or \ @@ -982,19 +951,78 @@ class AcquisitionData(DataContainer):                  if dimension_labels is None:                      if array.ndim == 4: -                        dimension_labels = ['channel' ,'angle' , 'vertical' ,  -                                      'horizontal'] +                        dimension_labels = [AcquisitionGeometry.CHANNEL, +                                            AcquisitionGeometry.ANGLE, +                                            AcquisitionGeometry.VERTICAL, +                                            AcquisitionGeometry.HORIZONTAL]                      elif array.ndim == 3: -                        dimension_labels = ['angle' , 'vertical' ,  -                                      'horizontal'] +                        dimension_labels = [AcquisitionGeometry.ANGLE, +                                            AcquisitionGeometry.VERTICAL, +                                            AcquisitionGeometry.HORIZONTAL]                      else: -                        dimension_labels = ['angle' ,  -                                      'horizontal']    -                 -                #DataContainer.__init__(self, array, deep_copy, dimension_labels, **kwargs) +                        dimension_labels = [AcquisitionGeometry.ANGLE, +                                            AcquisitionGeometry.HORIZONTAL] +                  super(AcquisitionData, self).__init__(array, deep_copy,                        dimension_labels, **kwargs) +    def get_shape_labels(self, geometry, dimension_labels=None): +        channels      = geometry.channels +        horiz         = geometry.pixel_num_h +        vert          = geometry.pixel_num_v +        angles        = geometry.angles +        num_of_angles = numpy.shape(angles)[0] +         +        if dimension_labels is None: +            if channels > 1: +                if vert > 1: +                    shape = (channels, num_of_angles , vert, horiz) +                    dim_labels = [AcquisitionGeometry.CHANNEL, +                                  AcquisitionGeometry.ANGLE, +                                  AcquisitionGeometry.VERTICAL, +                                  AcquisitionGeometry.HORIZONTAL] +                else: +                    shape = (channels , num_of_angles, horiz) +                    dim_labels = [AcquisitionGeometry.CHANNEL, +                                  AcquisitionGeometry.ANGLE, +                                  AcquisitionGeometry.HORIZONTAL] +            else: +                if vert > 1: +                    shape = (num_of_angles, vert, horiz) +                    dim_labels = [AcquisitionGeometry.ANGLE, +                                  AcquisitionGeometry.VERTICAL, +                                  AcquisitionGeometry.HORIZONTAL +                                  ] +                else: +                    shape = (num_of_angles, horiz) +                    dim_labels = [AcquisitionGeometry.ANGLE, +                                  AcquisitionGeometry.HORIZONTAL +                                  ] +             +            dimension_labels = dim_labels +        else: +            shape = [] +            for i in range(len(dimension_labels)): +                dim = dimension_labels[i] +                 +                if dim == AcquisitionGeometry.CHANNEL: +                    shape.append(channels) +                elif dim == AcquisitionGeometry.ANGLE: +                    shape.append(num_of_angles) +                elif dim == AcquisitionGeometry.VERTICAL: +                    shape.append(vert) +                elif dim == AcquisitionGeometry.HORIZONTAL: +                    shape.append(horiz) +            if len(shape) != len(dimension_labels): +                raise ValueError('Missing {0} axes.\nExpected{1} got {2}'\ +                    .format( +                        len(dimension_labels) - len(shape), +                        dimension_labels, shape)  +                    ) +            shape = tuple(shape) +        return (shape, dimension_labels) +     +                  class DataProcessor(object): diff --git a/Wrappers/Python/ccpi/io/reader.py b/Wrappers/Python/ccpi/io/reader.py index 856f5e0..07e3bf9 100644 --- a/Wrappers/Python/ccpi/io/reader.py +++ b/Wrappers/Python/ccpi/io/reader.py @@ -241,26 +241,37 @@ class NexusReader(object):                      pass
                  dims = file[self.data_path].shape
                  if ymin is None and ymax is None:
 -                    data = np.array(file[self.data_path])
 +                    
 +                    try:
 +                        image_keys = self.get_image_keys()
 +                        print ("image_keys", image_keys)
 +                        projections = np.array(file[self.data_path])
 +                        data = projections[image_keys==0]
 +                    except KeyError as ke:
 +                        print (ke)
 +                        data = np.array(file[self.data_path])
 +                    
                  else:
 +                    image_keys = self.get_image_keys()
 +                    print ("image_keys", image_keys)
 +                    projections = np.array(file[self.data_path])[image_keys==0]
                      if ymin is None:
                          ymin = 0
                          if ymax > dims[1]:
                              raise ValueError('ymax out of range')
 -                        data = np.array(file[self.data_path][:,:ymax,:])
 +                        data = projections[:,:ymax,:]
                      elif ymax is None:        
                          ymax = dims[1]
                          if ymin < 0:
                              raise ValueError('ymin out of range')
 -                        data = np.array(file[self.data_path][:,ymin:,:])
 +                        data = projections[:,ymin:,:]
                      else:
                          if ymax > dims[1]:
                              raise ValueError('ymax out of range')
                          if ymin < 0:
                              raise ValueError('ymin out of range')
 -                        data = np.array(file[self.data_path]
 -                            [: , ymin:ymax , :] )
 +                        data = projections[: , ymin:ymax , :] 
          except:
              print("Error reading nexus file")
 diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py index 7194eb8..e65bc89 100755 --- a/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/CGLS.py @@ -23,7 +23,6 @@ Created on Thu Feb 21 11:11:23 2019  """  from ccpi.optimisation.algorithms import Algorithm -#from collections.abc import Iterable  class CGLS(Algorithm):      '''Conjugate Gradient Least Squares algorithm @@ -84,4 +83,4 @@ class CGLS(Algorithm):          self.d = s + beta*self.d      def update_objective(self): -        self.loss.append(self.r.squared_norm())
\ No newline at end of file +        self.loss.append(self.r.squared_norm()) diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/FBPD.py b/Wrappers/Python/ccpi/optimisation/algorithms/FBPD.py index 445ba7a..aa07359 100644 --- a/Wrappers/Python/ccpi/optimisation/algorithms/FBPD.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/FBPD.py @@ -23,7 +23,7 @@ Created on Thu Feb 21 11:09:03 2019  """  from ccpi.optimisation.algorithms import Algorithm -from ccpi.optimisation.functions import ZeroFun +from ccpi.optimisation.functions import ZeroFunction  class FBPD(Algorithm):      '''FBPD Algorithm diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py index 93ba178..064cb33 100755 --- a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py @@ -6,7 +6,7 @@ Created on Thu Feb 21 11:07:30 2019  """  from ccpi.optimisation.algorithms import Algorithm -from ccpi.optimisation.functions import ZeroFun +from ccpi.optimisation.functions import ZeroFunction  import numpy  class FISTA(Algorithm): @@ -46,11 +46,11 @@ class FISTA(Algorithm):          # default inputs          if f   is None:  -            self.f = ZeroFun() +            self.f = ZeroFunction()          else:              self.f = f          if g   is None: -            g = ZeroFun() +            g = ZeroFunction()              self.g = g          else:              self.g = g diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py b/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py index f1e4132..14763c5 100755 --- a/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/GradientDescent.py @@ -73,4 +73,4 @@ class GradientDescent(Algorithm):      def update_objective(self):          self.loss.append(self.objective_function(self.x)) -        
\ No newline at end of file +         diff --git a/Wrappers/Python/ccpi/optimisation/algs.py b/Wrappers/Python/ccpi/optimisation/algs.py index 6b6ae2c..c142eda 100755 --- a/Wrappers/Python/ccpi/optimisation/algs.py +++ b/Wrappers/Python/ccpi/optimisation/algs.py @@ -21,7 +21,7 @@ import numpy  import time  from ccpi.optimisation.functions import Function -from ccpi.optimisation.functions import ZeroFun +from ccpi.optimisation.functions import ZeroFunction  from ccpi.framework import ImageData   from ccpi.framework import AcquisitionData  from ccpi.optimisation.spdhg import spdhg  diff --git a/Wrappers/Python/test/test_DataContainer.py b/Wrappers/Python/test/test_DataContainer.py index 8edfd8b..40cd244 100755 --- a/Wrappers/Python/test/test_DataContainer.py +++ b/Wrappers/Python/test/test_DataContainer.py @@ -494,6 +494,14 @@ class TestDataContainer(unittest.TestCase):          self.assertEqual(order[0], image.dimension_labels[0])          self.assertEqual(order[1], image.dimension_labels[1])          self.assertEqual(order[2], image.dimension_labels[2]) +         +        ig = ImageGeometry(2,3,2) +        try: +            z = ImageData(numpy.random.randint(10, size=(2,3)), geometry=ig) +            self.assertTrue(False) +        except ValueError as ve: +            print (ve) +            self.assertTrue(True)          #vgeometry.allocate('')      def test_AcquisitionGeometry_allocate(self): @@ -525,6 +533,14 @@ class TestDataContainer(unittest.TestCase):          self.assertEqual(order[1], sino.dimension_labels[1])          self.assertEqual(order[2], sino.dimension_labels[2])          self.assertEqual(order[2], sino.dimension_labels[2]) +         +         +        try: +            z = AcquisitionData(numpy.random.randint(10, size=(2,3)), geometry=ageometry) +            self.assertTrue(False) +        except ValueError as ve: +            print (ve) +            self.assertTrue(True)      def assertNumpyArrayEqual(self, first, second):          res = True diff --git a/Wrappers/Python/test/test_functions.py b/Wrappers/Python/test/test_functions.py index 22721fa..05bdd7a 100644 --- a/Wrappers/Python/test/test_functions.py +++ b/Wrappers/Python/test/test_functions.py @@ -26,7 +26,7 @@ from ccpi.optimisation.funcs import Norm2sq  # from ccpi.optimisation.functions.L2NormSquared import SimpleL2NormSq, L2NormSq  # from ccpi.optimisation.functions.L1Norm import SimpleL1Norm, L1Norm  #from ccpi.optimisation.functions import mixed_L12Norm -from ccpi.optimisation.functions import ZeroFun +from ccpi.optimisation.functions import ZeroFunction  from ccpi.optimisation.functions import FunctionOperatorComposition  import unittest @@ -329,7 +329,7 @@ class TestFunction(unittest.TestCase):          a1 = f_no_scaled(U)          a2 = f_scaled(U) -        self.assertAlmostEqual(a1,a2) +        self.assertNumpyArrayAlmostEqual(a1.as_array(),a2.as_array())          tmp = [ el**2 for el in U.containers ] diff --git a/Wrappers/Python/test/test_run_test.py b/Wrappers/Python/test/test_run_test.py index 8cef925..c698032 100755 --- a/Wrappers/Python/test/test_run_test.py +++ b/Wrappers/Python/test/test_run_test.py @@ -6,10 +6,10 @@ from ccpi.framework import ImageData  from ccpi.framework import AcquisitionData  from ccpi.framework import ImageGeometry  from ccpi.framework import AcquisitionGeometry -from ccpi.optimisation.algs import FISTA -from ccpi.optimisation.algs import FBPD +from ccpi.optimisation.algorithms import FISTA +#from ccpi.optimisation.algs import FBPD  from ccpi.optimisation.funcs import Norm2sq -from ccpi.optimisation.functions import ZeroFun +from ccpi.optimisation.functions import ZeroFunction  from ccpi.optimisation.funcs import Norm1  from ccpi.optimisation.funcs import TV2D  from ccpi.optimisation.funcs import Norm2 @@ -82,7 +82,7 @@ class TestAlgorithms(unittest.TestCase):                  opt = {'memopt': True}                  # Create object instances with the test data A and b.                  f = Norm2sq(A, b, c=0.5, memopt=True) -                g0 = ZeroFun() +                g0 = ZeroFunction()                  # Initial guess                  x_init = DataContainer(np.zeros((n, 1))) @@ -90,12 +90,15 @@ class TestAlgorithms(unittest.TestCase):                  f.grad(x_init)                  # Run FISTA for least squares plus zero function. -                x_fista0, it0, timing0, criter0 = FISTA(x_init, f, g0, opt=opt) - +                #x_fista0, it0, timing0, criter0 = FISTA(x_init, f, g0, opt=opt) +                fa = FISTA(x_init=x_init, f=f, g=g0) +                fa.max_iteration = 10 +                fa.run(10) +                                  # Print solution and final objective/criterion value for comparison                  print("FISTA least squares plus zero function solution and objective value:") -                print(x_fista0.array) -                print(criter0[-1]) +                print(fa.get_output()) +                print(fa.get_last_objective())                  # Compare to CVXPY @@ -143,7 +146,7 @@ class TestAlgorithms(unittest.TestCase):                  opt = {'memopt': True}                  # Create object instances with the test data A and b.                  f = Norm2sq(A, b, c=0.5, memopt=True) -                g0 = ZeroFun() +                g0 = ZeroFunction()                  # Initial guess                  x_init = DataContainer(np.zeros((n, 1))) @@ -155,12 +158,16 @@ class TestAlgorithms(unittest.TestCase):                  g1.prox(x_init, 0.02)                  # Combine with least squares and solve using generic FISTA implementation -                x_fista1, it1, timing1, criter1 = FISTA(x_init, f, g1, opt=opt) +                #x_fista1, it1, timing1, criter1 = FISTA(x_init, f, g1, opt=opt) +                fa = FISTA(x_init=x_init, f=f, g=g1) +                fa.max_iteration = 10 +                fa.run(10) +                                  # Print for comparison                  print("FISTA least squares plus 1-norm solution and objective value:") -                print(x_fista1.as_array().squeeze()) -                print(criter1[-1]) +                print(fa.get_output()) +                print(fa.get_last_objective())                  # Compare to CVXPY | 
