summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-04-11 17:13:08 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2019-04-11 17:13:08 +0100
commit10aae87e1416d291906b94927acb4aac5737a44e (patch)
tree557c9337eb2f0a09f9879c7a85b68a96a65bc34f /Wrappers/Python
parent6ce64e15b13cf7c6ae55cf9bc891980679268ac4 (diff)
downloadframework-10aae87e1416d291906b94927acb4aac5737a44e.tar.gz
framework-10aae87e1416d291906b94927acb4aac5737a44e.tar.bz2
framework-10aae87e1416d291906b94927acb4aac5737a44e.tar.xz
framework-10aae87e1416d291906b94927acb4aac5737a44e.zip
fixing algebra with nested block data containers
Diffstat (limited to 'Wrappers/Python')
-rwxr-xr-xWrappers/Python/ccpi/framework/BlockDataContainer.py45
1 files changed, 40 insertions, 5 deletions
diff --git a/Wrappers/Python/ccpi/framework/BlockDataContainer.py b/Wrappers/Python/ccpi/framework/BlockDataContainer.py
index 13663c2..85cd05a 100755
--- a/Wrappers/Python/ccpi/framework/BlockDataContainer.py
+++ b/Wrappers/Python/ccpi/framework/BlockDataContainer.py
@@ -53,9 +53,9 @@ class BlockDataContainer(object):
def is_compatible(self, other):
'''basic check if the size of the 2 objects fit'''
- for i in range(len(self.containers)):
- if type(self.containers[i])==type(self):
- self = self.containers[i]
+ #for i in range(len(self.containers)):
+ # if type(self.containers[i])==type(self):
+ # self = self.containers[i]
if isinstance(other, Number):
return True
@@ -71,7 +71,16 @@ class BlockDataContainer(object):
elif isinstance(other, numpy.ndarray):
return len(self.containers) == len(other)
elif issubclass(other.__class__, DataContainer):
- return self.get_item(0).shape == other.shape
+ ret = True
+ for i, el in enumerate(self.containers):
+ if isinstance(el, BlockDataContainer):
+ a = el.is_compatible(other)
+ else:
+ a = el.shape == other.shape
+ print ("current element" , el.shape, "other ", other.shape, "same shape" , a)
+ ret = ret and a
+ return ret
+ #return self.get_item(0).shape == other.shape
return len(self.containers) == len(other.containers)
def get_item(self, row):
@@ -139,10 +148,36 @@ class BlockDataContainer(object):
return type(self)(*[ el.divide(ot, *args, **kwargs) for el,ot in zip(self.containers,other)], shape=self.shape)
elif issubclass(other.__class__, DataContainer):
# try to do algebra with one DataContainer. Will raise error if not compatible
- return type(self)(*[ el.divide(other, *args, **kwargs) for el in self.containers], shape=self.shape)
+ if out is not None:
+ kw = kwargs.copy()
+ for i,el in enumerate(self.containers):
+ kw['out'] = out.get_item(i)
+ el.divide(other, *args, **kw)
+ return
+ else:
+ return type(self)(*[ el.divide(other, *args, **kwargs) for el in self.containers], shape=self.shape)
+ return type(self)(*[ el.divide(ot, *args, **kwargs) for el,ot in zip(self.containers,other.containers)],
+ shape=self.shape)
+ def binary_operations(self, operation, other, *args, **kwargs):
+ if not self.is_compatible(other):
+ raise ValueError('Incompatible for divide')
+ out = kwargs.get('out', None)
+ if isinstance(other, Number) or issubclass(other.__class__, DataContainer):
+ # try to do algebra with one DataContainer. Will raise error if not compatible
+ if out is not None:
+ kw = kwargs.copy()
+ for i,el in enumerate(self.containers):
+ kw['out'] = out.get_item(i)
+ el.divide(other, *args, **kw)
+ return
+ else:
+ return type(self)(*[ el.divide(other, *args, **kwargs) for el in self.containers], shape=self.shape)
+ elif isinstance(other, list) or isinstance(other, numpy.ndarray):
+ return type(self)(*[ el.divide(ot, *args, **kwargs) for el,ot in zip(self.containers,other)], shape=self.shape)
return type(self)(*[ el.divide(ot, *args, **kwargs) for el,ot in zip(self.containers,other.containers)],
shape=self.shape)
+
def power(self, other, *args, **kwargs):
if not self.is_compatible(other):
raise ValueError('Incompatible for power')