From 10aae87e1416d291906b94927acb4aac5737a44e Mon Sep 17 00:00:00 2001 From: Edoardo Pasca <edo.paskino@gmail.com> Date: Thu, 11 Apr 2019 17:13:08 +0100 Subject: fixing algebra with nested block data containers --- .../Python/ccpi/framework/BlockDataContainer.py | 45 +++++++++++++++++++--- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/Wrappers/Python/ccpi/framework/BlockDataContainer.py b/Wrappers/Python/ccpi/framework/BlockDataContainer.py index 13663c2..85cd05a 100755 --- a/Wrappers/Python/ccpi/framework/BlockDataContainer.py +++ b/Wrappers/Python/ccpi/framework/BlockDataContainer.py @@ -53,9 +53,9 @@ class BlockDataContainer(object): def is_compatible(self, other): '''basic check if the size of the 2 objects fit''' - for i in range(len(self.containers)): - if type(self.containers[i])==type(self): - self = self.containers[i] + #for i in range(len(self.containers)): + # if type(self.containers[i])==type(self): + # self = self.containers[i] if isinstance(other, Number): return True @@ -71,7 +71,16 @@ class BlockDataContainer(object): elif isinstance(other, numpy.ndarray): return len(self.containers) == len(other) elif issubclass(other.__class__, DataContainer): - return self.get_item(0).shape == other.shape + ret = True + for i, el in enumerate(self.containers): + if isinstance(el, BlockDataContainer): + a = el.is_compatible(other) + else: + a = el.shape == other.shape + print ("current element" , el.shape, "other ", other.shape, "same shape" , a) + ret = ret and a + return ret + #return self.get_item(0).shape == other.shape return len(self.containers) == len(other.containers) def get_item(self, row): @@ -139,10 +148,36 @@ class BlockDataContainer(object): return type(self)(*[ el.divide(ot, *args, **kwargs) for el,ot in zip(self.containers,other)], shape=self.shape) elif issubclass(other.__class__, DataContainer): # try to do algebra with one DataContainer. Will raise error if not compatible - return type(self)(*[ el.divide(other, *args, **kwargs) for el in self.containers], shape=self.shape) + if out is not None: + kw = kwargs.copy() + for i,el in enumerate(self.containers): + kw['out'] = out.get_item(i) + el.divide(other, *args, **kw) + return + else: + return type(self)(*[ el.divide(other, *args, **kwargs) for el in self.containers], shape=self.shape) + return type(self)(*[ el.divide(ot, *args, **kwargs) for el,ot in zip(self.containers,other.containers)], + shape=self.shape) + def binary_operations(self, operation, other, *args, **kwargs): + if not self.is_compatible(other): + raise ValueError('Incompatible for divide') + out = kwargs.get('out', None) + if isinstance(other, Number) or issubclass(other.__class__, DataContainer): + # try to do algebra with one DataContainer. Will raise error if not compatible + if out is not None: + kw = kwargs.copy() + for i,el in enumerate(self.containers): + kw['out'] = out.get_item(i) + el.divide(other, *args, **kw) + return + else: + return type(self)(*[ el.divide(other, *args, **kwargs) for el in self.containers], shape=self.shape) + elif isinstance(other, list) or isinstance(other, numpy.ndarray): + return type(self)(*[ el.divide(ot, *args, **kwargs) for el,ot in zip(self.containers,other)], shape=self.shape) return type(self)(*[ el.divide(ot, *args, **kwargs) for el,ot in zip(self.containers,other.containers)], shape=self.shape) + def power(self, other, *args, **kwargs): if not self.is_compatible(other): raise ValueError('Incompatible for power') -- cgit v1.2.3