diff options
author | Edoardo Pasca <edo.paskino@gmail.com> | 2019-05-01 16:14:02 +0100 |
---|---|---|
committer | Edoardo Pasca <edo.paskino@gmail.com> | 2019-05-01 16:14:02 +0100 |
commit | da581e6061ebe206e007fe4378cc9d449b67d791 (patch) | |
tree | 479729d5e4c48583f15d8ac86dcfaecf78f4d020 | |
parent | 6d1c24c8fc389365f2dd83ee480c63969b08ce9f (diff) | |
download | framework-da581e6061ebe206e007fe4378cc9d449b67d791.tar.gz framework-da581e6061ebe206e007fe4378cc9d449b67d791.tar.bz2 framework-da581e6061ebe206e007fe4378cc9d449b67d791.tar.xz framework-da581e6061ebe206e007fe4378cc9d449b67d791.zip |
reimplements dot product
following discussion in #273 an implementation of the dot product is
made where we rely on Python to choose an appropriate type for the result
of the reduction (e.g. float64 if the data is float32)
-rwxr-xr-x | Wrappers/Python/ccpi/framework/framework.py | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/Wrappers/Python/ccpi/framework/framework.py b/Wrappers/Python/ccpi/framework/framework.py index e278795..66420b9 100755 --- a/Wrappers/Python/ccpi/framework/framework.py +++ b/Wrappers/Python/ccpi/framework/framework.py @@ -764,13 +764,23 @@ class DataContainer(object): return numpy.sqrt(self.squared_norm()) def dot(self, other, *args, **kwargs): '''return the inner product of 2 DataContainers viewed as vectors''' + method = kwargs.get('method', 'reduce') if self.shape == other.shape: - return (self*other).sum() - #return numpy.dot(self.as_array().ravel(), other.as_array().ravel()) + # return (self*other).sum() + if method == 'numpy': + return numpy.dot(self.as_array().ravel(), other.as_array()) + elif method == 'reduce': + # see https://github.com/vais-ral/CCPi-Framework/pull/273 + # notice that Python seems to be smart enough to use + # the appropriate type to hold the result of the reduction + sf = reduce(lambda x,y: x + y[0]*y[1], + zip(self.as_array().ravel(), + other.as_array().ravel()), + 0) + return sf else: raise ValueError('Shapes are not aligned: {} != {}'.format(self.shape, other.shape)) - - + |