From c8eeb3b9f202c16535f3c056a09fb74f638c43f2 Mon Sep 17 00:00:00 2001 From: Vaggelis Date: Wed, 3 Apr 2019 00:10:52 +0100 Subject: add precond test blockOperator --- .../ccpi/optimisation/operators/BlockOperator.py | 34 ++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) (limited to 'Wrappers/Python') diff --git a/Wrappers/Python/ccpi/optimisation/operators/BlockOperator.py b/Wrappers/Python/ccpi/optimisation/operators/BlockOperator.py index 19da3d4..752fd21 100755 --- a/Wrappers/Python/ccpi/optimisation/operators/BlockOperator.py +++ b/Wrappers/Python/ccpi/optimisation/operators/BlockOperator.py @@ -236,7 +236,21 @@ class BlockOperator(Operator): return ImageData(tmp) else: return BlockDataContainer(*res) - + + def sum_abs_col(self): + + res = [] + for row in range(self.shape[0]): + for col in range(self.shape[1]): + if col == 0: + prod = self.get_item(row, col).sum_abs_col() + else: + prod += self.get_item(row, col).sum_abs_col() + res.append(prod) + + return BlockDataContainer(*res) + + if __name__ == '__main__': @@ -247,7 +261,7 @@ if __name__ == '__main__': from ccpi.optimisation.operators import Operator, LinearOperator - M, N= 4, 3 + M, N = 4, 3 ig = ImageGeometry(M, N) arr = ig.allocate('random_int') G = Gradient(ig) @@ -263,11 +277,27 @@ if __name__ == '__main__': d1 = abs(Gx.matrix()).toarray().sum(axis=0) d2 = abs(Gy.matrix()).toarray().sum(axis=0) d3 = abs(Id.matrix()).toarray().sum(axis=0) + d_res = numpy.reshape(d1 + d2 + d3, ig.shape, 'F') print(d_res) + z1 = abs(Gx.matrix()).toarray().sum(axis=1) + z2 = abs(Gy.matrix()).toarray().sum(axis=1) + z3 = abs(Id.matrix()).toarray().sum(axis=1) + + z_res = BlockDataContainer(BlockDataContainer(ImageData(numpy.reshape(z2, ig.shape, 'F')),\ + ImageData(numpy.reshape(z1, ig.shape, 'F'))),\ + ImageData(numpy.reshape(z3, ig.shape, 'F'))) + + ttt = B.sum_abs_col() + + numpy.testing.assert_array_almost_equal(z_res[0][0].as_array(), ttt[0][0].as_array(), decimal=4) + numpy.testing.assert_array_almost_equal(z_res[0][1].as_array(), ttt[0][1].as_array(), decimal=4) + numpy.testing.assert_array_almost_equal(z_res[1].as_array(), ttt[1].as_array(), decimal=4) + + -- cgit v1.2.3