From 10aae87e1416d291906b94927acb4aac5737a44e Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Thu, 11 Apr 2019 17:13:08 +0100
Subject: fixing algebra with nested block data containers

---
 .../Python/ccpi/framework/BlockDataContainer.py    | 45 +++++++++++++++++++---
 1 file changed, 40 insertions(+), 5 deletions(-)

diff --git a/Wrappers/Python/ccpi/framework/BlockDataContainer.py b/Wrappers/Python/ccpi/framework/BlockDataContainer.py
index 13663c2..85cd05a 100755
--- a/Wrappers/Python/ccpi/framework/BlockDataContainer.py
+++ b/Wrappers/Python/ccpi/framework/BlockDataContainer.py
@@ -53,9 +53,9 @@ class BlockDataContainer(object):
     def is_compatible(self, other):
         '''basic check if the size of the 2 objects fit'''
         
-        for i in range(len(self.containers)):
-            if type(self.containers[i])==type(self):
-                self = self.containers[i]
+        #for i in range(len(self.containers)):
+        #    if type(self.containers[i])==type(self):
+        #        self = self.containers[i]
         
         if isinstance(other, Number):
             return True   
@@ -71,7 +71,16 @@ class BlockDataContainer(object):
         elif isinstance(other, numpy.ndarray):
             return len(self.containers) == len(other)
         elif issubclass(other.__class__, DataContainer):
-            return self.get_item(0).shape == other.shape
+            ret = True
+            for i, el in enumerate(self.containers):
+                if isinstance(el, BlockDataContainer):
+                    a = el.is_compatible(other)
+                else:
+                    a = el.shape == other.shape
+                print ("current element" , el.shape, "other ", other.shape, "same shape" , a)
+                ret = ret and a
+            return ret
+            #return self.get_item(0).shape == other.shape
         return len(self.containers) == len(other.containers)
 
     def get_item(self, row):
@@ -139,10 +148,36 @@ class BlockDataContainer(object):
             return type(self)(*[ el.divide(ot, *args, **kwargs) for el,ot in zip(self.containers,other)], shape=self.shape)
         elif issubclass(other.__class__, DataContainer):
             # try to do algebra with one DataContainer. Will raise error if not compatible
-            return type(self)(*[ el.divide(other, *args, **kwargs) for el in self.containers], shape=self.shape)
+            if out is not None:
+                kw = kwargs.copy()
+                for i,el in enumerate(self.containers):
+                    kw['out'] = out.get_item(i)
+                    el.divide(other, *args, **kw)
+                return
+            else:
+                return type(self)(*[ el.divide(other, *args, **kwargs) for el in self.containers], shape=self.shape)
+        return type(self)(*[ el.divide(ot, *args, **kwargs) for el,ot in zip(self.containers,other.containers)],
+                          shape=self.shape)
+    def binary_operations(self, operation, other, *args, **kwargs):
+        if not self.is_compatible(other):
+            raise ValueError('Incompatible for divide')
+        out = kwargs.get('out', None)
+        if isinstance(other, Number) or issubclass(other.__class__, DataContainer):
+            # try to do algebra with one DataContainer. Will raise error if not compatible
+            if out is not None:
+                kw = kwargs.copy()
+                for i,el in enumerate(self.containers):
+                    kw['out'] = out.get_item(i)
+                    el.divide(other, *args, **kw)
+                return
+            else:
+                return type(self)(*[ el.divide(other, *args, **kwargs) for el in self.containers], shape=self.shape)
+        elif isinstance(other, list) or isinstance(other, numpy.ndarray):
+            return type(self)(*[ el.divide(ot, *args, **kwargs) for el,ot in zip(self.containers,other)], shape=self.shape)
         return type(self)(*[ el.divide(ot, *args, **kwargs) for el,ot in zip(self.containers,other.containers)],
                           shape=self.shape)
     
+
     def power(self, other, *args, **kwargs):
         if not self.is_compatible(other):
             raise ValueError('Incompatible for power')
-- 
cgit v1.2.3