summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-03-21 15:16:39 +0000
committerEdoardo Pasca <edo.paskino@gmail.com>2019-03-21 15:16:39 +0000
commit1b34498aaa93b95925991258fe542b62a9155aff (patch)
treed1e8e8e00c2df27525bee0f6747788eb27b5b96f /Wrappers/Python
parent6c6c8474bd869467e7f381c21cef195fc7250045 (diff)
downloadframework-1b34498aaa93b95925991258fe542b62a9155aff.tar.gz
framework-1b34498aaa93b95925991258fe542b62a9155aff.tar.bz2
framework-1b34498aaa93b95925991258fe542b62a9155aff.tar.xz
framework-1b34498aaa93b95925991258fe542b62a9155aff.zip
BlockDataContainer can do algebra with DataContainers
Diffstat (limited to 'Wrappers/Python')
-rwxr-xr-xWrappers/Python/ccpi/framework/BlockDataContainer.py27
-rwxr-xr-xWrappers/Python/test/test_BlockDataContainer.py30
2 files changed, 49 insertions, 8 deletions
diff --git a/Wrappers/Python/ccpi/framework/BlockDataContainer.py b/Wrappers/Python/ccpi/framework/BlockDataContainer.py
index 358ba2d..f29f839 100755
--- a/Wrappers/Python/ccpi/framework/BlockDataContainer.py
+++ b/Wrappers/Python/ccpi/framework/BlockDataContainer.py
@@ -12,6 +12,7 @@ from __future__ import unicode_literals
import numpy
from numbers import Number
import functools
+from ccpi.framework import DataContainer
#from ccpi.framework import AcquisitionData, ImageData
#from ccpi.optimisation.operators import Operator, LinearOperator
@@ -64,6 +65,8 @@ class BlockDataContainer(object):
return len(self.containers) == len(other)
elif isinstance(other, numpy.ndarray):
return self.shape == other.shape
+ elif issubclass(other.__class__, DataContainer):
+ return self.get_item(0).shape == other.shape
return len(self.containers) == len(other.containers)
def get_item(self, row):
@@ -75,24 +78,33 @@ class BlockDataContainer(object):
return self.get_item(row)
def add(self, other, *args, **kwargs):
- assert self.is_compatible(other)
+ if not self.is_compatible(other):
+ raise ValueError('Incompatible for add')
out = kwargs.get('out', None)
#print ("args" , *args)
if isinstance(other, Number):
return type(self)(*[ el.add(other, *args, **kwargs) for el in self.containers], shape=self.shape)
elif isinstance(other, list) or isinstance(other, numpy.ndarray):
- return type(self)(*[ el.add(ot, out, *args, **kwargs) for el,ot in zip(self.containers,other)], shape=self.shape)
+ return type(self)(*[ el.add(ot, *args, **kwargs) for el,ot in zip(self.containers,other)], shape=self.shape)
+ elif issubclass(other.__class__, DataContainer):
+ # try to do algebra with one DataContainer. Will raise error if not compatible
+ return type(self)(*[ el.add(other, *args, **kwargs) for el in self.containers], shape=self.shape)
+
return type(self)(
*[ el.add(ot, out, *args, **kwargs) for el,ot in zip(self.containers,other.containers)],
shape=self.shape)
def subtract(self, other, *args, **kwargs):
- assert self.is_compatible(other)
+ if not self.is_compatible(other):
+ raise ValueError('Incompatible for add')
out = kwargs.get('out', None)
if isinstance(other, Number):
return type(self)(*[ el.subtract(other, out, *args, **kwargs) for el in self.containers], shape=self.shape)
elif isinstance(other, list) or isinstance(other, numpy.ndarray):
return type(self)(*[ el.subtract(ot, out, *args, **kwargs) for el,ot in zip(self.containers,other)], shape=self.shape)
+ elif issubclass(other.__class__, DataContainer):
+ # try to do algebra with one DataContainer. Will raise error if not compatible
+ return type(self)(*[ el.subtract(other, *args, **kwargs) for el in self.containers], shape=self.shape)
return type(self)(*[ el.subtract(ot, out, *args, **kwargs) for el,ot in zip(self.containers,other.containers)],
shape=self.shape)
@@ -105,6 +117,9 @@ class BlockDataContainer(object):
return type(self)(*[ el.multiply(ot, *args, **kwargs) for el,ot in zip(self.containers,other)], shape=self.shape)
elif isinstance(other, numpy.ndarray):
return type(self)(*[ el.multiply(ot, *args, **kwargs) for el,ot in zip(self.containers,other)], shape=self.shape)
+ elif issubclass(other.__class__, DataContainer):
+ # try to do algebra with one DataContainer. Will raise error if not compatible
+ return type(self)(*[ el.multiply(other, *args, **kwargs) for el in self.containers], shape=self.shape)
return type(self)(*[ el.multiply(ot, *args, **kwargs) for el,ot in zip(self.containers,other.containers)],
shape=self.shape)
@@ -115,6 +130,9 @@ class BlockDataContainer(object):
return type(self)(*[ el.divide(other, *args, **kwargs) for el in self.containers], shape=self.shape)
elif isinstance(other, list) or isinstance(other, numpy.ndarray):
return type(self)(*[ el.divide(ot, *args, **kwargs) for el,ot in zip(self.containers,other)], shape=self.shape)
+ elif issubclass(other.__class__, DataContainer):
+ # try to do algebra with one DataContainer. Will raise error if not compatible
+ return type(self)(*[ el.divide(other, *args, **kwargs) for el in self.containers], shape=self.shape)
return type(self)(*[ el.divide(ot, *args, **kwargs) for el,ot in zip(self.containers,other.containers)],
shape=self.shape)
@@ -138,13 +156,10 @@ class BlockDataContainer(object):
## unary operations
def abs(self, *args, **kwargs):
- out = kwargs.get('out', None)
return type(self)(*[ el.abs(*args, **kwargs) for el in self.containers], shape=self.shape)
def sign(self, *args, **kwargs):
- out = kwargs.get('out', None)
return type(self)(*[ el.sign(*args, **kwargs) for el in self.containers], shape=self.shape)
def sqrt(self, *args, **kwargs):
- out = kwargs.get('out', None)
return type(self)(*[ el.sqrt(*args, **kwargs) for el in self.containers], shape=self.shape)
def conjugate(self, out=None):
return type(self)(*[el.conjugate() for el in self.containers], shape=self.shape)
diff --git a/Wrappers/Python/test/test_BlockDataContainer.py b/Wrappers/Python/test/test_BlockDataContainer.py
index 6c0bede..51d07fa 100755
--- a/Wrappers/Python/test/test_BlockDataContainer.py
+++ b/Wrappers/Python/test/test_BlockDataContainer.py
@@ -95,7 +95,7 @@ class TestBlockDataContainer(unittest.TestCase):
def test_BlockDataContainer(self):
print ("test block data container")
ig0 = ImageGeometry(2,3,4)
- ig1 = ImageGeometry(2,3,4)
+ ig1 = ImageGeometry(2,3,5)
data0 = ImageData(geometry=ig0)
data1 = ImageData(geometry=ig1) + 1
@@ -105,7 +105,33 @@ class TestBlockDataContainer(unittest.TestCase):
cp0 = BlockDataContainer(data0,data1)
cp1 = BlockDataContainer(data2,data3)
- #
+
+ cp2 = BlockDataContainer(data0+1, data2+1)
+ d = cp2 + data0
+ self.assertEqual(d.get_item(0).as_array()[0][0][0], 1)
+ try:
+ d = cp2 + data1
+ self.assertTrue(False)
+ except ValueError as ve:
+ print (ve)
+ self.assertTrue(True)
+ d = cp2 - data0
+ self.assertEqual(d.get_item(0).as_array()[0][0][0], 1)
+ try:
+ d = cp2 - data1
+ self.assertTrue(False)
+ except ValueError as ve:
+ print (ve)
+ self.assertTrue(True)
+ d = cp2 * data2
+ self.assertEqual(d.get_item(0).as_array()[0][0][0], 2)
+ try:
+ d = cp2 * data1
+ self.assertTrue(False)
+ except ValueError as ve:
+ print (ve)
+ self.assertTrue(True)
+
a = [ (el, ot) for el,ot in zip(cp0.containers,cp1.containers)]
print (a[0][0].shape)
#cp2 = BlockDataContainer(*a)