summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Wrappers/Python/ccpi/optimisation/operators/GradientOperator.py40
1 files changed, 32 insertions, 8 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/operators/GradientOperator.py b/Wrappers/Python/ccpi/optimisation/operators/GradientOperator.py
index 6ffaf70..60978be 100644
--- a/Wrappers/Python/ccpi/optimisation/operators/GradientOperator.py
+++ b/Wrappers/Python/ccpi/optimisation/operators/GradientOperator.py
@@ -14,23 +14,47 @@ from ccpi.optimisation.operators import FiniteDiff, SparseFiniteDiff
#%%
class Gradient(LinearOperator):
-
+ CORRELATION_SPACE = "Space"
+ CORRELATION_SPACECHANNEL = "SpaceChannels"
+ # Grad_order = ['channels', 'direction_z', 'direction_y', 'direction_x']
+ # Grad_order = ['channels', 'direction_y', 'direction_x']
+ # Grad_order = ['direction_z', 'direction_y', 'direction_x']
+ # Grad_order = ['channels', 'direction_z', 'direction_y', 'direction_x']
def __init__(self, gm_domain, bnd_cond = 'Neumann', **kwargs):
super(Gradient, self).__init__()
self.gm_domain = gm_domain # Domain of Grad Operator
- self.correlation = kwargs.get('correlation','Space')
+ self.correlation = kwargs.get('correlation',Gredient.CORRELATION_SPACE)
- if self.correlation=='Space':
+ if self.correlation==Gredient.CORRELATION_SPACE:
if self.gm_domain.channels>1:
- self.gm_range = BlockGeometry(*[self.gm_domain for _ in range(self.gm_domain.length-1)] )
- self.ind = numpy.arange(1,self.gm_domain.length)
- else:
+ self.gm_range = BlockGeometry(*[self.gm_domain for _ in range(self.gm_domain.length-1)] )
+ if self.gm_domain.length == 4:
+ # 3D + Channel
+ # expected Grad_order = ['channels', 'direction_z', 'direction_y', 'direction_x']
+ expected_order = [ImageGeometry.CHANNEL, ImageGeometry.VERTICAL, ImageGeometry.HORIZONTAL_Y, ImageGeometry.HORIZONTAL_X]
+ else:
+ # 2D + Channel
+ # expected Grad_order = ['channels', 'direction_y', 'direction_x']
+ expected_order = [ImageGeometry.CHANNEL, ImageGeometry.HORIZONTAL_Y, ImageGeometry.HORIZONTAL_X]
+ order = self.gm_domain.get_order_by_label(self.gm_domain.dimension_labels, expected_order)
+ self.ind = order[1:]
+ #self.ind = numpy.arange(1,self.gm_domain.length)
+ else:
+ # no channel info
self.gm_range = BlockGeometry(*[self.gm_domain for _ in range(self.gm_domain.length) ] )
- self.ind = numpy.arange(self.gm_domain.length)
- elif self.correlation=='SpaceChannels':
+ if self.gm_domain.length == 3:
+ # 3D
+ # expected Grad_order = ['direction_z', 'direction_y', 'direction_x']
+ expected_order = [ImageGeometry.VERTICAL, ImageGeometry.HORIZONTAL_Y, ImageGeometry.HORIZONTAL_X]
+ else:
+ # 2D
+ expected_order = [ImageGeometry.VERTICAL, ImageGeometry.HORIZONTAL_Y, ImageGeometry.HORIZONTAL_X]
+ self.ind = self.gm_domain.get_order_by_label(self.gm_domain.dimension_labels, expected_order)
+ # self.ind = numpy.arange(self.gm_domain.length)
+ elif self.correlation==Gredient.CORRELATION_SPACECHANNEL:
if self.gm_domain.channels>1:
self.gm_range = BlockGeometry(*[self.gm_domain for _ in range(self.gm_domain.length)])
self.ind = range(self.gm_domain.length)