diff options
| author | Edoardo Pasca <edo.paskino@gmail.com> | 2019-04-11 17:13:08 +0100 | 
|---|---|---|
| committer | Edoardo Pasca <edo.paskino@gmail.com> | 2019-04-11 17:13:08 +0100 | 
| commit | 10aae87e1416d291906b94927acb4aac5737a44e (patch) | |
| tree | 557c9337eb2f0a09f9879c7a85b68a96a65bc34f | |
| parent | 6ce64e15b13cf7c6ae55cf9bc891980679268ac4 (diff) | |
| download | framework-10aae87e1416d291906b94927acb4aac5737a44e.tar.gz framework-10aae87e1416d291906b94927acb4aac5737a44e.tar.bz2 framework-10aae87e1416d291906b94927acb4aac5737a44e.tar.xz framework-10aae87e1416d291906b94927acb4aac5737a44e.zip | |
fixing algebra with nested block data containers
| -rwxr-xr-x | Wrappers/Python/ccpi/framework/BlockDataContainer.py | 45 | 
1 files changed, 40 insertions, 5 deletions
| diff --git a/Wrappers/Python/ccpi/framework/BlockDataContainer.py b/Wrappers/Python/ccpi/framework/BlockDataContainer.py index 13663c2..85cd05a 100755 --- a/Wrappers/Python/ccpi/framework/BlockDataContainer.py +++ b/Wrappers/Python/ccpi/framework/BlockDataContainer.py @@ -53,9 +53,9 @@ class BlockDataContainer(object):      def is_compatible(self, other):
          '''basic check if the size of the 2 objects fit'''
 -        for i in range(len(self.containers)):
 -            if type(self.containers[i])==type(self):
 -                self = self.containers[i]
 +        #for i in range(len(self.containers)):
 +        #    if type(self.containers[i])==type(self):
 +        #        self = self.containers[i]
          if isinstance(other, Number):
              return True   
 @@ -71,7 +71,16 @@ class BlockDataContainer(object):          elif isinstance(other, numpy.ndarray):
              return len(self.containers) == len(other)
          elif issubclass(other.__class__, DataContainer):
 -            return self.get_item(0).shape == other.shape
 +            ret = True
 +            for i, el in enumerate(self.containers):
 +                if isinstance(el, BlockDataContainer):
 +                    a = el.is_compatible(other)
 +                else:
 +                    a = el.shape == other.shape
 +                print ("current element" , el.shape, "other ", other.shape, "same shape" , a)
 +                ret = ret and a
 +            return ret
 +            #return self.get_item(0).shape == other.shape
          return len(self.containers) == len(other.containers)
      def get_item(self, row):
 @@ -139,10 +148,36 @@ class BlockDataContainer(object):              return type(self)(*[ el.divide(ot, *args, **kwargs) for el,ot in zip(self.containers,other)], shape=self.shape)
          elif issubclass(other.__class__, DataContainer):
              # try to do algebra with one DataContainer. Will raise error if not compatible
 -            return type(self)(*[ el.divide(other, *args, **kwargs) for el in self.containers], shape=self.shape)
 +            if out is not None:
 +                kw = kwargs.copy()
 +                for i,el in enumerate(self.containers):
 +                    kw['out'] = out.get_item(i)
 +                    el.divide(other, *args, **kw)
 +                return
 +            else:
 +                return type(self)(*[ el.divide(other, *args, **kwargs) for el in self.containers], shape=self.shape)
 +        return type(self)(*[ el.divide(ot, *args, **kwargs) for el,ot in zip(self.containers,other.containers)],
 +                          shape=self.shape)
 +    def binary_operations(self, operation, other, *args, **kwargs):
 +        if not self.is_compatible(other):
 +            raise ValueError('Incompatible for divide')
 +        out = kwargs.get('out', None)
 +        if isinstance(other, Number) or issubclass(other.__class__, DataContainer):
 +            # try to do algebra with one DataContainer. Will raise error if not compatible
 +            if out is not None:
 +                kw = kwargs.copy()
 +                for i,el in enumerate(self.containers):
 +                    kw['out'] = out.get_item(i)
 +                    el.divide(other, *args, **kw)
 +                return
 +            else:
 +                return type(self)(*[ el.divide(other, *args, **kwargs) for el in self.containers], shape=self.shape)
 +        elif isinstance(other, list) or isinstance(other, numpy.ndarray):
 +            return type(self)(*[ el.divide(ot, *args, **kwargs) for el,ot in zip(self.containers,other)], shape=self.shape)
          return type(self)(*[ el.divide(ot, *args, **kwargs) for el,ot in zip(self.containers,other.containers)],
                            shape=self.shape)
 +
      def power(self, other, *args, **kwargs):
          if not self.is_compatible(other):
              raise ValueError('Incompatible for power')
 | 
