summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Wrappers/Python/ccpi/framework.py225
-rw-r--r--Wrappers/Python/test/regularizers.py59
2 files changed, 245 insertions, 39 deletions
diff --git a/Wrappers/Python/ccpi/framework.py b/Wrappers/Python/ccpi/framework.py
index 035c729..3cfa2a0 100644
--- a/Wrappers/Python/ccpi/framework.py
+++ b/Wrappers/Python/ccpi/framework.py
@@ -16,6 +16,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+from __future__ import division
import abc
import numpy
import sys
@@ -91,8 +93,10 @@ class CCPiBaseClass(ABC):
if self.debug:
print ("{0}: {1}".format(self.__class__.__name__, msg))
-class DataSet():
- '''Generic class to hold data'''
+class DataSet(object):
+ '''Generic class to hold data
+
+ Data is currently held in a numpy arrays'''
def __init__ (self, array, deep_copy=True, dimension_labels=None,
**kwargs):
@@ -199,8 +203,174 @@ class DataSet():
numpy.shape(array)))
self.array = array[:]
-
-
+ def checkDimensions(self, other):
+ return self.shape == other.shape
+
+ def __add__(self, other):
+ if issubclass(type(other), DataSet):
+ if self.checkDimensions(other):
+ out = self.as_array() + other.as_array()
+ return DataSet(out,
+ deep_copy=True,
+ dimension_labels=self.dimension_labels)
+ else:
+ raise ValueError('Wrong shape: {0} and {1}'.format(self.shape,
+ other.shape))
+ elif isinstance(other, (int, float, complex)):
+ return DataSet(self.as_array() + other,
+ deep_copy=True,
+ dimension_labels=self.dimension_labels)
+ else:
+ raise TypeError('Cannot {0} DataSet with {1}'.format("add" ,
+ type(other)))
+ # __add__
+
+ def __sub__(self, other):
+ if issubclass(type(other), DataSet):
+ if self.checkDimensions(other):
+ out = self.as_array() - other.as_array()
+ return DataSet(out,
+ deep_copy=True,
+ dimension_labels=self.dimension_labels)
+ else:
+ raise ValueError('Wrong shape: {0} and {1}'.format(self.shape,
+ other.shape))
+ elif isinstance(other, (int, float, complex)):
+ return DataSet(self.as_array() - other,
+ deep_copy=True,
+ dimension_labels=self.dimension_labels)
+ else:
+ raise TypeError('Cannot {0} DataSet with {1}'.format("subtract" ,
+ type(other)))
+ # __sub__
+ def __truediv__(self,other):
+ return self.__div__(other)
+
+ def __div__(self, other):
+ print ("calling __div__")
+ if issubclass(type(other), DataSet):
+ if self.checkDimensions(other):
+ out = self.as_array() / other.as_array()
+ return DataSet(out,
+ deep_copy=True,
+ dimension_labels=self.dimension_labels)
+ else:
+ raise ValueError('Wrong shape: {0} and {1}'.format(self.shape,
+ other.shape))
+ elif isinstance(other, (int, float, complex)):
+ return DataSet(self.as_array() / other,
+ deep_copy=True,
+ dimension_labels=self.dimension_labels)
+ else:
+ raise TypeError('Cannot {0} DataSet with {1}'.format("divide" ,
+ type(other)))
+ # __div__
+
+ def __pow__(self, other):
+ if issubclass(type(other), DataSet):
+ if self.checkDimensions(other):
+ out = self.as_array() ** other.as_array()
+ return DataSet(out,
+ deep_copy=True,
+ dimension_labels=self.dimension_labels)
+ else:
+ raise ValueError('Wrong shape: {0} and {1}'.format(self.shape,
+ other.shape))
+ elif isinstance(other, (int, float, complex)):
+ return DataSet(self.as_array() ** other,
+ deep_copy=True,
+ dimension_labels=self.dimension_labels)
+ else:
+ raise TypeError('Cannot {0} DataSet with {1}'.format("power" ,
+ type(other)))
+ # __pow__
+
+ def __mul__(self, other):
+ if issubclass(type(other), DataSet):
+ if self.checkDimensions(other):
+ out = self.as_array() * other.as_array()
+ return DataSet(out,
+ deep_copy=True,
+ dimension_labels=self.dimension_labels)
+ else:
+ raise ValueError('Wrong shape: {0} and {1}'.format(self.shape,
+ other.shape))
+ elif isinstance(other, (int, float, complex)):
+ return DataSet(self.as_array() * other,
+ deep_copy=True,
+ dimension_labels=self.dimension_labels)
+ else:
+ raise TypeError('Cannot {0} DataSet with {1}'.format("multiply" ,
+ type(other)))
+ # __mul__
+
+
+ #def __abs__(self):
+ # operation = FM.OPERATION.ABS
+ # return self.callFieldMath(operation, None, self.mask, self.maskOnValue)
+ # __abs__
+
+ # reverse operand
+ def __radd__(self, other):
+ return self + other
+ # __radd__
+
+ def __rsub__(self, other):
+ return (-1 * self) + other
+ # __rsub__
+
+ def __rmul__(self, other):
+ return self * other
+ # __rmul__
+
+ def __rdiv__(self, other):
+ print ("call __rdiv__")
+ return pow(self / other, -1)
+ # __rdiv__
+ def __rtruediv__(self, other):
+ return self.__rdiv__(other)
+
+ def __rpow__(self, other):
+ if isinstance(other, (int, float)) :
+ fother = numpy.ones(numpy.shape(self.array)) * other
+ return DataSet(fother ** self.array ,
+ dimension_labels=self.dimension_labels)
+ elif issubclass(other, DataSet):
+ if self.checkDimensions(other):
+ return DataSet(other.as_array() ** self.array ,
+ dimension_labels=self.dimension_labels)
+ else:
+ raise ValueError('Dimensions do not match')
+ # __rpow__
+
+
+ # in-place arithmetic operators:
+ # (+=, -=, *=, /= , //=,
+
+ def __iadd__(self, other):
+ return self + other
+ # __iadd__
+
+ def __imul__(self, other):
+ return self * other
+ # __imul__
+
+ def __isub__(self, other):
+ return self - other
+ # __isub__
+
+ def __idiv__(self, other):
+ print ("call __idiv__")
+ return self / other
+ # __idiv__
+
+ def __str__ (self):
+ repres = ""
+ repres += "Number of dimensions: {0}\n".format(self.number_of_dimensions)
+ repres += "Shape: {0}\n".format(self.shape)
+ repres += "Axis labels: {0}\n".format(self.dimension_labels)
+ repres += "Representation: {0}\n".format(self.array)
+ return repres
@@ -219,7 +389,9 @@ class VolumeData(DataSet):
raise ValueError('Number of dimensions are not 2 or 3: {0}'\
.format(array.number_of_dimensions))
- DataSet.__init__(self, array.as_array(), deep_copy,
+ #DataSet.__init__(self, array.as_array(), deep_copy,
+ # array.dimension_labels, **kwargs)
+ super(VolumeData, self).__init__(array.as_array(), deep_copy,
array.dimension_labels, **kwargs)
elif type(array) == numpy.ndarray:
if not ( array.ndim == 3 or array.ndim == 2 ):
@@ -236,8 +408,9 @@ class VolumeData(DataSet):
dimension_labels = ['horizontal' ,
'vertical']
- DataSet.__init__(self, array, deep_copy, dimension_labels, **kwargs)
-
+ #DataSet.__init__(self, array, deep_copy, dimension_labels, **kwargs)
+ super(VolumeData, self).__init__(array, deep_copy,
+ dimension_labels, **kwargs)
# load metadata from kwargs if present
for key, value in kwargs.items():
@@ -287,7 +460,7 @@ class SinogramData(DataSet):
# assume it is parallel beam
pass
-class DataSetProcessor():
+class DataSetProcessor(object):
'''Defines a generic DataSet processor
accepts DataSet as inputs and
@@ -341,6 +514,7 @@ class DataSetProcessor():
elif self.mTime > self.runTime:
shouldRun = True
+ # CHECK this
if self.store_output and shouldRun:
self.runTime = datetime.now()
self.output = self.process()
@@ -405,8 +579,8 @@ class AX(DataSetProcessor):
'input':None,
}
- DataSetProcessor.__init__(self, **kwargs)
-
+ #DataSetProcessor.__init__(self, **kwargs)
+ super(AX, self).__init__(**kwargs)
def checkInput(self, dataset):
return True
@@ -438,8 +612,8 @@ class PixelByPixelDataSetProcessor(DataSetProcessor):
kwargs = {'pyfunc':None,
'input':None,
}
- DataSetProcessor.__init__(self, **kwargs)
-
+ #DataSetProcessor.__init__(self, **kwargs)
+ super(PixelByPixelDataSetProcessor, self).__init__(**kwargs)
def checkInput(self, dataset):
return True
@@ -530,4 +704,29 @@ if __name__ == '__main__':
chain.setInputProcessor(ax)
print ("chain in {0} out {1}".format(ax.getOutput().as_array(), chain.getOutput().as_array()))
- \ No newline at end of file
+ # testing arithmetic operations
+
+ print (b)
+ print ((b+1))
+ print ((1+b))
+
+ print (b)
+ print ((b*2))
+
+ print (b)
+ print ((2*b))
+
+ print (b)
+ print ((b/2))
+
+ print (b)
+ print ((2/b))
+
+ print (b)
+ print ((b**2))
+
+ print (b)
+ print ((2**b))
+
+
+ \ No newline at end of file
diff --git a/Wrappers/Python/test/regularizers.py b/Wrappers/Python/test/regularizers.py
index 25873c7..04ac3aa 100644
--- a/Wrappers/Python/test/regularizers.py
+++ b/Wrappers/Python/test/regularizers.py
@@ -17,6 +17,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import print_function
import matplotlib.pyplot as plt
import numpy as np
import os
@@ -27,7 +28,7 @@ from ccpi.filters.cpu_regularizers_boost import SplitBregman_TV , FGP_TV ,\
#from ccpi.filters.cpu_regularizers_cython import some
try:
- from ccpi.filter import gpu_regularizers as gpu
+ from ccpi.filters import gpu_regularizers as gpu
class PatchBasedRegGPU(DataSetProcessor23D):
'''Regularizers DataSetProcessor for PatchBasedReg
@@ -40,7 +41,7 @@ try:
'similarity_window_ratio': None,
'PB_filtering_parameter': None
}
- DataSetProcessor.__init__(self, **attributes)
+ super(PatchBasedRegGPU, self).__init__(**attributes)
def process(self):
@@ -67,7 +68,7 @@ try:
'similarity_window_ratio': None,
'PB_filtering_parameter': None
}
- DataSetProcessor.__init__(self, **attributes)
+ super(Diff4thHajiaboli, self).__init__(self, **attributes)
def process(self):
@@ -97,7 +98,7 @@ class SBTV(DataSetProcessor23D):
'tolerance_constant': 0.0001,
'TV_penalty':0
}
- DataSetProcessor.__init__(self, **attributes)
+ super(SBTV , self).__init__(**attributes)
def process(self):
@@ -124,7 +125,7 @@ class FGPTV(DataSetProcessor23D):
'tolerance_constant': 0.0001,
'TV_penalty':0
}
- DataSetProcessor.__init__(self, **attributes)
+ super(FGPTV, self).__init__(**attributes)
def process(self):
@@ -153,7 +154,7 @@ class LLT(DataSetProcessor23D):
'tolerance_constant': 0,
'restrictive_Z_smoothing': None
}
- DataSetProcessor.__init__(self, **attributes)
+ super(LLT, self).__init__(**attributes)
def process(self):
@@ -182,7 +183,7 @@ class PatchBasedReg(DataSetProcessor23D):
'similarity_window_ratio': None,
'PB_filtering_parameter': None
}
- DataSetProcessor.__init__(self, **attributes)
+ super(PatchBasedReg, self).__init__(**attributes)
def process(self):
@@ -204,13 +205,17 @@ class TGVPD(DataSetProcessor23D):
'''
- def __init__(self):
+ def __init__(self,**kwargs):
attributes = {'regularization_parameter':None,
'first_order_term': None,
'second_order_term': None,
'number_of_iterations': None
}
- DataSetProcessor.__init__(self, **attributes)
+ for key, value in kwargs.items():
+ if key in attributes.keys():
+ attributes[key] = value
+
+ super(TGVPD, self).__init__(**attributes)
def process(self):
@@ -247,15 +252,17 @@ if __name__ == '__main__':
"lena_gray_512.tif")
Im = plt.imread(filename)
Im = np.asarray(Im, dtype='float32')
-
- perc = 0.15
+
+ Im = Im/255
+
+ perc = 0.075
u0 = Im + np.random.normal(loc = Im ,
- scale = perc * Im ,
- size = np.shape(Im))
+ scale = perc * Im ,
+ size = np.shape(Im))
# map the u0 u0->u0>0
f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1)
u0 = f(u0).astype('float32')
-
+
lena = DataSet(u0, False, ['X','Y'])
## plot
@@ -272,9 +279,9 @@ if __name__ == '__main__':
reg3 = SBTV()
- reg3.number_of_iterations = 350
- reg3.tolerance_constant = 0.01
- reg3.regularization_parameter = 40
+ reg3.number_of_iterations = 40
+ reg3.tolerance_constant = 0.0001
+ reg3.regularization_parameter = 15
reg3.TV_penalty = 0
reg3.setInput(lena)
dataprocessoroutput = reg3.getOutput()
@@ -293,9 +300,9 @@ if __name__ == '__main__':
##########################################################################
reg4 = FGPTV()
- reg4.number_of_iterations = 350
- reg4.tolerance_constant = 0.01
- reg4.regularization_parameter = 40
+ reg4.number_of_iterations = 200
+ reg4.tolerance_constant = 1e-4
+ reg4.regularization_parameter = 0.05
reg4.TV_penalty = 0
reg4.setInput(lena)
dataprocessoroutput2 = reg4.getOutput()
@@ -313,10 +320,10 @@ if __name__ == '__main__':
###########################################################################
reg6 = LLT()
- reg6.regularization_parameter = 25
- reg6.time_step = 0.0003
- reg6.number_of_iterations = 300
- reg6.tolerance_constant = 0.001
+ reg6.regularization_parameter = 5
+ reg6.time_step = 0.00035
+ reg6.number_of_iterations = 350
+ reg6.tolerance_constant = 0.0001
reg6.restrictive_Z_smoothing = 0
reg6.setInput(lena)
llt = reg6.getOutput()
@@ -336,7 +343,7 @@ if __name__ == '__main__':
reg7.regularization_parameter = 0.05
reg7.searching_window_ratio = 3
reg7.similarity_window_ratio = 1
- reg7.PB_filtering_parameter = 0.08
+ reg7.PB_filtering_parameter = 0.06
reg7.setInput(lena)
pbr = reg7.getOutput()
# plot
@@ -352,7 +359,7 @@ if __name__ == '__main__':
###########################################################################
reg5 = TGVPD()
- reg5.regularization_parameter = 0.05
+ reg5.regularization_parameter = 0.07
reg5.first_order_term = 1.3
reg5.second_order_term = 1
reg5.number_of_iterations = 550