summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python
diff options
context:
space:
mode:
Diffstat (limited to 'Wrappers/Python')
-rwxr-xr-xWrappers/Python/ccpi/optimisation/operators/BlockOperator.py34
1 files changed, 32 insertions, 2 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/operators/BlockOperator.py b/Wrappers/Python/ccpi/optimisation/operators/BlockOperator.py
index 19da3d4..752fd21 100755
--- a/Wrappers/Python/ccpi/optimisation/operators/BlockOperator.py
+++ b/Wrappers/Python/ccpi/optimisation/operators/BlockOperator.py
@@ -236,7 +236,21 @@ class BlockOperator(Operator):
return ImageData(tmp)
else:
return BlockDataContainer(*res)
-
+
+ def sum_abs_col(self):
+
+ res = []
+ for row in range(self.shape[0]):
+ for col in range(self.shape[1]):
+ if col == 0:
+ prod = self.get_item(row, col).sum_abs_col()
+ else:
+ prod += self.get_item(row, col).sum_abs_col()
+ res.append(prod)
+
+ return BlockDataContainer(*res)
+
+
if __name__ == '__main__':
@@ -247,7 +261,7 @@ if __name__ == '__main__':
from ccpi.optimisation.operators import Operator, LinearOperator
- M, N= 4, 3
+ M, N = 4, 3
ig = ImageGeometry(M, N)
arr = ig.allocate('random_int')
G = Gradient(ig)
@@ -263,11 +277,27 @@ if __name__ == '__main__':
d1 = abs(Gx.matrix()).toarray().sum(axis=0)
d2 = abs(Gy.matrix()).toarray().sum(axis=0)
d3 = abs(Id.matrix()).toarray().sum(axis=0)
+
d_res = numpy.reshape(d1 + d2 + d3, ig.shape, 'F')
print(d_res)
+ z1 = abs(Gx.matrix()).toarray().sum(axis=1)
+ z2 = abs(Gy.matrix()).toarray().sum(axis=1)
+ z3 = abs(Id.matrix()).toarray().sum(axis=1)
+
+ z_res = BlockDataContainer(BlockDataContainer(ImageData(numpy.reshape(z2, ig.shape, 'F')),\
+ ImageData(numpy.reshape(z1, ig.shape, 'F'))),\
+ ImageData(numpy.reshape(z3, ig.shape, 'F')))
+
+ ttt = B.sum_abs_col()
+
+ numpy.testing.assert_array_almost_equal(z_res[0][0].as_array(), ttt[0][0].as_array(), decimal=4)
+ numpy.testing.assert_array_almost_equal(z_res[0][1].as_array(), ttt[0][1].as_array(), decimal=4)
+ numpy.testing.assert_array_almost_equal(z_res[1].as_array(), ttt[1].as_array(), decimal=4)
+
+