From 5aaa46237fbf0a6bb008fe81576cabc61e3b1fce Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Thu, 3 Aug 2017 15:26:29 +0100
Subject: Added Python modules

Matlab2Python_utils.cpp contains utilities for handling numpy arrays.
Together with setup_test.py it creates a functional module for testing.

fista_module.cpp and setup.py are meant for the real fista module.
---
 src/Python/Matlab2Python_utils.cpp | 206 ++++++++++++++++++++++++
 src/Python/fista_module.cpp        | 315 +++++++++++++++++++++++++++++++++++++
 src/Python/setup.py                |  58 +++++++
 src/Python/setup_test.py           |  58 +++++++
 4 files changed, 637 insertions(+)
 create mode 100644 src/Python/Matlab2Python_utils.cpp
 create mode 100644 src/Python/fista_module.cpp
 create mode 100644 src/Python/setup.py
 create mode 100644 src/Python/setup_test.py

(limited to 'src/Python')

diff --git a/src/Python/Matlab2Python_utils.cpp b/src/Python/Matlab2Python_utils.cpp
new file mode 100644
index 0000000..138e8da
--- /dev/null
+++ b/src/Python/Matlab2Python_utils.cpp
@@ -0,0 +1,206 @@
+/*
+This work is part of the Core Imaging Library developed by
+Visual Analytics and Imaging System Group of the Science Technology
+Facilities Council, STFC
+
+Copyright 2017 Daniil Kazanteev
+Copyright 2017 Srikanth Nagella, Edoardo Pasca
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
+
+#include <iostream>
+#include <cmath>
+
+#include <boost/python.hpp>
+#include <boost/python/numpy.hpp>
+#include "boost/tuple/tuple.hpp"
+
+#if defined(_WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(_WIN64)
+#include <windows.h>
+// this trick only if compiler is MSVC
+__if_not_exists(uint8_t) { typedef __int8 uint8_t; }
+__if_not_exists(uint16_t) { typedef __int8 uint16_t; }
+#endif
+
+namespace bp = boost::python;
+namespace np = boost::python::numpy;
+
+/*! in the Matlab implementation this is called as
+void mexFunction(
+int nlhs, mxArray *plhs[],
+int nrhs, const mxArray *prhs[])
+where:
+prhs Array of pointers to the INPUT mxArrays
+nrhs int number of INPUT mxArrays
+
+nlhs Array of pointers to the OUTPUT mxArrays
+plhs int number of OUTPUT mxArrays
+
+***********************************************************
+	
+***********************************************************
+double mxGetScalar(const mxArray *pm);
+args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray.
+Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray.	In C, mxGetScalar returns a double.
+***********************************************************
+char *mxArrayToString(const mxArray *array_ptr);
+args: array_ptr Pointer to mxCHAR array.
+Returns: C-style string. Returns NULL on failure. Possible reasons for failure include out of memory and specifying an array that is not an mxCHAR array.
+Description: Call mxArrayToString to copy the character data of an mxCHAR array into a C-style string.
+***********************************************************
+mxClassID mxGetClassID(const mxArray *pm);
+args: pm Pointer to an mxArray
+Returns: Numeric identifier of the class (category) of the mxArray that pm points to.For user-defined types,
+mxGetClassId returns a unique value identifying the class of the array contents.
+Use mxIsClass to determine whether an array is of a specific user-defined type.
+
+mxClassID Value	  MATLAB Type   MEX Type	 C Primitive Type
+mxINT8_CLASS 	  int8	        int8_T	     char, byte
+mxUINT8_CLASS	  uint8	        uint8_T	     unsigned char, byte
+mxINT16_CLASS	  int16	        int16_T	     short
+mxUINT16_CLASS	  uint16	    uint16_T	 unsigned short
+mxINT32_CLASS	  int32	        int32_T	     int
+mxUINT32_CLASS	  uint32	    uint32_T	 unsigned int
+mxINT64_CLASS	  int64	        int64_T	     long long
+mxUINT64_CLASS	  uint64	    uint64_T 	 unsigned long long
+mxSINGLE_CLASS	  single	    float	     float
+mxDOUBLE_CLASS	  double	    double	     double
+
+****************************************************************
+double *mxGetPr(const mxArray *pm);
+args: pm Pointer to an mxArray of type double
+Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data.
+****************************************************************
+mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims,
+mxClassID classid, mxComplexity ComplexFlag);
+args: ndimNumber of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2.
+dims Dimensions array. Each element in the dimensions array contains the size of the array in that dimension.
+For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array.
+classid Identifier for the class of the array, which determines the way the numerical data is represented in memory.
+For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer.
+ComplexFlag  If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran).
+Returns: Pointer to the created mxArray, if successful. If unsuccessful in a standalone (non-MEX file) application, returns NULL in C (0 in Fortran).
+If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not
+enough free heap space to create the mxArray.
+*/
+
+void mexErrMessageText(char* text) {
+	std::cerr << text << std::endl;
+}
+
+/*
+double mxGetScalar(const mxArray *pm);
+args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray.
+Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray.	In C, mxGetScalar returns a double.
+*/
+
+template<typename T>
+double mxGetScalar(const np::ndarray plh) {
+	return (double)bp::extract<T>(plh[0]);
+}
+
+
+
+template<typename T>
+T * mxGetData(const np::ndarray pm) {
+    //args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray.
+	//Returns: Pointer to the value of the first real(nonimaginary) element of the mxArray.In C, mxGetScalar returns a double.
+	/*Access the numpy array pointer:
+	char * get_data() const;
+	Returns:	Array�s raw data pointer as a char
+	Note:	This returns char so stride math works properly on it.User will have to reinterpret_cast it.
+	probably this would work.
+	A = reinterpret_cast<float *>(prhs[0]);
+	*/
+	return reinterpret_cast<T *>(prhs[0]);
+}
+
+template<typename T>
+np::ndarray zeros(int dims , int * dim_array, T el) {
+	bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]);
+	np::dtype dtype = np::dtype::get_builtin<T>();
+	np::ndarray zz = np::zeros(shape, dtype);
+	return zz;
+}
+
+
+bp::list mexFunction( np::ndarray input ) {
+	int number_of_dims = input.get_nd();
+	int dim_array[3];
+
+	dim_array[0] = input.shape(0);
+	dim_array[1] = input.shape(1);
+	if (number_of_dims == 2) {
+		dim_array[2] = -1;
+	}
+	else {
+		dim_array[2] = input.shape(2);
+	}
+
+	/**************************************************************************/
+	np::ndarray zz = zeros(3, dim_array, (int)0);
+	np::ndarray fzz = zeros(3, dim_array, (float)0);
+	/**************************************************************************/
+	
+	int * A = reinterpret_cast<int *>( input.get_data() );
+	int * B = reinterpret_cast<int *>( zz.get_data() );
+	float * C = reinterpret_cast<float *>(fzz.get_data());
+
+	//Copy data and cast
+	for (int i = 0; i < dim_array[0]; i++) {
+		for (int j = 0; j < dim_array[1]; j++) {
+			for (int k = 0; k < dim_array[2]; k++) {
+				int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i;
+				int val = (*(A + index));
+				float fval = (float)val;
+				std::memcpy(B + index , &val, sizeof(int));
+				std::memcpy(C + index , &fval, sizeof(float));
+			}
+		}
+	}
+
+
+	bp::list result;
+
+	result.append<int>(number_of_dims);
+	result.append<int>(dim_array[0]);
+	result.append<int>(dim_array[1]);
+	result.append<int>(dim_array[2]);
+	result.append<np::ndarray>(zz);
+	result.append<np::ndarray>(fzz);
+
+	//result.append<bp::tuple>(tup);
+	return result;
+
+}
+
+
+BOOST_PYTHON_MODULE(fista)
+{
+	np::initialize();
+
+	//To specify that this module is a package
+	bp::object package = bp::scope();
+	package.attr("__path__") = "fista";
+
+	np::dtype dt1 = np::dtype::get_builtin<uint8_t>();
+	np::dtype dt2 = np::dtype::get_builtin<uint16_t>();
+	
+	//import_array();
+	//numpy_boost_python_register_type<float, 1>();
+	//numpy_boost_python_register_type<float, 2>();
+	//numpy_boost_python_register_type<float, 3>();
+	//numpy_boost_python_register_type<double, 3>();
+	def("mexFunction", mexFunction);
+}
\ No newline at end of file
diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp
new file mode 100644
index 0000000..5344083
--- /dev/null
+++ b/src/Python/fista_module.cpp
@@ -0,0 +1,315 @@
+/*
+This work is part of the Core Imaging Library developed by
+Visual Analytics and Imaging System Group of the Science Technology
+Facilities Council, STFC
+
+Copyright 2017 Daniil Kazanteev
+Copyright 2017 Srikanth Nagella, Edoardo Pasca
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
+
+#include <iostream>
+#include <cmath>
+
+#include <boost/python.hpp>
+#include <boost/python/numpy.hpp>
+#include "boost/tuple/tuple.hpp"
+
+// include the regularizers
+#include "FGP_TV_core.h"
+#include "LLT_model_core.h"
+#include "PatchBased_Regul_core.h"
+#include "SplitBregman_TV_core.h"
+#include "TGV_PD_core.h"
+
+#if defined(_WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(_WIN64)
+#include <windows.h>
+// this trick only if compiler is MSVC
+__if_not_exists(uint8_t) { typedef __int8 uint8_t; }
+__if_not_exists(uint16_t) { typedef __int8 uint16_t; }
+#endif
+
+namespace bp = boost::python;
+namespace np = boost::python::numpy;
+
+
+/*! in the Matlab implementation this is called as
+void mexFunction(
+int nlhs, mxArray *plhs[],
+int nrhs, const mxArray *prhs[])
+where:
+prhs Array of pointers to the INPUT mxArrays
+nrhs int number of INPUT mxArrays
+
+nlhs Array of pointers to the OUTPUT mxArrays
+plhs int number of OUTPUT mxArrays
+
+***********************************************************
+mxGetData
+args: pm Pointer to an mxArray
+Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data.
+***********************************************************
+double mxGetScalar(const mxArray *pm);
+args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray.
+Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray.	In C, mxGetScalar returns a double.
+***********************************************************
+char *mxArrayToString(const mxArray *array_ptr);
+args: array_ptr Pointer to mxCHAR array.
+Returns: C-style string. Returns NULL on failure. Possible reasons for failure include out of memory and specifying an array that is not an mxCHAR array.
+Description: Call mxArrayToString to copy the character data of an mxCHAR array into a C-style string.
+***********************************************************
+mxClassID mxGetClassID(const mxArray *pm);
+args: pm Pointer to an mxArray
+Returns: Numeric identifier of the class (category) of the mxArray that pm points to.For user-defined types,
+mxGetClassId returns a unique value identifying the class of the array contents.
+Use mxIsClass to determine whether an array is of a specific user-defined type.
+
+mxClassID Value	  MATLAB Type   MEX Type	 C Primitive Type
+mxINT8_CLASS 	  int8	        int8_T	     char, byte
+mxUINT8_CLASS	  uint8	        uint8_T	     unsigned char, byte
+mxINT16_CLASS	  int16	        int16_T	     short
+mxUINT16_CLASS	  uint16	    uint16_T	 unsigned short
+mxINT32_CLASS	  int32	        int32_T	     int
+mxUINT32_CLASS	  uint32	    uint32_T	 unsigned int
+mxINT64_CLASS	  int64	        int64_T	     long long
+mxUINT64_CLASS	  uint64	    uint64_T 	 unsigned long long
+mxSINGLE_CLASS	  single	    float	     float
+mxDOUBLE_CLASS	  double	    double	     double
+
+****************************************************************
+double *mxGetPr(const mxArray *pm);
+args: pm Pointer to an mxArray of type double
+Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data.
+****************************************************************
+mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, mxClassID classid, mxComplexity ComplexFlag);
+args:  ndim: Number of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2.
+       dims: Dimensions array. Each element in the dimensions array contains the size of the array in that dimension.
+	         For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array.
+       classid: Identifier for the class of the array, which determines the way the numerical data is represented in memory.
+                For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer.
+       ComplexFlag:  If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). 
+	                 Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran).
+
+Returns: Pointer to the created mxArray, if successful. If unsuccessful in a standalone (non-MEX file) application, returns NULL in C (0 in Fortran).
+       If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not
+       enough free heap space to create the mxArray.
+*/
+
+template<typename T>
+np::ndarray zeros(int dims, int * dim_array, T el) {
+	bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]);
+	np::dtype dtype = np::dtype::get_builtin<T>();
+	np::ndarray zz = np::zeros(shape, dtype);
+	return zz;
+}
+
+
+bp::list SplitBregman_TV(np::ndarray input, double d_mu, , int niterations, double d_epsil, int TV_type) {
+	/* C-OMP implementation of Split Bregman - TV denoising-regularization model (2D/3D)
+	*
+	* Input Parameters:
+	* 1. Noisy image/volume
+	* 2. lambda - regularization parameter
+	* 3. Number of iterations [OPTIONAL parameter]
+	* 4. eplsilon - tolerance constant [OPTIONAL parameter]
+	* 5. TV-type: 'iso' or 'l1' [OPTIONAL parameter]
+	*
+	* Output:
+	* Filtered/regularized image
+	*
+	* All sanity checks and default values are set in Python
+	*/
+	int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV;
+	const int dim_array[3];
+	float *A, *U = NULL, *U_old = NULL, *Dx = NULL, *Dy = NULL, *Dz = NULL, *Bx = NULL, *By = NULL, *Bz = NULL, lambda, mu, epsil, re, re1, re_old;
+
+	//number_of_dims = mxGetNumberOfDimensions(prhs[0]);
+	//dim_array = mxGetDimensions(prhs[0]);
+	number_of_dims = input.get_nd();
+
+	dim_array[0] = input.shape(0);
+	dim_array[1] = input.shape(1);
+	if (number_of_dims == 2) {
+		dim_array[2] = -11;
+	}
+	else {
+		dim_array[2] = input.shape(2);
+	}
+
+	/*Handling Matlab input data*/
+	//if ((nrhs < 2) || (nrhs > 5)) mexErrMsgTxt("At least 2 parameters is required: Image(2D/3D), Regularization parameter. The full list of parameters: Image(2D/3D), Regularization parameter, iterations number, tolerance, penalty type ('iso' or 'l1')");
+
+	/*Handling Matlab input data*/
+	//A = (float *)mxGetData(prhs[0]); /*noisy image (2D/3D) */
+	A = reinterpret_cast<float *>(input.get_data());
+
+
+	//mu = (float)mxGetScalar(prhs[1]); /* regularization parameter */
+	mu = (float)d_mu;
+	//iter = 35; /* default iterations number */
+	iter = niterations;
+	//epsil = 0.0001; /* default tolerance constant */
+	epsil = (float)d_epsil;
+	//methTV = 0;  /* default isotropic TV penalty */
+	methTV = TV_type;
+	//if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5))  iter = (int)mxGetScalar(prhs[2]); /* iterations number */
+	//if ((nrhs == 4) || (nrhs == 5))  epsil = (float)mxGetScalar(prhs[3]); /* tolerance constant */
+	//if (nrhs == 5) {
+	//	char *penalty_type;
+	//	penalty_type = mxArrayToString(prhs[4]); /* choosing TV penalty: 'iso' or 'l1', 'iso' is the default */
+	//	if ((strcmp(penalty_type, "l1") != 0) && (strcmp(penalty_type, "iso") != 0)) mexErrMsgTxt("Choose TV type: 'iso' or 'l1',");
+	//	if (strcmp(penalty_type, "l1") == 0)  methTV = 1;  /* enable 'l1' penalty */
+	//	mxFree(penalty_type);
+	//}
+	//if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); }
+
+	lambda = 2.0f*mu;
+	count = 1;
+	re_old = 0.0f;
+	/*Handling Matlab output data*/
+	dimY = dim_array[0]; dimX = dim_array[1]; dimZ = dim_array[2];
+
+	if (number_of_dims == 2) {
+		dimZ = 1; /*2D case*/
+		/*
+		mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, mxClassID classid, mxComplexity ComplexFlag);
+args:  ndim: Number of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2.
+       dims: Dimensions array. Each element in the dimensions array contains the size of the array in that dimension.
+	         For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array.
+       classid: Identifier for the class of the array, which determines the way the numerical data is represented in memory.
+                For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer.
+       ComplexFlag:  If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). 
+	                 Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran).
+
+					 mxCreateNumericArray initializes all its real data elements to 0.
+*/
+
+/*
+		U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+*/
+		//U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		U = A = reinterpret_cast<float *>input.get_data();
+		U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		copyIm(A, U, dimX, dimY, dimZ); /*initialize */
+
+										/* begin outer SB iterations */
+		for (ll = 0; ll<iter; ll++) {
+
+			/*storing old values*/
+			copyIm(U, U_old, dimX, dimY, dimZ);
+
+			/*GS iteration */
+			gauss_seidel2D(U, A, Dx, Dy, Bx, By, dimX, dimY, lambda, mu);
+
+			if (methTV == 1)  updDxDy_shrinkAniso2D(U, Dx, Dy, Bx, By, dimX, dimY, lambda);
+			else updDxDy_shrinkIso2D(U, Dx, Dy, Bx, By, dimX, dimY, lambda);
+
+			updBxBy2D(U, Dx, Dy, Bx, By, dimX, dimY);
+
+			/* calculate norm to terminate earlier */
+			re = 0.0f; re1 = 0.0f;
+			for (j = 0; j<dimX*dimY*dimZ; j++)
+			{
+				re += pow(U_old[j] - U[j], 2);
+				re1 += pow(U_old[j], 2);
+			}
+			re = sqrt(re) / sqrt(re1);
+			if (re < epsil)  count++;
+			if (count > 4) break;
+
+			/* check that the residual norm is decreasing */
+			if (ll > 2) {
+				if (re > re_old) break;
+			}
+			re_old = re;
+			/*printf("%f %i %i \n", re, ll, count); */
+
+			/*copyIm(U_old, U, dimX, dimY, dimZ); */
+		}
+		printf("SB iterations stopped at iteration: %i\n", ll);
+	}
+	if (number_of_dims == 3) {
+		U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+		U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+		Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+		Dy = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+		Dz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+		Bx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+		By = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+		Bz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+
+		copyIm(A, U, dimX, dimY, dimZ); /*initialize */
+
+										/* begin outer SB iterations */
+		for (ll = 0; ll<iter; ll++) {
+
+			/*storing old values*/
+			copyIm(U, U_old, dimX, dimY, dimZ);
+
+			/*GS iteration */
+			gauss_seidel3D(U, A, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ, lambda, mu);
+
+			if (methTV == 1) updDxDyDz_shrinkAniso3D(U, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ, lambda);
+			else updDxDyDz_shrinkIso3D(U, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ, lambda);
+
+			updBxByBz3D(U, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ);
+
+			/* calculate norm to terminate earlier */
+			re = 0.0f; re1 = 0.0f;
+			for (j = 0; j<dimX*dimY*dimZ; j++)
+			{
+				re += pow(U[j] - U_old[j], 2);
+				re1 += pow(U[j], 2);
+			}
+			re = sqrt(re) / sqrt(re1);
+			if (re < epsil)  count++;
+			if (count > 4) break;
+
+			/* check that the residual norm is decreasing */
+			if (ll > 2) {
+				if (re > re_old) break;
+			}
+			/*printf("%f %i %i \n", re, ll, count); */
+			re_old = re;
+		}
+		printf("SB iterations stopped at iteration: %i\n", ll);
+	}
+	bp::list result;
+	return result;
+}
+	
+
+BOOST_PYTHON_MODULE(fista)
+{
+	np::initialize();
+
+	//To specify that this module is a package
+	bp::object package = bp::scope();
+	package.attr("__path__") = "fista";
+
+	np::dtype dt1 = np::dtype::get_builtin<uint8_t>();
+	np::dtype dt2 = np::dtype::get_builtin<uint16_t>();
+
+	
+	def("mexFunction", mexFunction);
+}
\ No newline at end of file
diff --git a/src/Python/setup.py b/src/Python/setup.py
new file mode 100644
index 0000000..ffb9c02
--- /dev/null
+++ b/src/Python/setup.py
@@ -0,0 +1,58 @@
+#!/usr/bin/env python
+
+import setuptools
+from distutils.core import setup
+from distutils.extension import Extension
+from Cython.Distutils import build_ext
+
+import os
+import sys
+import numpy
+import platform	
+
+cil_version=os.environ['CIL_VERSION']
+if  cil_version == '':
+    print("Please set the environmental variable CIL_VERSION")
+    sys.exit(1)
+
+library_include_path = ""
+library_lib_path = ""
+try:
+    library_include_path = os.environ['LIBRARY_INC']
+    library_lib_path = os.environ['LIBRARY_LIB']
+except:
+    library_include_path = os.environ['PREFIX']+'/include'
+    pass
+    
+extra_include_dirs = [numpy.get_include(), library_include_path]
+extra_library_dirs = [library_include_path+"/../lib", "C:\\Apps\\Miniconda2\\envs\\cil27\\Library\\lib"]
+extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x']
+extra_libraries = []
+if platform.system() == 'Windows':
+    extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB']   
+    extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."]
+    if sys.version_info.major == 3 :   
+        extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64']
+    else:
+        extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64']
+else:
+    extra_include_dirs += ["../ContourTree/", "../Core/","."]
+    if sys.version_info.major == 3:
+        extra_libraries += ['boost_python3', 'boost_numpy3','gomp']
+    else:
+        extra_libraries += ['boost_python', 'boost_numpy','gomp']
+
+setup(
+    name='ccpi',
+	description='CCPi Core Imaging Library - FISTA Reconstruction Module',
+	version=cil_version,
+    cmdclass = {'build_ext': build_ext},
+    ext_modules = [Extension("fista",
+                             sources=[  "Matlab2Python_utils.cpp",
+                                        ],
+                             include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), 
+    
+    ],
+	zip_safe = False,	
+	packages = {'ccpi','ccpi.reconstruction'},
+)
diff --git a/src/Python/setup_test.py b/src/Python/setup_test.py
new file mode 100644
index 0000000..ffb9c02
--- /dev/null
+++ b/src/Python/setup_test.py
@@ -0,0 +1,58 @@
+#!/usr/bin/env python
+
+import setuptools
+from distutils.core import setup
+from distutils.extension import Extension
+from Cython.Distutils import build_ext
+
+import os
+import sys
+import numpy
+import platform	
+
+cil_version=os.environ['CIL_VERSION']
+if  cil_version == '':
+    print("Please set the environmental variable CIL_VERSION")
+    sys.exit(1)
+
+library_include_path = ""
+library_lib_path = ""
+try:
+    library_include_path = os.environ['LIBRARY_INC']
+    library_lib_path = os.environ['LIBRARY_LIB']
+except:
+    library_include_path = os.environ['PREFIX']+'/include'
+    pass
+    
+extra_include_dirs = [numpy.get_include(), library_include_path]
+extra_library_dirs = [library_include_path+"/../lib", "C:\\Apps\\Miniconda2\\envs\\cil27\\Library\\lib"]
+extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x']
+extra_libraries = []
+if platform.system() == 'Windows':
+    extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB']   
+    extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."]
+    if sys.version_info.major == 3 :   
+        extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64']
+    else:
+        extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64']
+else:
+    extra_include_dirs += ["../ContourTree/", "../Core/","."]
+    if sys.version_info.major == 3:
+        extra_libraries += ['boost_python3', 'boost_numpy3','gomp']
+    else:
+        extra_libraries += ['boost_python', 'boost_numpy','gomp']
+
+setup(
+    name='ccpi',
+	description='CCPi Core Imaging Library - FISTA Reconstruction Module',
+	version=cil_version,
+    cmdclass = {'build_ext': build_ext},
+    ext_modules = [Extension("fista",
+                             sources=[  "Matlab2Python_utils.cpp",
+                                        ],
+                             include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), 
+    
+    ],
+	zip_safe = False,	
+	packages = {'ccpi','ccpi.reconstruction'},
+)
-- 
cgit v1.2.3


From 22d41e596544668e71f3abef321d48f0a54f0f53 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Thu, 3 Aug 2017 16:52:20 +0100
Subject: added FGP_TV wrapper

---
 src/Python/fista_module.cpp | 576 ++++++++++++++++++++++++++++++++++++--------
 1 file changed, 473 insertions(+), 103 deletions(-)

(limited to 'src/Python')

diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp
index 5344083..2492884 100644
--- a/src/Python/fista_module.cpp
+++ b/src/Python/fista_module.cpp
@@ -26,12 +26,10 @@ limitations under the License.
 #include <boost/python/numpy.hpp>
 #include "boost/tuple/tuple.hpp"
 
-// include the regularizers
-#include "FGP_TV_core.h"
-#include "LLT_model_core.h"
-#include "PatchBased_Regul_core.h"
 #include "SplitBregman_TV_core.h"
-#include "TGV_PD_core.h"
+#include "FGP_TV_core.h"
+
+
 
 #if defined(_WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(_WIN64)
 #include <windows.h>
@@ -43,7 +41,6 @@ __if_not_exists(uint16_t) { typedef __int8 uint16_t; }
 namespace bp = boost::python;
 namespace np = boost::python::numpy;
 
-
 /*! in the Matlab implementation this is called as
 void mexFunction(
 int nlhs, mxArray *plhs[],
@@ -56,9 +53,7 @@ nlhs Array of pointers to the OUTPUT mxArrays
 plhs int number of OUTPUT mxArrays
 
 ***********************************************************
-mxGetData
-args: pm Pointer to an mxArray
-Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data.
+
 ***********************************************************
 double mxGetScalar(const mxArray *pm);
 args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray.
@@ -92,77 +87,143 @@ double *mxGetPr(const mxArray *pm);
 args: pm Pointer to an mxArray of type double
 Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data.
 ****************************************************************
-mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, mxClassID classid, mxComplexity ComplexFlag);
-args:  ndim: Number of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2.
-       dims: Dimensions array. Each element in the dimensions array contains the size of the array in that dimension.
-	         For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array.
-       classid: Identifier for the class of the array, which determines the way the numerical data is represented in memory.
-                For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer.
-       ComplexFlag:  If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). 
-	                 Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran).
-
+mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims,
+mxClassID classid, mxComplexity ComplexFlag);
+args: ndimNumber of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2.
+dims Dimensions array. Each element in the dimensions array contains the size of the array in that dimension.
+For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array.
+classid Identifier for the class of the array, which determines the way the numerical data is represented in memory.
+For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer.
+ComplexFlag  If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran).
 Returns: Pointer to the created mxArray, if successful. If unsuccessful in a standalone (non-MEX file) application, returns NULL in C (0 in Fortran).
-       If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not
-       enough free heap space to create the mxArray.
+If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not
+enough free heap space to create the mxArray.
+*/
+
+void mexErrMessageText(char* text) {
+	std::cerr << text << std::endl;
+}
+
+/*
+double mxGetScalar(const mxArray *pm);
+args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray.
+Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray.	In C, mxGetScalar returns a double.
 */
 
 template<typename T>
-np::ndarray zeros(int dims, int * dim_array, T el) {
-	bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]);
-	np::dtype dtype = np::dtype::get_builtin<T>();
-	np::ndarray zz = np::zeros(shape, dtype);
-	return zz;
+double mxGetScalar(const np::ndarray plh) {
+	return (double)bp::extract<T>(plh[0]);
 }
 
 
-bp::list SplitBregman_TV(np::ndarray input, double d_mu, , int niterations, double d_epsil, int TV_type) {
-	/* C-OMP implementation of Split Bregman - TV denoising-regularization model (2D/3D)
-	*
-	* Input Parameters:
-	* 1. Noisy image/volume
-	* 2. lambda - regularization parameter
-	* 3. Number of iterations [OPTIONAL parameter]
-	* 4. eplsilon - tolerance constant [OPTIONAL parameter]
-	* 5. TV-type: 'iso' or 'l1' [OPTIONAL parameter]
-	*
-	* Output:
-	* Filtered/regularized image
-	*
-	* All sanity checks and default values are set in Python
+
+template<typename T>
+T * mxGetData(const np::ndarray pm) {
+	//args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray.
+	//Returns: Pointer to the value of the first real(nonimaginary) element of the mxArray.In C, mxGetScalar returns a double.
+	/*Access the numpy array pointer:
+	char * get_data() const;
+	Returns:	Array�s raw data pointer as a char
+	Note:	This returns char so stride math works properly on it.User will have to reinterpret_cast it.
+	probably this would work.
+	A = reinterpret_cast<float *>(prhs[0]);
 	*/
+	return reinterpret_cast<T *>(prhs[0]);
+}
+
+
+
+
+bp::list mexFunction(np::ndarray input) {
+	int number_of_dims = input.get_nd();
+	int dim_array[3];
+
+	dim_array[0] = input.shape(0);
+	dim_array[1] = input.shape(1);
+	if (number_of_dims == 2) {
+		dim_array[2] = -1;
+	}
+	else {
+		dim_array[2] = input.shape(2);
+	}
+
+	/**************************************************************************/
+	np::ndarray zz = zeros(3, dim_array, (int)0);
+	np::ndarray fzz = zeros(3, dim_array, (float)0);
+	/**************************************************************************/
+
+	int * A = reinterpret_cast<int *>(input.get_data());
+	int * B = reinterpret_cast<int *>(zz.get_data());
+	float * C = reinterpret_cast<float *>(fzz.get_data());
+
+	//Copy data and cast
+	for (int i = 0; i < dim_array[0]; i++) {
+		for (int j = 0; j < dim_array[1]; j++) {
+			for (int k = 0; k < dim_array[2]; k++) {
+				int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i;
+				int val = (*(A + index));
+				float fval = (float)val;
+				std::memcpy(B + index, &val, sizeof(int));
+				std::memcpy(C + index, &fval, sizeof(float));
+			}
+		}
+	}
+
+
+	bp::list result;
+
+	result.append<int>(number_of_dims);
+	result.append<int>(dim_array[0]);
+	result.append<int>(dim_array[1]);
+	result.append<int>(dim_array[2]);
+	result.append<np::ndarray>(zz);
+	result.append<np::ndarray>(fzz);
+
+	//result.append<bp::tuple>(tup);
+	return result;
+
+}
+
+bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) {
+	
+	// the result is in the following list
+	bp::list result;
+		
 	int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV;
-	const int dim_array[3];
+	const int  *dim_array;
 	float *A, *U = NULL, *U_old = NULL, *Dx = NULL, *Dy = NULL, *Dz = NULL, *Bx = NULL, *By = NULL, *Bz = NULL, lambda, mu, epsil, re, re1, re_old;
-
+	
 	//number_of_dims = mxGetNumberOfDimensions(prhs[0]);
 	//dim_array = mxGetDimensions(prhs[0]);
-	number_of_dims = input.get_nd();
+
+	int number_of_dims = input.get_nd();
+	int dim_array[3];
 
 	dim_array[0] = input.shape(0);
 	dim_array[1] = input.shape(1);
 	if (number_of_dims == 2) {
-		dim_array[2] = -11;
+		dim_array[2] = -1;
 	}
 	else {
 		dim_array[2] = input.shape(2);
 	}
 
-	/*Handling Matlab input data*/
+	// Parameter handling is be done in Python
+	///*Handling Matlab input data*/
 	//if ((nrhs < 2) || (nrhs > 5)) mexErrMsgTxt("At least 2 parameters is required: Image(2D/3D), Regularization parameter. The full list of parameters: Image(2D/3D), Regularization parameter, iterations number, tolerance, penalty type ('iso' or 'l1')");
 
-	/*Handling Matlab input data*/
+	///*Handling Matlab input data*/
 	//A = (float *)mxGetData(prhs[0]); /*noisy image (2D/3D) */
 	A = reinterpret_cast<float *>(input.get_data());
 
-
 	//mu = (float)mxGetScalar(prhs[1]); /* regularization parameter */
 	mu = (float)d_mu;
+
 	//iter = 35; /* default iterations number */
-	iter = niterations;
+	
 	//epsil = 0.0001; /* default tolerance constant */
 	epsil = (float)d_epsil;
 	//methTV = 0;  /* default isotropic TV penalty */
-	methTV = TV_type;
 	//if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5))  iter = (int)mxGetScalar(prhs[2]); /* iterations number */
 	//if ((nrhs == 4) || (nrhs == 5))  epsil = (float)mxGetScalar(prhs[3]); /* tolerance constant */
 	//if (nrhs == 5) {
@@ -182,34 +243,31 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, , int niterations, doub
 
 	if (number_of_dims == 2) {
 		dimZ = 1; /*2D case*/
-		/*
-		mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, mxClassID classid, mxComplexity ComplexFlag);
-args:  ndim: Number of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2.
-       dims: Dimensions array. Each element in the dimensions array contains the size of the array in that dimension.
-	         For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array.
-       classid: Identifier for the class of the array, which determines the way the numerical data is represented in memory.
-                For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer.
-       ComplexFlag:  If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). 
-	                 Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran).
-
-					 mxCreateNumericArray initializes all its real data elements to 0.
-*/
-
-/*
-		U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
-		U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
-		Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
-		Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
-		Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
-		By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
-*/
 		//U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
-		U = A = reinterpret_cast<float *>input.get_data();
-		U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
-		Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
-		Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
-		Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
-		By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		//U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		//Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		//Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		//Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		//By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]);
+		np::dtype dtype = np::dtype::get_builtin<float>();
+
+		np::ndarray npU     = np::zeros(shape, dtype);
+		np::ndarray npU_old = np::zeros(shape, dtype);
+		np::ndarray npDx    = np::zeros(shape, dtype);
+		np::ndarray npDy    = np::zeros(shape, dtype);
+		np::ndarray npBx    = np::zeros(shape, dtype);
+		np::ndarray npBy    = np::zeros(shape, dtype);
+
+		U     = reinterpret_cast<float *>(npU.get_data());
+		U_old = reinterpret_cast<float *>(npU_old.get_data());
+		Dx    = reinterpret_cast<float *>(npDx.get_data());
+		Dy    = reinterpret_cast<float *>(npDy.get_data());
+		Bx    = reinterpret_cast<float *>(npBx.get_data());
+		By    = reinterpret_cast<float *>(npBy.get_data());
+
+		
+
 		copyIm(A, U, dimX, dimY, dimZ); /*initialize */
 
 										/* begin outer SB iterations */
@@ -245,59 +303,370 @@ args:  ndim: Number of dimensions. If you specify a value for ndim that is less
 			/*printf("%f %i %i \n", re, ll, count); */
 
 			/*copyIm(U_old, U, dimX, dimY, dimZ); */
+			result.append<np::ndarray>(npU);
+			result.append<int>(ll);
+		}
+		//printf("SB iterations stopped at iteration: %i\n", ll);
+		if (number_of_dims == 3) {
+			/*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+			U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+			Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+			Dy = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+			Dz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+			Bx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+			By = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+			Bz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));*/
+			bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]);
+			np::dtype dtype = np::dtype::get_builtin<float>();
+
+			np::ndarray npU     = np::zeros(shape, dtype);
+			np::ndarray npU_old = np::zeros(shape, dtype);
+			np::ndarray npDx    = np::zeros(shape, dtype);
+			np::ndarray npDy    = np::zeros(shape, dtype);
+			np::ndarray npDz    = np::zeros(shape, dtype);
+			np::ndarray npBx    = np::zeros(shape, dtype);
+			np::ndarray npBy    = np::zeros(shape, dtype);
+			np::ndarray npBz    = np::zeros(shape, dtype);
+
+			U     = reinterpret_cast<float *>(npU.get_data());
+			U_old = reinterpret_cast<float *>(npU_old.get_data());
+			Dx    = reinterpret_cast<float *>(npDx.get_data());
+			Dy    = reinterpret_cast<float *>(npDy.get_data());
+			Dz    = reinterpret_cast<float *>(npDz.get_data());
+			Bx    = reinterpret_cast<float *>(npBx.get_data());
+			By    = reinterpret_cast<float *>(npBy.get_data());
+			Bz    = reinterpret_cast<float *>(npBz.get_data());
+
+			copyIm(A, U, dimX, dimY, dimZ); /*initialize */
+
+											/* begin outer SB iterations */
+			for (ll = 0; ll<iter; ll++) {
+
+				/*storing old values*/
+				copyIm(U, U_old, dimX, dimY, dimZ);
+
+				/*GS iteration */
+				gauss_seidel3D(U, A, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ, lambda, mu);
+
+				if (methTV == 1) updDxDyDz_shrinkAniso3D(U, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ, lambda);
+				else updDxDyDz_shrinkIso3D(U, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ, lambda);
+
+				updBxByBz3D(U, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ);
+
+				/* calculate norm to terminate earlier */
+				re = 0.0f; re1 = 0.0f;
+				for (j = 0; j<dimX*dimY*dimZ; j++)
+				{
+					re += pow(U[j] - U_old[j], 2);
+					re1 += pow(U[j], 2);
+				}
+				re = sqrt(re) / sqrt(re1);
+				if (re < epsil)  count++;
+				if (count > 4) break;
+
+				/* check that the residual norm is decreasing */
+				if (ll > 2) {
+					if (re > re_old) break;
+				}
+				/*printf("%f %i %i \n", re, ll, count); */
+				re_old = re;
+			}
+			//printf("SB iterations stopped at iteration: %i\n", ll);
+			result.append<np::ndarray>(npU);
+			result.append<int>(ll);
 		}
-		printf("SB iterations stopped at iteration: %i\n", ll);
 	}
-	if (number_of_dims == 3) {
-		U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
-		U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
-		Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
-		Dy = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
-		Dz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
-		Bx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
-		By = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
-		Bz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+	return result;
 
-		copyIm(A, U, dimX, dimY, dimZ); /*initialize */
+}
 
-										/* begin outer SB iterations */
+bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) {
+
+	// the result is in the following list
+	bp::list result;
+
+	int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV;
+	float *A, *D = NULL, *D_old = NULL, *P1 = NULL, *P2 = NULL, *P3 = NULL, *P1_old = NULL, *P2_old = NULL, *P3_old = NULL, *R1 = NULL, *R2 = NULL, *R3 = NULL;
+	float lambda, tk, tkp1, re, re1, re_old, epsil, funcval;
+
+	//number_of_dims = mxGetNumberOfDimensions(prhs[0]);
+	//dim_array = mxGetDimensions(prhs[0]);
+
+	int number_of_dims = input.get_nd();
+	int dim_array[3];
+
+	dim_array[0] = input.shape(0);
+	dim_array[1] = input.shape(1);
+	if (number_of_dims == 2) {
+		dim_array[2] = -1;
+	}
+	else {
+		dim_array[2] = input.shape(2);
+	}
+
+	// Parameter handling is be done in Python
+	///*Handling Matlab input data*/
+	//if ((nrhs < 2) || (nrhs > 5)) mexErrMsgTxt("At least 2 parameters is required: Image(2D/3D), Regularization parameter. The full list of parameters: Image(2D/3D), Regularization parameter, iterations number, tolerance, penalty type ('iso' or 'l1')");
+
+	///*Handling Matlab input data*/
+	//A = (float *)mxGetData(prhs[0]); /*noisy image (2D/3D) */
+	A = reinterpret_cast<float *>(input.get_data());
+
+	//mu = (float)mxGetScalar(prhs[1]); /* regularization parameter */
+	mu = (float)d_mu;
+
+	//iter = 35; /* default iterations number */
+
+	//epsil = 0.0001; /* default tolerance constant */
+	epsil = (float)d_epsil;
+	//methTV = 0;  /* default isotropic TV penalty */
+	//if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5))  iter = (int)mxGetScalar(prhs[2]); /* iterations number */
+	//if ((nrhs == 4) || (nrhs == 5))  epsil = (float)mxGetScalar(prhs[3]); /* tolerance constant */
+	//if (nrhs == 5) {
+	//	char *penalty_type;
+	//	penalty_type = mxArrayToString(prhs[4]); /* choosing TV penalty: 'iso' or 'l1', 'iso' is the default */
+	//	if ((strcmp(penalty_type, "l1") != 0) && (strcmp(penalty_type, "iso") != 0)) mexErrMsgTxt("Choose TV type: 'iso' or 'l1',");
+	//	if (strcmp(penalty_type, "l1") == 0)  methTV = 1;  /* enable 'l1' penalty */
+	//	mxFree(penalty_type);
+	//}
+	//if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); }
+
+	//plhs[1] = mxCreateNumericMatrix(1, 1, mxSINGLE_CLASS, mxREAL);
+	bp::tuple shape1 = bp::make_tuple(dim_array[0], dim_array[1]);
+	np::dtype dtype = np::dtype::get_builtin<float>();
+	np::ndarray out1 = np::zeros(shape1, dtype);
+	
+	//float *funcvalA = (float *)mxGetData(plhs[1]);
+	float * funcvalA = reinterpret_cast<float *>(out1.get_data());
+	//if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); }
+
+	/*Handling Matlab output data*/
+	dimX = dim_array[0]; dimY = dim_array[1]; dimZ = dim_array[2];
+
+	tk = 1.0f;
+	tkp1 = 1.0f;
+	count = 1;
+	re_old = 0.0f;
+
+	if (number_of_dims == 2) {
+		dimZ = 1; /*2D case*/
+		/*D = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		D_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		P1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		P2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		P1_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		P2_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		R1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		R2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/
+
+		bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]);
+		np::dtype dtype = np::dtype::get_builtin<float>();
+
+
+		np::ndarray npD      = np::zeros(shape, dtype);
+		np::ndarray npD_old  = np::zeros(shape, dtype);
+		np::ndarray npP1     = np::zeros(shape, dtype);
+		np::ndarray npP2     = np::zeros(shape, dtype);
+		np::ndarray npP1_old = np::zeros(shape, dtype);
+		np::ndarray npP2_old = np::zeros(shape, dtype);
+		np::ndarray npR1     = np::zeros(shape, dtype);
+		np::ndarray npR2     = zeros(2, dim_array, (float)0);
+
+		D      = reinterpret_cast<float *>(npD.get_data());
+		D_old  = reinterpret_cast<float *>(npD_old.get_data());
+		P1     = reinterpret_cast<float *>(npP1.get_data());
+		P2     = reinterpret_cast<float *>(npP2.get_data());
+		P1_old = reinterpret_cast<float *>(npP1_old.get_data());
+		P2_old = reinterpret_cast<float *>(npP2_old.get_data());
+		R1     = reinterpret_cast<float *>(npR1.get_data());
+		R2     = reinterpret_cast<float *>(npR2.get_data());
+
+		/* begin iterations */
 		for (ll = 0; ll<iter; ll++) {
 
-			/*storing old values*/
-			copyIm(U, U_old, dimX, dimY, dimZ);
+			/* computing the gradient of the objective function */
+			Obj_func2D(A, D, R1, R2, lambda, dimX, dimY);
 
-			/*GS iteration */
-			gauss_seidel3D(U, A, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ, lambda, mu);
+			/*Taking a step towards minus of the gradient*/
+			Grad_func2D(P1, P2, D, R1, R2, lambda, dimX, dimY);
 
-			if (methTV == 1) updDxDyDz_shrinkAniso3D(U, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ, lambda);
-			else updDxDyDz_shrinkIso3D(U, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ, lambda);
+			/* projection step */
+			Proj_func2D(P1, P2, methTV, dimX, dimY);
 
-			updBxByBz3D(U, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ);
+			/*updating R and t*/
+			tkp1 = (1.0f + sqrt(1.0f + 4.0f*tk*tk))*0.5f;
+			Rupd_func2D(P1, P1_old, P2, P2_old, R1, R2, tkp1, tk, dimX, dimY);
 
-			/* calculate norm to terminate earlier */
+			/* calculate norm */
 			re = 0.0f; re1 = 0.0f;
 			for (j = 0; j<dimX*dimY*dimZ; j++)
 			{
-				re += pow(U[j] - U_old[j], 2);
-				re1 += pow(U[j], 2);
+				re += pow(D[j] - D_old[j], 2);
+				re1 += pow(D[j], 2);
 			}
 			re = sqrt(re) / sqrt(re1);
 			if (re < epsil)  count++;
-			if (count > 4) break;
+			if (count > 3) {
+				Obj_func2D(A, D, P1, P2, lambda, dimX, dimY);
+				funcval = 0.0f;
+				for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2);
+				//funcvalA[0] = sqrt(funcval);
+				float fv = sqrt(funcval);
+				std::memcpy(funcvalA, &fv), sizeof(float));
+				break;
+			}
 
 			/* check that the residual norm is decreasing */
 			if (ll > 2) {
-				if (re > re_old) break;
+				if (re > re_old) {
+					Obj_func2D(A, D, P1, P2, lambda, dimX, dimY);
+					funcval = 0.0f;
+					for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2);
+					//funcvalA[0] = sqrt(funcval);
+					float fv = sqrt(funcval);
+					std::memcpy(funcvalA, &fv), sizeof(float));
+					break;
+				}
 			}
+			re_old = re;
 			/*printf("%f %i %i \n", re, ll, count); */
+
+			/*storing old values*/
+			copyIm(D, D_old, dimX, dimY, dimZ);
+			copyIm(P1, P1_old, dimX, dimY, dimZ);
+			copyIm(P2, P2_old, dimX, dimY, dimZ);
+			tk = tkp1;
+
+			/* calculating the objective function value */
+			if (ll == (iter - 1)) {
+				Obj_func2D(A, D, P1, P2, lambda, dimX, dimY);
+				funcval = 0.0f;
+				for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2);
+				//funcvalA[0] = sqrt(funcval);
+				float fv = sqrt(funcval);
+				std::memcpy(funcvalA, &fv), sizeof(float));
+			}
+		}
+		//printf("FGP-TV iterations stopped at iteration %i with the function value %f \n", ll, funcvalA[0]);
+		result.append<np::ndarray>(npD);
+		result.append<np::ndarray>(out1);
+		result.append<int>(ll);
+	}
+	if (number_of_dims == 3) {
+		/*D = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+		D_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+		P1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+		P2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+		P3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+		P1_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+		P2_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+		P3_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+		R1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+		R2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+		R3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));*/
+		bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]);
+		np::dtype dtype = np::dtype::get_builtin<float>();
+		
+		np::ndarray npD      = np::zeros(shape, dtype);
+		np::ndarray npD_old  = np::zeros(shape, dtype);
+		np::ndarray npP1     = np::zeros(shape, dtype);
+		np::ndarray npP2     = np::zeros(shape, dtype);
+		np::ndarray npP3     = np::zeros(shape, dtype);
+		np::ndarray npP1_old = np::zeros(shape, dtype);
+		np::ndarray npP2_old = np::zeros(shape, dtype);
+		np::ndarray npP3_old = np::zeros(shape, dtype);
+		np::ndarray npR1     = np::zeros(shape, dtype);
+		np::ndarray npR2     = np::zeros(shape, dtype);
+		np::ndarray npR3     = np::zeros(shape, dtype);
+
+		D      = reinterpret_cast<float *>(npD.get_data());
+		D_old  = reinterpret_cast<float *>(npD_old.get_data());
+		P1     = reinterpret_cast<float *>(npP1.get_data());
+		P2     = reinterpret_cast<float *>(npP2.get_data());
+		P3     = reinterpret_cast<float *>(npP3.get_data());
+		P1_old = reinterpret_cast<float *>(npP1_old.get_data());
+		P2_old = reinterpret_cast<float *>(npP2_old.get_data());
+		P3_old = reinterpret_cast<float *>(npP3_old.get_data());
+		R1     = reinterpret_cast<float *>(npR1.get_data());
+		R2     = reinterpret_cast<float *>(npR2.get_data());
+		R2     = reinterpret_cast<float *>(npR3.get_data());
+		/* begin iterations */
+		for (ll = 0; ll<iter; ll++) {
+
+			/* computing the gradient of the objective function */
+			Obj_func3D(A, D, R1, R2, R3, lambda, dimX, dimY, dimZ);
+
+			/*Taking a step towards minus of the gradient*/
+			Grad_func3D(P1, P2, P3, D, R1, R2, R3, lambda, dimX, dimY, dimZ);
+
+			/* projection step */
+			Proj_func3D(P1, P2, P3, dimX, dimY, dimZ);
+
+			/*updating R and t*/
+			tkp1 = (1.0f + sqrt(1.0f + 4.0f*tk*tk))*0.5f;
+			Rupd_func3D(P1, P1_old, P2, P2_old, P3, P3_old, R1, R2, R3, tkp1, tk, dimX, dimY, dimZ);
+
+			/* calculate norm - stopping rules*/
+			re = 0.0f; re1 = 0.0f;
+			for (j = 0; j<dimX*dimY*dimZ; j++)
+			{
+				re += pow(D[j] - D_old[j], 2);
+				re1 += pow(D[j], 2);
+			}
+			re = sqrt(re) / sqrt(re1);
+			/* stop if the norm residual is less than the tolerance EPS */
+			if (re < epsil)  count++;
+			if (count > 3) {
+				Obj_func3D(A, D, P1, P2, P3, lambda, dimX, dimY, dimZ);
+				funcval = 0.0f;
+				for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2);
+				//funcvalA[0] = sqrt(funcval);
+				float fv = sqrt(funcval);
+				std::memcpy(funcvalA, &fv), sizeof(float));
+				break;
+			}
+
+			/* check that the residual norm is decreasing */
+			if (ll > 2) {
+				if (re > re_old) {
+					Obj_func3D(A, D, P1, P2, P3, lambda, dimX, dimY, dimZ);
+					funcval = 0.0f;
+					for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2);
+					//funcvalA[0] = sqrt(funcval);
+					float fv = sqrt(funcval);
+					std::memcpy(funcvalA, &fv), sizeof(float));
+					break;
+				}
+			}
+
 			re_old = re;
+			/*printf("%f %i %i \n", re, ll, count); */
+
+			/*storing old values*/
+			copyIm(D, D_old, dimX, dimY, dimZ);
+			copyIm(P1, P1_old, dimX, dimY, dimZ);
+			copyIm(P2, P2_old, dimX, dimY, dimZ);
+			copyIm(P3, P3_old, dimX, dimY, dimZ);
+			tk = tkp1;
+
+			if (ll == (iter - 1)) {
+				Obj_func3D(A, D, P1, P2, P3, lambda, dimX, dimY, dimZ);
+				funcval = 0.0f;
+				for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2);
+				//funcvalA[0] = sqrt(funcval);
+				float fv = sqrt(funcval);
+				std::memcpy(funcvalA, &fv), sizeof(float));
+			}
+
 		}
-		printf("SB iterations stopped at iteration: %i\n", ll);
+		//printf("FGP-TV iterations stopped at iteration %i with the function value %f \n", ll, funcvalA[0]);
+		result.append<np::ndarray>(npD);
+		result.append<np::ndarray>(out1);
+		result.append<int>(ll);
 	}
-	bp::list result;
+
 	return result;
 }
-	
 
 BOOST_PYTHON_MODULE(fista)
 {
@@ -310,6 +679,7 @@ BOOST_PYTHON_MODULE(fista)
 	np::dtype dt1 = np::dtype::get_builtin<uint8_t>();
 	np::dtype dt2 = np::dtype::get_builtin<uint16_t>();
 
-	
 	def("mexFunction", mexFunction);
+	def("SplitBregman_TV", SplitBregman_TV);
+	def("FGP_TV", FGP_TV);
 }
\ No newline at end of file
-- 
cgit v1.2.3


From 12dbe738d5a2af5573e33a31f1745a50dba165ba Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Fri, 4 Aug 2017 16:15:03 +0100
Subject: compilation fixes

---
 src/Python/setup.py | 14 +++++++++-----
 1 file changed, 9 insertions(+), 5 deletions(-)

(limited to 'src/Python')

diff --git a/src/Python/setup.py b/src/Python/setup.py
index ffb9c02..a8feb1c 100644
--- a/src/Python/setup.py
+++ b/src/Python/setup.py
@@ -29,14 +29,14 @@ extra_library_dirs = [library_include_path+"/../lib", "C:\\Apps\\Miniconda2\\env
 extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x']
 extra_libraries = []
 if platform.system() == 'Windows':
-    extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB']   
-    extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."]
+    extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB' , '/openmp' ]   
+    extra_include_dirs += ["..\\..\\main_func\\regularizers_CPU\\","."]
     if sys.version_info.major == 3 :   
         extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64']
     else:
         extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64']
 else:
-    extra_include_dirs += ["../ContourTree/", "../Core/","."]
+    extra_include_dirs += ["../../main_func/regularizers_CPU","."]
     if sys.version_info.major == 3:
         extra_libraries += ['boost_python3', 'boost_numpy3','gomp']
     else:
@@ -47,8 +47,12 @@ setup(
 	description='CCPi Core Imaging Library - FISTA Reconstruction Module',
 	version=cil_version,
     cmdclass = {'build_ext': build_ext},
-    ext_modules = [Extension("fista",
-                             sources=[  "Matlab2Python_utils.cpp",
+    ext_modules = [Extension("regularizers",
+                             sources=["fista_module.cpp",
+                                      "..\\..\\main_func\\regularizers_CPU\\FGP_TV_core.c",
+                                      "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c",
+                                      "..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c",
+                                      "..\\..\\main_func\\regularizers_CPU\\utils.c"
                                         ],
                              include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), 
     
-- 
cgit v1.2.3


From 36e4c296223f67bb917511089ec59533460f1695 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Fri, 4 Aug 2017 16:15:17 +0100
Subject: test facility for regularizers

---
 src/Python/test_regularizers.py | 265 ++++++++++++++++++++++++++++++++++++++++
 1 file changed, 265 insertions(+)
 create mode 100644 src/Python/test_regularizers.py

(limited to 'src/Python')

diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py
new file mode 100644
index 0000000..6abfba4
--- /dev/null
+++ b/src/Python/test_regularizers.py
@@ -0,0 +1,265 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Fri Aug  4 11:10:05 2017
+
+@author: ofn77899
+"""
+
+from ccpi.viewer.CILViewer2D import Converter
+import vtk
+
+import regularizers
+import matplotlib.pyplot as plt
+import numpy as np
+import os    
+from enum import Enum
+
+class Regularizer():
+    '''Class to handle regularizer algorithms to be used during reconstruction
+    
+    Currently 5 regularization algorithms are available:
+        
+    1) SplitBregman_TV
+    2) FGP_TV
+    3)
+    4)
+    5)
+    
+    Usage:
+        the regularizer can be invoked as object or as static method
+        Depending on the actual regularizer the input parameter may vary, and 
+        a different default setting is defined.
+        reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+
+        out = reg(input=u0, regularization_parameter=10., number_of_iterations=30,
+          tolerance_constant=1e-4, 
+          TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+
+        out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10.,
+          number_of_iterations=30, tolerance_constant=1e-4, 
+          TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+        
+        A number of optional parameters can be passed or skipped
+        out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. )
+
+    '''
+    class Algorithm(Enum):
+        SplitBregman_TV = regularizers.SplitBregman_TV
+        FGP_TV = regularizers.FGP_TV
+        LLT_model = regularizers.LLT_model
+    # Algorithm
+    
+    class TotalVariationPenalty(Enum):
+        isotropic = 0
+        l1 = 1
+    # TotalVariationPenalty
+        
+    def __init__(self , algorithm):
+        
+        self.algorithm = algorithm
+        self.pars = self.parsForAlgorithm(algorithm)
+    # __init__
+        
+    def parsForAlgorithm(self, algorithm):
+        pars = dict()
+        if algorithm == Regularizer.Algorithm.SplitBregman_TV :
+            pars['algorithm'] = algorithm
+            pars['input'] = None
+            pars['regularization_parameter'] = None
+            pars['number_of_iterations'] = 35
+            pars['tolerance_constant'] = 0.0001
+            pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic
+        elif algorithm == Regularizer.Algorithm.FGP_TV :
+            pars['algorithm'] = algorithm
+            pars['input'] = None
+            pars['regularization_parameter'] = None
+            pars['number_of_iterations'] = 50
+            pars['tolerance_constant'] = 0.001
+            pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic
+        elif algorithm == Regularizer.Algorithm.LLT_model:
+            pars['algorithm'] = algorithm
+            pars['input'] = None
+            pars['regularization_parameter'] = None
+            pars['time_step'] = None
+            pars['number_of_iterations'] = None
+            pars['tolerance_constant'] = None
+            pars['restrictive_Z_smoothing'] = 0
+            
+        return pars
+    # parsForAlgorithm
+        
+    def __call__(self, input, regularization_parameter, **kwargs):
+        
+        if kwargs is not None:
+            for key, value in kwargs.items():
+                #print("{0} = {1}".format(key, value))
+                self.pars[key] = value
+        self.pars['input'] = input
+        self.pars['regularization_parameter'] = regularization_parameter
+        #for key, value in self.pars.items():
+        #        print("{0} = {1}".format(key, value))
+                
+        if self.algorithm == Regularizer.Algorithm.SplitBregman_TV :
+            return self.algorithm(input, regularization_parameter,
+                              self.pars['number_of_iterations'],
+                              self.pars['tolerance_constant'],
+                              self.pars['TV_penalty'].value )    
+        elif self.algorithm == Regularizer.Algorithm.FGP_TV :
+            return self.algorithm(input, regularization_parameter,
+                              self.pars['number_of_iterations'],
+                              self.pars['tolerance_constant'],
+                              self.pars['TV_penalty'].value )
+        elif self.algorithm == Regularizer.Algorithm.LLT_model :
+            #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
+            # no default
+            if None in self.pars:
+                raise Exception("Not all parameters have been provided")
+            else:
+                return self.algorithm(input, 
+                                  regularization_parameter,
+                                  self.pars['time_step'] , 
+                                  self.pars['number_of_iterations'],
+                                  self.pars['tolerance_constant'],
+                                  self.pars['restrictive_Z_smoothing'] )
+            
+        
+    # __call__
+    
+    @staticmethod
+    def SplitBregman_TV(input, regularization_parameter , **kwargs):
+        reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+        out = list( reg(input, regularization_parameter, **kwargs) )
+        out.append(reg.pars)
+        return out
+        
+    @staticmethod
+    def FGP_TV(input, regularization_parameter , **kwargs):
+        reg = Regularizer(Regularizer.Algorithm.FGP_TV)
+        out = list( reg(input, regularization_parameter, **kwargs) )
+        out.append(reg.pars)
+        return out
+    
+    @staticmethod
+    def LLT_model(input, regularization_parameter , time_step, number_of_iterations,
+                  tolerance_constant, restrictive_Z_smoothing=0):
+        reg = Regularizer(Regularizer.Algorithm.FGP_TV)
+        out = list( reg(input, regularization_parameter, time_step=time_step, 
+                        number_of_iterations=number_of_iterations,
+                        tolerance_constant=tolerance_constant, 
+                        restrictive_Z_smoothing=restrictive_Z_smoothing) )
+        out.append(reg.pars)
+        return out
+        
+
+#Example:
+# figure;
+# Im = double(imread('lena_gray_256.tif'))/255;  % loading image
+# u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0;
+# u = SplitBregman_TV(single(u0), 10, 30, 1e-04);
+
+filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif"
+reader = vtk.vtkTIFFReader()
+reader.SetFileName(os.path.normpath(filename))
+reader.Update()
+#vtk returns 3D images, let's take just the one slice there is as 2D
+Im = Converter.vtk2numpy(reader.GetOutput()).T[0]/255
+
+#imgplot = plt.imshow(Im)
+perc = 0.05
+u0 = Im + (perc* np.random.normal(size=np.shape(Im)))
+# map the u0 u0->u0>0
+f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1)
+u0 = f(u0).astype('float32')
+
+# plot 
+fig = plt.figure()
+a=fig.add_subplot(2,3,1)
+a.set_title('Original')
+imgplot = plt.imshow(Im)
+
+a=fig.add_subplot(2,3,2)
+a.set_title('noise')
+imgplot = plt.imshow(u0)
+
+
+##############################################################################
+# Call regularizer
+
+####################### SplitBregman_TV #####################################
+# u = SplitBregman_TV(single(u0), 10, 30, 1e-04);
+
+reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+
+out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30,
+          #tolerance_constant=1e-4, 
+          TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+
+out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30,
+          tolerance_constant=1e-4, 
+          TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. )
+pars = out2[2]
+
+a=fig.add_subplot(2,3,3)
+a.set_title('SplitBregman_TV')
+textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s'
+textstr = textstr % (pars['regularization_parameter'], 
+                     pars['number_of_iterations'], 
+                     pars['tolerance_constant'],
+                     pars['TV_penalty'].name)
+
+# these are matplotlib.patch.Patch properties
+props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+# place a text box in upper left in axes coords
+a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+        verticalalignment='top', bbox=props)
+imgplot = plt.imshow(out2[0])
+
+###################### FGP_TV #########################################
+# u = FGP_TV(single(u0), 0.05, 100, 1e-04);
+out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.05,
+                          number_of_iterations=10)
+pars = out2[-1]
+
+a=fig.add_subplot(2,3,4)
+a.set_title('FGP_TV')
+textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s'
+textstr = textstr % (pars['regularization_parameter'], 
+                     pars['number_of_iterations'], 
+                     pars['tolerance_constant'],
+                     pars['TV_penalty'].name)
+
+# these are matplotlib.patch.Patch properties
+props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+# place a text box in upper left in axes coords
+a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+        verticalalignment='top', bbox=props)
+imgplot = plt.imshow(out2[0])
+
+###################### LLT_model #########################################
+# * u0 = Im + .03*randn(size(Im)); % adding noise
+# [Den] = LLT_model(single(u0), 10, 0.1, 1);
+out2 = Regularizer.LLT_model(input=u0, regularization_parameter=10.,
+                          time_step=0.1,
+                          tolerance_constant=1e-4,
+                          number_of_iterations=10)
+pars = out2[-1]
+
+a=fig.add_subplot(2,3,5)
+a.set_title('LLT_model')
+textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\ntime-step=%f'
+textstr = textstr % (pars['regularization_parameter'], 
+                     pars['number_of_iterations'], 
+                     pars['tolerance_constant'],
+                     pars['time_step']
+                     )
+
+# these are matplotlib.patch.Patch properties
+props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+# place a text box in upper left in axes coords
+a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+        verticalalignment='top', bbox=props)
+imgplot = plt.imshow(out2[0])
+
+
+
-- 
cgit v1.2.3


From fd496731c8e9d4975864d76dbb6574cbeee7cf98 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Fri, 4 Aug 2017 16:16:37 +0100
Subject: Added 3 regularizers

SplitBregman_TV
FGP_TV
LLT_model
---
 src/Python/fista_module.cpp | 266 ++++++++++++++++++++++++++++++++++++++------
 1 file changed, 232 insertions(+), 34 deletions(-)

(limited to 'src/Python')

diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp
index 2492884..d890b10 100644
--- a/src/Python/fista_module.cpp
+++ b/src/Python/fista_module.cpp
@@ -3,7 +3,7 @@ This work is part of the Core Imaging Library developed by
 Visual Analytics and Imaging System Group of the Science Technology
 Facilities Council, STFC
 
-Copyright 2017 Daniil Kazanteev
+Copyright 2017 Daniil Kazantsev
 Copyright 2017 Srikanth Nagella, Edoardo Pasca
 
 Licensed under the Apache License, Version 2.0 (the "License");
@@ -28,6 +28,8 @@ limitations under the License.
 
 #include "SplitBregman_TV_core.h"
 #include "FGP_TV_core.h"
+#include "LLT_model_core.h"
+#include "utils.h"
 
 
 
@@ -131,6 +133,18 @@ T * mxGetData(const np::ndarray pm) {
 	return reinterpret_cast<T *>(prhs[0]);
 }
 
+template<typename T>
+np::ndarray zeros(int dims, int * dim_array, T el) {
+	bp::tuple shape;
+	if (dims == 3)
+		shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]);
+	else if (dims == 2)
+		shape = bp::make_tuple(dim_array[0], dim_array[1]);
+	np::dtype dtype = np::dtype::get_builtin<T>();
+	np::ndarray zz = np::zeros(shape, dtype);
+	return zz;
+}
+
 
 
 
@@ -169,7 +183,6 @@ bp::list mexFunction(np::ndarray input) {
 		}
 	}
 
-
 	bp::list result;
 
 	result.append<int>(number_of_dims);
@@ -189,14 +202,14 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi
 	// the result is in the following list
 	bp::list result;
 		
-	int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV;
-	const int  *dim_array;
+	int number_of_dims, dimX, dimY, dimZ, ll, j, count;
+	//const int  *dim_array;
 	float *A, *U = NULL, *U_old = NULL, *Dx = NULL, *Dy = NULL, *Dz = NULL, *Bx = NULL, *By = NULL, *Bz = NULL, lambda, mu, epsil, re, re1, re_old;
 	
 	//number_of_dims = mxGetNumberOfDimensions(prhs[0]);
 	//dim_array = mxGetDimensions(prhs[0]);
 
-	int number_of_dims = input.get_nd();
+	number_of_dims = input.get_nd();
 	int dim_array[3];
 
 	dim_array[0] = input.shape(0);
@@ -252,26 +265,26 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi
 		bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]);
 		np::dtype dtype = np::dtype::get_builtin<float>();
 
-		np::ndarray npU     = np::zeros(shape, dtype);
+		np::ndarray npU = np::zeros(shape, dtype);
 		np::ndarray npU_old = np::zeros(shape, dtype);
-		np::ndarray npDx    = np::zeros(shape, dtype);
-		np::ndarray npDy    = np::zeros(shape, dtype);
-		np::ndarray npBx    = np::zeros(shape, dtype);
-		np::ndarray npBy    = np::zeros(shape, dtype);
+		np::ndarray npDx = np::zeros(shape, dtype);
+		np::ndarray npDy = np::zeros(shape, dtype);
+		np::ndarray npBx = np::zeros(shape, dtype);
+		np::ndarray npBy = np::zeros(shape, dtype);
 
-		U     = reinterpret_cast<float *>(npU.get_data());
+		U = reinterpret_cast<float *>(npU.get_data());
 		U_old = reinterpret_cast<float *>(npU_old.get_data());
-		Dx    = reinterpret_cast<float *>(npDx.get_data());
-		Dy    = reinterpret_cast<float *>(npDy.get_data());
-		Bx    = reinterpret_cast<float *>(npBx.get_data());
-		By    = reinterpret_cast<float *>(npBy.get_data());
+		Dx = reinterpret_cast<float *>(npDx.get_data());
+		Dy = reinterpret_cast<float *>(npDy.get_data());
+		Bx = reinterpret_cast<float *>(npBx.get_data());
+		By = reinterpret_cast<float *>(npBy.get_data());
+
 
-		
 
 		copyIm(A, U, dimX, dimY, dimZ); /*initialize */
 
 										/* begin outer SB iterations */
-		for (ll = 0; ll<iter; ll++) {
+		for (ll = 0; ll < iter; ll++) {
 
 			/*storing old values*/
 			copyIm(U, U_old, dimX, dimY, dimZ);
@@ -286,7 +299,7 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi
 
 			/* calculate norm to terminate earlier */
 			re = 0.0f; re1 = 0.0f;
-			for (j = 0; j<dimX*dimY*dimZ; j++)
+			for (j = 0; j < dimX*dimY*dimZ; j++)
 			{
 				re += pow(U_old[j] - U[j], 2);
 				re1 += pow(U_old[j], 2);
@@ -303,11 +316,13 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi
 			/*printf("%f %i %i \n", re, ll, count); */
 
 			/*copyIm(U_old, U, dimX, dimY, dimZ); */
-			result.append<np::ndarray>(npU);
-			result.append<int>(ll);
+			
 		}
 		//printf("SB iterations stopped at iteration: %i\n", ll);
-		if (number_of_dims == 3) {
+		result.append<np::ndarray>(npU);
+		result.append<int>(ll);
+	}
+	if (number_of_dims == 3) {
 			/*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
 			U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
 			Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
@@ -375,24 +390,25 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi
 			result.append<np::ndarray>(npU);
 			result.append<int>(ll);
 		}
-	}
 	return result;
 
-}
+	}
+
+
 
 bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) {
 
 	// the result is in the following list
 	bp::list result;
 
-	int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV;
+	int number_of_dims, dimX, dimY, dimZ, ll, j, count;
 	float *A, *D = NULL, *D_old = NULL, *P1 = NULL, *P2 = NULL, *P3 = NULL, *P1_old = NULL, *P2_old = NULL, *P3_old = NULL, *R1 = NULL, *R2 = NULL, *R3 = NULL;
-	float lambda, tk, tkp1, re, re1, re_old, epsil, funcval;
+	float lambda, tk, tkp1, re, re1, re_old, epsil, funcval, mu;
 
 	//number_of_dims = mxGetNumberOfDimensions(prhs[0]);
 	//dim_array = mxGetDimensions(prhs[0]);
 
-	int number_of_dims = input.get_nd();
+	number_of_dims = input.get_nd();
 	int dim_array[3];
 
 	dim_array[0] = input.shape(0);
@@ -512,7 +528,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me
 				for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2);
 				//funcvalA[0] = sqrt(funcval);
 				float fv = sqrt(funcval);
-				std::memcpy(funcvalA, &fv), sizeof(float));
+				std::memcpy(funcvalA, &fv, sizeof(float));
 				break;
 			}
 
@@ -524,7 +540,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me
 					for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2);
 					//funcvalA[0] = sqrt(funcval);
 					float fv = sqrt(funcval);
-					std::memcpy(funcvalA, &fv), sizeof(float));
+					std::memcpy(funcvalA, &fv, sizeof(float));
 					break;
 				}
 			}
@@ -544,7 +560,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me
 				for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2);
 				//funcvalA[0] = sqrt(funcval);
 				float fv = sqrt(funcval);
-				std::memcpy(funcvalA, &fv), sizeof(float));
+				std::memcpy(funcvalA, &fv, sizeof(float));
 			}
 		}
 		//printf("FGP-TV iterations stopped at iteration %i with the function value %f \n", ll, funcvalA[0]);
@@ -622,7 +638,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me
 				for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2);
 				//funcvalA[0] = sqrt(funcval);
 				float fv = sqrt(funcval);
-				std::memcpy(funcvalA, &fv), sizeof(float));
+				std::memcpy(funcvalA, &fv, sizeof(float));
 				break;
 			}
 
@@ -634,7 +650,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me
 					for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2);
 					//funcvalA[0] = sqrt(funcval);
 					float fv = sqrt(funcval);
-					std::memcpy(funcvalA, &fv), sizeof(float));
+					std::memcpy(funcvalA, &fv, sizeof(float));
 					break;
 				}
 			}
@@ -655,7 +671,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me
 				for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2);
 				//funcvalA[0] = sqrt(funcval);
 				float fv = sqrt(funcval);
-				std::memcpy(funcvalA, &fv), sizeof(float));
+				std::memcpy(funcvalA, &fv, sizeof(float));
 			}
 
 		}
@@ -668,13 +684,194 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me
 	return result;
 }
 
-BOOST_PYTHON_MODULE(fista)
+bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) {
+	// the result is in the following list
+	bp::list result;
+
+	int number_of_dims, dimX, dimY, dimZ, ll, j, count;
+	//const int  *dim_array;
+	float *U0, *U = NULL, *U_old = NULL, *D1 = NULL, *D2 = NULL, *D3 = NULL, lambda, tau, re, re1, epsil, re_old;
+	unsigned short *Map = NULL;
+
+	number_of_dims = input.get_nd();
+	int dim_array[3];
+
+	dim_array[0] = input.shape(0);
+	dim_array[1] = input.shape(1);
+	if (number_of_dims == 2) {
+		dim_array[2] = -1;
+	}
+	else {
+		dim_array[2] = input.shape(2);
+	}
+
+	///*Handling Matlab input data*/
+	//U0 = (float *)mxGetData(prhs[0]); /*origanal noise image/volume*/
+	//if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); }
+	//lambda = (float)mxGetScalar(prhs[1]); /*regularization parameter*/
+	//tau = (float)mxGetScalar(prhs[2]); /* time-step */
+	//iter = (int)mxGetScalar(prhs[3]); /*iterations number*/
+	//epsil = (float)mxGetScalar(prhs[4]); /* tolerance constant */
+	//switcher = (int)mxGetScalar(prhs[5]); /*switch on (1) restrictive smoothing in Z dimension*/
+	
+	U0 = reinterpret_cast<float *>(input.get_data());
+	lambda = (float)d_lambda;
+	tau = (float)d_tau;
+	// iter is passed as parameter
+	epsil = (float)d_epsil;
+	// switcher is passed as parameter
+										  /*Handling Matlab output data*/
+	dimX = dim_array[0]; dimY = dim_array[1];  dimZ = 1;
+
+	if (number_of_dims == 2) {
+		/*2D case*/
+		/*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		D1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		D2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/
+
+		bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]);
+		np::dtype dtype = np::dtype::get_builtin<float>();
+
+
+		np::ndarray npU = np::zeros(shape, dtype);
+		np::ndarray npU_old = np::zeros(shape, dtype);
+		np::ndarray npD1 = np::zeros(shape, dtype);
+		np::ndarray npD2 = np::zeros(shape, dtype);
+		
+
+		U = reinterpret_cast<float *>(npU.get_data());
+		U_old = reinterpret_cast<float *>(npU_old.get_data());
+		D1 = reinterpret_cast<float *>(npD1.get_data());
+		D2 = reinterpret_cast<float *>(npD2.get_data());
+		
+		/*Copy U0 to U*/
+		copyIm(U0, U, dimX, dimY, dimZ);
+
+		count = 1;
+		re_old = 0.0f;
+
+		for (ll = 0; ll < iter; ll++) {
+
+			copyIm(U, U_old, dimX, dimY, dimZ);
+
+			/*estimate inner derrivatives */
+			der2D(U, D1, D2, dimX, dimY, dimZ);
+			/* calculate div^2 and update */
+			div_upd2D(U0, U, D1, D2, dimX, dimY, dimZ, lambda, tau);
+
+			/* calculate norm to terminate earlier */
+			re = 0.0f; re1 = 0.0f;
+			for (j = 0; j<dimX*dimY*dimZ; j++)
+			{
+				re += pow(U_old[j] - U[j], 2);
+				re1 += pow(U_old[j], 2);
+			}
+			re = sqrt(re) / sqrt(re1);
+			if (re < epsil)  count++;
+			if (count > 4) break;
+
+			/* check that the residual norm is decreasing */
+			if (ll > 2) {
+				if (re > re_old) break;
+			}
+			re_old = re;
+
+		} /*end of iterations*/
+		  //printf("HO iterations stopped at iteration: %i\n", ll);
+
+		result.append<np::ndarray>(npU);
+	}
+	else if (number_of_dims == 3) {
+		/*3D case*/
+		dimZ = dim_array[2];
+		/*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+		U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+		D1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+		D2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+		D3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));
+		if (switcher != 0) {
+			Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL));
+		}*/
+		bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]);
+		np::dtype dtype = np::dtype::get_builtin<float>();
+
+
+		np::ndarray npU = np::zeros(shape, dtype);
+		np::ndarray npU_old = np::zeros(shape, dtype);
+		np::ndarray npD1 = np::zeros(shape, dtype);
+		np::ndarray npD2 = np::zeros(shape, dtype);
+		np::ndarray npD3 = np::zeros(shape, dtype);
+		np::ndarray npMap = np::zeros(shape, np::dtype::get_builtin<unsigned short>());
+		Map = reinterpret_cast<unsigned short *>(npMap.get_data());
+		if (switcher != 0) {
+			//Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL));
+			
+			Map = reinterpret_cast<unsigned short *>(npMap.get_data());
+		}
+
+		U = reinterpret_cast<float *>(npU.get_data());
+		U_old = reinterpret_cast<float *>(npU_old.get_data());
+		D1 = reinterpret_cast<float *>(npD1.get_data());
+		D2 = reinterpret_cast<float *>(npD2.get_data());
+		D3 = reinterpret_cast<float *>(npD2.get_data());
+		
+		/*Copy U0 to U*/
+		copyIm(U0, U, dimX, dimY, dimZ);
+
+		count = 1;
+		re_old = 0.0f;
+	
+
+		if (switcher == 1) {
+			/* apply restrictive smoothing */
+			calcMap(U, Map, dimX, dimY, dimZ);
+			/*clear outliers */
+			cleanMap(Map, dimX, dimY, dimZ);
+		}
+		for (ll = 0; ll < iter; ll++) {
+
+			copyIm(U, U_old, dimX, dimY, dimZ);
+
+			/*estimate inner derrivatives */
+			der3D(U, D1, D2, D3, dimX, dimY, dimZ);
+			/* calculate div^2 and update */
+			div_upd3D(U0, U, D1, D2, D3, Map, switcher, dimX, dimY, dimZ, lambda, tau);
+
+			/* calculate norm to terminate earlier */
+			re = 0.0f; re1 = 0.0f;
+			for (j = 0; j<dimX*dimY*dimZ; j++)
+			{
+				re += pow(U_old[j] - U[j], 2);
+				re1 += pow(U_old[j], 2);
+			}
+			re = sqrt(re) / sqrt(re1);
+			if (re < epsil)  count++;
+			if (count > 4) break;
+
+			/* check that the residual norm is decreasing */
+			if (ll > 2) {
+				if (re > re_old) break;
+			}
+			re_old = re;
+
+		} /*end of iterations*/
+		//printf("HO iterations stopped at iteration: %i\n", ll);
+		result.append<np::ndarray>(npU);
+		if (switcher != 0) result.append<np::ndarray>(npMap);
+
+	}
+	return result;
+}
+
+
+BOOST_PYTHON_MODULE(regularizers)
 {
 	np::initialize();
 
 	//To specify that this module is a package
 	bp::object package = bp::scope();
-	package.attr("__path__") = "fista";
+	package.attr("__path__") = "regularizers";
 
 	np::dtype dt1 = np::dtype::get_builtin<uint8_t>();
 	np::dtype dt2 = np::dtype::get_builtin<uint16_t>();
@@ -682,4 +879,5 @@ BOOST_PYTHON_MODULE(fista)
 	def("mexFunction", mexFunction);
 	def("SplitBregman_TV", SplitBregman_TV);
 	def("FGP_TV", FGP_TV);
+	def("LLT_model", LLT_model);
 }
\ No newline at end of file
-- 
cgit v1.2.3


From 4bef3726577ddf1bf2b594620e106573c6f18693 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Fri, 4 Aug 2017 16:16:53 +0100
Subject: minor change

---
 src/Python/Matlab2Python_utils.cpp | 10 +++++++---
 1 file changed, 7 insertions(+), 3 deletions(-)

(limited to 'src/Python')

diff --git a/src/Python/Matlab2Python_utils.cpp b/src/Python/Matlab2Python_utils.cpp
index 138e8da..6aaad90 100644
--- a/src/Python/Matlab2Python_utils.cpp
+++ b/src/Python/Matlab2Python_utils.cpp
@@ -128,7 +128,11 @@ T * mxGetData(const np::ndarray pm) {
 
 template<typename T>
 np::ndarray zeros(int dims , int * dim_array, T el) {
-	bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]);
+	bp::tuple shape;
+	if (dims == 3)
+		shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]);
+	else if (dims == 2)
+		shape = bp::make_tuple(dim_array[0], dim_array[1]);
 	np::dtype dtype = np::dtype::get_builtin<T>();
 	np::ndarray zz = np::zeros(shape, dtype);
 	return zz;
@@ -163,7 +167,7 @@ bp::list mexFunction( np::ndarray input ) {
 			for (int k = 0; k < dim_array[2]; k++) {
 				int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i;
 				int val = (*(A + index));
-				float fval = (float)val;
+				float fval = sqrt((float)val);
 				std::memcpy(B + index , &val, sizeof(int));
 				std::memcpy(C + index , &fval, sizeof(float));
 			}
@@ -186,7 +190,7 @@ bp::list mexFunction( np::ndarray input ) {
 }
 
 
-BOOST_PYTHON_MODULE(fista)
+BOOST_PYTHON_MODULE(prova)
 {
 	np::initialize();
 
-- 
cgit v1.2.3


From 662ab4ac9c3d89cdc1527c2a2bdcf442f3b6a173 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Fri, 4 Aug 2017 16:17:18 +0100
Subject: test for general boost::python / numpy routines

---
 src/Python/test.py | 18 ++++++++++++++++++
 1 file changed, 18 insertions(+)
 create mode 100644 src/Python/test.py

(limited to 'src/Python')

diff --git a/src/Python/test.py b/src/Python/test.py
new file mode 100644
index 0000000..e283f89
--- /dev/null
+++ b/src/Python/test.py
@@ -0,0 +1,18 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Thu Aug  3 14:08:09 2017
+
+@author: ofn77899
+"""
+
+import fista
+import numpy as np
+
+a = np.asarray([i for i in range(3*4*5)])
+a = a.reshape([3,4,5])
+print (a)
+b = fista.mexFunction(a)
+#print (b)
+print (b[4].shape)
+print (b[4])
+print (b[5])
\ No newline at end of file
-- 
cgit v1.2.3


From ecfb1146dc1de9ea6d8c6587d15417a9690f5ab4 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Fri, 4 Aug 2017 16:49:21 +0100
Subject: added PatchBased_Regul

---
 src/Python/fista_module.cpp | 123 +++++++++++++++++++++++++++++++++++++++++++-
 src/Python/setup.py         |   1 +
 2 files changed, 123 insertions(+), 1 deletion(-)

(limited to 'src/Python')

diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp
index d890b10..c2d9352 100644
--- a/src/Python/fista_module.cpp
+++ b/src/Python/fista_module.cpp
@@ -29,6 +29,7 @@ limitations under the License.
 #include "SplitBregman_TV_core.h"
 #include "FGP_TV_core.h"
 #include "LLT_model_core.h"
+#include "PatchBased_Regul_core.h"
 #include "utils.h"
 
 
@@ -793,7 +794,7 @@ bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, d
 		if (switcher != 0) {
 			Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL));
 		}*/
-		bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]);
+		bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]);
 		np::dtype dtype = np::dtype::get_builtin<float>();
 
 
@@ -865,6 +866,126 @@ bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, d
 }
 
 
+bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW,  double d_h, double d_lambda) {
+	// the result is in the following list
+	bp::list result;
+
+	int N, M, Z, numdims, SearchW, /*SimilW, SearchW_real,*/ padXY, newsizeX, newsizeY, newsizeZ, switchpad_crop;
+	//const int  *dims;
+	float *A, *B = NULL, *Ap = NULL, *Bp = NULL, h, lambda;
+
+	numdims = input.get_nd();
+	int dims[3];
+
+	dims[0] = input.shape(0);
+	dims[1] = input.shape(1);
+	if (numdims == 2) {
+		dims[2] = -1;
+	}
+	else {
+		dims[2] = input.shape(2);
+	}
+	/*numdims = mxGetNumberOfDimensions(prhs[0]);
+	dims = mxGetDimensions(prhs[0]);*/
+
+	N = dims[0];
+	M = dims[1];
+	Z = dims[2];
+
+	//if ((numdims < 2) || (numdims > 3)) { mexErrMsgTxt("The input should be 2D image or 3D volume"); }
+	//if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); }
+
+	//if (nrhs != 5) mexErrMsgTxt("Five inputs reqired: Image(2D,3D), SearchW, SimilW, Threshold, Regularization parameter");
+
+	///*Handling inputs*/
+	//A = (float *)mxGetData(prhs[0]);    /* the image to regularize/filter */
+	//SearchW_real = (int)mxGetScalar(prhs[1]); /* the searching window ratio */
+	//SimilW = (int)mxGetScalar(prhs[2]);  /* the similarity window ratio */
+	//h = (float)mxGetScalar(prhs[3]);  /* parameter for the PB filtering function */
+	//lambda = (float)mxGetScalar(prhs[4]); /* regularization parameter */
+
+	//if (h <= 0) mexErrMsgTxt("Parmeter for the PB penalty function should be > 0");
+	//if (lambda <= 0) mexErrMsgTxt(" Regularization parmeter should be > 0");
+
+	SearchW = SearchW_real + 2 * SimilW;
+
+	/* SearchW_full = 2*SearchW + 1; */ /* the full searching window  size */
+										/* SimilW_full = 2*SimilW + 1;  */  /* the full similarity window  size */
+
+
+	padXY = SearchW + 2 * SimilW; /* padding sizes */
+	newsizeX = N + 2 * (padXY); /* the X size of the padded array */
+	newsizeY = M + 2 * (padXY); /* the Y size of the padded array */
+	newsizeZ = Z + 2 * (padXY); /* the Z size of the padded array */
+	int N_dims[] = { newsizeX, newsizeY, newsizeZ };
+
+	/******************************2D case ****************************/
+	if (numdims == 2) {
+		///*Handling output*/
+		//B = (float*)mxGetData(plhs[0] = mxCreateNumericMatrix(N, M, mxSINGLE_CLASS, mxREAL));
+		///*allocating memory for the padded arrays */
+		//Ap = (float*)mxGetData(mxCreateNumericMatrix(newsizeX, newsizeY, mxSINGLE_CLASS, mxREAL));
+		//Bp = (float*)mxGetData(mxCreateNumericMatrix(newsizeX, newsizeY, mxSINGLE_CLASS, mxREAL));
+		///**************************************************************************/
+
+		bp::tuple shape = bp::make_tuple(N, M);
+		np::dtype dtype = np::dtype::get_builtin<float>();
+
+		np::ndarray npB = np::zeros(shape, dtype);
+
+		shape = bp::make_tuple(newsizeX, newsizeY);
+		np::ndarray npAp = np::zeros(shape, dtype);
+		np::ndarray npBp = np::zeros(shape, dtype);
+		B = reinterpret_cast<float *>(npB.get_data());
+		Ap = reinterpret_cast<float *>(npAp.get_data());
+		Bp = reinterpret_cast<float *>(npBp.get_data());		
+
+		/*Perform padding of image A to the size of [newsizeX * newsizeY] */
+		switchpad_crop = 0; /*padding*/
+		pad_crop(A, Ap, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop);
+
+		/* Do PB regularization with the padded array  */
+		PB_FUNC2D(Ap, Bp, newsizeY, newsizeX, padXY, SearchW, SimilW, (float)h, (float)lambda);
+
+		switchpad_crop = 1; /*cropping*/
+		pad_crop(Bp, B, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop);
+		result.append<np::ndarray>(npB);
+	}
+	else
+	{
+		/******************************3D case ****************************/
+		///*Handling output*/
+		//B = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dims, mxSINGLE_CLASS, mxREAL));
+		///*allocating memory for the padded arrays */
+		//Ap = (float*)mxGetPr(mxCreateNumericArray(3, N_dims, mxSINGLE_CLASS, mxREAL));
+		//Bp = (float*)mxGetPr(mxCreateNumericArray(3, N_dims, mxSINGLE_CLASS, mxREAL));
+		/**************************************************************************/
+		bp::tuple shape = bp::make_tuple(dims[0], dims[1], dims[2]);
+		bp::tuple shape_AB = bp::make_tuple(N_dims[0], N_dims[1], N_dims[2]);
+		np::dtype dtype = np::dtype::get_builtin<float>();
+
+		np::ndarray npB = np::zeros(shape, dtype);
+		np::ndarray npAp = np::zeros(shape_AB, dtype);
+		np::ndarray npBp = np::zeros(shape_AB, dtype);
+		B = reinterpret_cast<float *>(npB.get_data());
+		Ap = reinterpret_cast<float *>(npAp.get_data());
+		Bp = reinterpret_cast<float *>(npBp.get_data());
+		/*Perform padding of image A to the size of [newsizeX * newsizeY * newsizeZ] */
+		switchpad_crop = 0; /*padding*/
+		pad_crop(A, Ap, M, N, Z, newsizeY, newsizeX, newsizeZ, padXY, switchpad_crop);
+
+		/* Do PB regularization with the padded array  */
+		PB_FUNC3D(Ap, Bp, newsizeY, newsizeX, newsizeZ, padXY, SearchW, SimilW, (float)h, (float)lambda);
+
+		switchpad_crop = 1; /*cropping*/
+		pad_crop(Bp, B, M, N, Z, newsizeY, newsizeX, newsizeZ, padXY, switchpad_crop);
+
+		result.append<np::ndarray>(npB);
+	} /*end else ndims*/
+
+	return result;
+}
+
 BOOST_PYTHON_MODULE(regularizers)
 {
 	np::initialize();
diff --git a/src/Python/setup.py b/src/Python/setup.py
index a8feb1c..a4eed14 100644
--- a/src/Python/setup.py
+++ b/src/Python/setup.py
@@ -52,6 +52,7 @@ setup(
                                       "..\\..\\main_func\\regularizers_CPU\\FGP_TV_core.c",
                                       "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c",
                                       "..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c",
+                                      "..\\..\\main_func\\regularizers_CPU\\PatchBased_Regul_core.c",
                                       "..\\..\\main_func\\regularizers_CPU\\utils.c"
                                         ],
                              include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), 
-- 
cgit v1.2.3


From 753f3477bde8fc250adc542bbeffc03d369107e1 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Mon, 7 Aug 2017 17:21:12 +0100
Subject: added TGV_PD, removed useless code

---
 src/Python/fista_module.cpp | 245 ++++++++++++++++++++++++++------------------
 1 file changed, 146 insertions(+), 99 deletions(-)

(limited to 'src/Python')

diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp
index c2d9352..eacda3d 100644
--- a/src/Python/fista_module.cpp
+++ b/src/Python/fista_module.cpp
@@ -30,6 +30,7 @@ limitations under the License.
 #include "FGP_TV_core.h"
 #include "LLT_model_core.h"
 #include "PatchBased_Regul_core.h"
+#include "TGV_PD_core.h"
 #include "utils.h"
 
 
@@ -103,101 +104,8 @@ If unsuccessful in a MEX file, the MEX file terminates and returns control to th
 enough free heap space to create the mxArray.
 */
 
-void mexErrMessageText(char* text) {
-	std::cerr << text << std::endl;
-}
-
-/*
-double mxGetScalar(const mxArray *pm);
-args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray.
-Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray.	In C, mxGetScalar returns a double.
-*/
-
-template<typename T>
-double mxGetScalar(const np::ndarray plh) {
-	return (double)bp::extract<T>(plh[0]);
-}
-
-
-
-template<typename T>
-T * mxGetData(const np::ndarray pm) {
-	//args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray.
-	//Returns: Pointer to the value of the first real(nonimaginary) element of the mxArray.In C, mxGetScalar returns a double.
-	/*Access the numpy array pointer:
-	char * get_data() const;
-	Returns:	Array�s raw data pointer as a char
-	Note:	This returns char so stride math works properly on it.User will have to reinterpret_cast it.
-	probably this would work.
-	A = reinterpret_cast<float *>(prhs[0]);
-	*/
-	return reinterpret_cast<T *>(prhs[0]);
-}
-
-template<typename T>
-np::ndarray zeros(int dims, int * dim_array, T el) {
-	bp::tuple shape;
-	if (dims == 3)
-		shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]);
-	else if (dims == 2)
-		shape = bp::make_tuple(dim_array[0], dim_array[1]);
-	np::dtype dtype = np::dtype::get_builtin<T>();
-	np::ndarray zz = np::zeros(shape, dtype);
-	return zz;
-}
 
 
-
-
-bp::list mexFunction(np::ndarray input) {
-	int number_of_dims = input.get_nd();
-	int dim_array[3];
-
-	dim_array[0] = input.shape(0);
-	dim_array[1] = input.shape(1);
-	if (number_of_dims == 2) {
-		dim_array[2] = -1;
-	}
-	else {
-		dim_array[2] = input.shape(2);
-	}
-
-	/**************************************************************************/
-	np::ndarray zz = zeros(3, dim_array, (int)0);
-	np::ndarray fzz = zeros(3, dim_array, (float)0);
-	/**************************************************************************/
-
-	int * A = reinterpret_cast<int *>(input.get_data());
-	int * B = reinterpret_cast<int *>(zz.get_data());
-	float * C = reinterpret_cast<float *>(fzz.get_data());
-
-	//Copy data and cast
-	for (int i = 0; i < dim_array[0]; i++) {
-		for (int j = 0; j < dim_array[1]; j++) {
-			for (int k = 0; k < dim_array[2]; k++) {
-				int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i;
-				int val = (*(A + index));
-				float fval = (float)val;
-				std::memcpy(B + index, &val, sizeof(int));
-				std::memcpy(C + index, &fval, sizeof(float));
-			}
-		}
-	}
-
-	bp::list result;
-
-	result.append<int>(number_of_dims);
-	result.append<int>(dim_array[0]);
-	result.append<int>(dim_array[1]);
-	result.append<int>(dim_array[2]);
-	result.append<np::ndarray>(zz);
-	result.append<np::ndarray>(fzz);
-
-	//result.append<bp::tuple>(tup);
-	return result;
-
-}
-
 bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) {
 	
 	// the result is in the following list
@@ -487,7 +395,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me
 		np::ndarray npP1_old = np::zeros(shape, dtype);
 		np::ndarray npP2_old = np::zeros(shape, dtype);
 		np::ndarray npR1     = np::zeros(shape, dtype);
-		np::ndarray npR2     = zeros(2, dim_array, (float)0);
+		np::ndarray npR2     = np::zeros(shape, dtype);
 
 		D      = reinterpret_cast<float *>(npD.get_data());
 		D_old  = reinterpret_cast<float *>(npD_old.get_data());
@@ -866,7 +774,7 @@ bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, d
 }
 
 
-bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW,  double d_h, double d_lambda) {
+bp::list PatchBased_Regul(np::ndarray input, double d_lambda, int SearchW_real, int SimilW,  double d_h) {
 	// the result is in the following list
 	bp::list result;
 
@@ -899,6 +807,7 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW,  doub
 
 	///*Handling inputs*/
 	//A = (float *)mxGetData(prhs[0]);    /* the image to regularize/filter */
+	A = reinterpret_cast<float *>(input.get_data());
 	//SearchW_real = (int)mxGetScalar(prhs[1]); /* the searching window ratio */
 	//SimilW = (int)mxGetScalar(prhs[2]);  /* the similarity window ratio */
 	//h = (float)mxGetScalar(prhs[3]);  /* parameter for the PB filtering function */
@@ -907,6 +816,8 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW,  doub
 	//if (h <= 0) mexErrMsgTxt("Parmeter for the PB penalty function should be > 0");
 	//if (lambda <= 0) mexErrMsgTxt(" Regularization parmeter should be > 0");
 
+	lambda = (float)d_lambda;
+	h = (float)d_h;
 	SearchW = SearchW_real + 2 * SimilW;
 
 	/* SearchW_full = 2*SearchW + 1; */ /* the full searching window  size */
@@ -918,7 +829,6 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW,  doub
 	newsizeY = M + 2 * (padXY); /* the Y size of the padded array */
 	newsizeZ = Z + 2 * (padXY); /* the Z size of the padded array */
 	int N_dims[] = { newsizeX, newsizeY, newsizeZ };
-
 	/******************************2D case ****************************/
 	if (numdims == 2) {
 		///*Handling output*/
@@ -943,12 +853,13 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW,  doub
 		/*Perform padding of image A to the size of [newsizeX * newsizeY] */
 		switchpad_crop = 0; /*padding*/
 		pad_crop(A, Ap, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop);
-
+		
 		/* Do PB regularization with the padded array  */
 		PB_FUNC2D(Ap, Bp, newsizeY, newsizeX, padXY, SearchW, SimilW, (float)h, (float)lambda);
-
+		
 		switchpad_crop = 1; /*cropping*/
 		pad_crop(Bp, B, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop);
+		
 		result.append<np::ndarray>(npB);
 	}
 	else
@@ -983,6 +894,141 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW,  doub
 		result.append<np::ndarray>(npB);
 	} /*end else ndims*/
 
+	return result;
+}
+
+bp::list TGV_PD(np::ndarray input, double d_lambda, double d_alpha1, double d_alpha0, int iter) {
+	// the result is in the following list
+	bp::list result;
+	int number_of_dims, /*iter,*/ dimX, dimY, dimZ, ll;
+	//const int  *dim_array;
+	float *A, *U, *U_old, *P1, *P2, *P3, *Q1, *Q2, *Q3, *Q4, *Q5, *Q6, *Q7, *Q8, *Q9, *V1, *V1_old, *V2, *V2_old, *V3, *V3_old, lambda, L2, tau, sigma, alpha1, alpha0;
+
+	//number_of_dims = mxGetNumberOfDimensions(prhs[0]);
+	//dim_array = mxGetDimensions(prhs[0]);
+	number_of_dims = input.get_nd();
+	int dim_array[3];
+
+	dim_array[0] = input.shape(0);
+	dim_array[1] = input.shape(1);
+	if (number_of_dims == 2) {
+		dim_array[2] = -1;
+	}
+	else {
+		dim_array[2] = input.shape(2);
+	}
+	/*Handling Matlab input data*/
+	//A = (float *)mxGetData(prhs[0]); /*origanal noise image/volume*/
+	//if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); }
+	
+	A = reinterpret_cast<float *>(input.get_data());
+
+	//lambda = (float)mxGetScalar(prhs[1]); /*regularization parameter*/
+	//alpha1 = (float)mxGetScalar(prhs[2]); /*first-order term*/
+	//alpha0 = (float)mxGetScalar(prhs[3]); /*second-order term*/
+	//iter = (int)mxGetScalar(prhs[4]); /*iterations number*/
+	//if (nrhs != 5) mexErrMsgTxt("Five input parameters is reqired: Image(2D/3D), Regularization parameter, alpha1, alpha0, Iterations");
+	lambda = (float)d_lambda;
+	alpha1 = (float)d_alpha1;
+	alpha0 = (float)d_alpha0;
+
+	/*Handling Matlab output data*/
+	dimX = dim_array[0]; dimY = dim_array[1];
+
+	if (number_of_dims == 2) {
+		/*2D case*/
+		dimZ = 1;
+		bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]);
+		np::dtype dtype = np::dtype::get_builtin<float>();
+
+		np::ndarray npU = np::zeros(shape, dtype);
+		np::ndarray npP1 = np::zeros(shape, dtype);
+		np::ndarray npP2 = np::zeros(shape, dtype);
+		np::ndarray npQ1 = np::zeros(shape, dtype);
+		np::ndarray npQ2 = np::zeros(shape, dtype);
+		np::ndarray npQ3 = np::zeros(shape, dtype);
+		np::ndarray npV1 = np::zeros(shape, dtype);
+		np::ndarray npV1_old = np::zeros(shape, dtype);
+		np::ndarray npV2 = np::zeros(shape, dtype);
+		np::ndarray npV2_old = np::zeros(shape, dtype);
+		np::ndarray npU_old = np::zeros(shape, dtype);
+
+		U = reinterpret_cast<float *>(npU.get_data());
+		U_old = reinterpret_cast<float *>(npU_old.get_data());
+		P1 = reinterpret_cast<float *>(npP1.get_data());
+		P2 = reinterpret_cast<float *>(npP2.get_data());
+		Q1 = reinterpret_cast<float *>(npQ1.get_data());
+		Q2 = reinterpret_cast<float *>(npQ2.get_data());
+		Q3 = reinterpret_cast<float *>(npQ3.get_data());
+		V1 = reinterpret_cast<float *>(npV1.get_data());
+		V1_old = reinterpret_cast<float *>(npV1_old.get_data());
+		V2 = reinterpret_cast<float *>(npV2.get_data());
+		V2_old = reinterpret_cast<float *>(npV2_old.get_data());
+		//U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+
+		/*dual variables*/
+		/*P1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		P2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+
+		Q1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		Q2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		Q3 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+
+		U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+
+		V1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		V1_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		V2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		V2_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/
+		/*printf("%i \n", i);*/
+		L2 = 12.0; /*Lipshitz constant*/
+		tau = 1.0 / pow(L2, 0.5);
+		sigma = 1.0 / pow(L2, 0.5);
+
+		/*Copy A to U*/
+		copyIm(A, U, dimX, dimY, dimZ);
+		/* Here primal-dual iterations begin for 2D */
+		for (ll = 0; ll < iter; ll++) {
+
+			/* Calculate Dual Variable P */
+			DualP_2D(U, V1, V2, P1, P2, dimX, dimY, dimZ, sigma);
+
+			/*Projection onto convex set for P*/
+			ProjP_2D(P1, P2, dimX, dimY, dimZ, alpha1);
+
+			/* Calculate Dual Variable Q */
+			DualQ_2D(V1, V2, Q1, Q2, Q3, dimX, dimY, dimZ, sigma);
+
+			/*Projection onto convex set for Q*/
+			ProjQ_2D(Q1, Q2, Q3, dimX, dimY, dimZ, alpha0);
+
+			/*saving U into U_old*/
+			copyIm(U, U_old, dimX, dimY, dimZ);
+
+			/*adjoint operation  -> divergence and projection of P*/
+			DivProjP_2D(U, A, P1, P2, dimX, dimY, dimZ, lambda, tau);
+
+			/*get updated solution U*/
+			newU(U, U_old, dimX, dimY, dimZ);
+
+			/*saving V into V_old*/
+			copyIm(V1, V1_old, dimX, dimY, dimZ);
+			copyIm(V2, V2_old, dimX, dimY, dimZ);
+
+			/* upd V*/
+			UpdV_2D(V1, V2, P1, P2, Q1, Q2, Q3, dimX, dimY, dimZ, tau);
+
+			/*get new V*/
+			newU(V1, V1_old, dimX, dimY, dimZ);
+			newU(V2, V2_old, dimX, dimY, dimZ);
+		} /*end of iterations*/
+	
+		result.append<np::ndarray>(npU);
+	}
+	
+
+	
+	
 	return result;
 }
 
@@ -997,8 +1043,9 @@ BOOST_PYTHON_MODULE(regularizers)
 	np::dtype dt1 = np::dtype::get_builtin<uint8_t>();
 	np::dtype dt2 = np::dtype::get_builtin<uint16_t>();
 
-	def("mexFunction", mexFunction);
 	def("SplitBregman_TV", SplitBregman_TV);
 	def("FGP_TV", FGP_TV);
 	def("LLT_model", LLT_model);
+	def("PatchBased_Regul", PatchBased_Regul);
+	def("TGV_PD", TGV_PD);
 }
\ No newline at end of file
-- 
cgit v1.2.3


From 4534a11d1c32a65484f4f38348c27a7bb2d9ad19 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Mon, 7 Aug 2017 17:21:54 +0100
Subject: added TGV_PD

---
 src/Python/setup.py             |   1 +
 src/Python/test_regularizers.py | 195 ++++++++++++++++++++++++++++++++++------
 2 files changed, 168 insertions(+), 28 deletions(-)

(limited to 'src/Python')

diff --git a/src/Python/setup.py b/src/Python/setup.py
index a4eed14..0468722 100644
--- a/src/Python/setup.py
+++ b/src/Python/setup.py
@@ -53,6 +53,7 @@ setup(
                                       "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c",
                                       "..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c",
                                       "..\\..\\main_func\\regularizers_CPU\\PatchBased_Regul_core.c",
+                                      "..\\..\\main_func\\regularizers_CPU\\TGV_PD_core.c",
                                       "..\\..\\main_func\\regularizers_CPU\\utils.c"
                                         ],
                              include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), 
diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py
index 6abfba4..6a34749 100644
--- a/src/Python/test_regularizers.py
+++ b/src/Python/test_regularizers.py
@@ -47,6 +47,8 @@ class Regularizer():
         SplitBregman_TV = regularizers.SplitBregman_TV
         FGP_TV = regularizers.FGP_TV
         LLT_model = regularizers.LLT_model
+        PatchBased_Regul = regularizers.PatchBased_Regul
+        TGV_PD = regularizers.TGV_PD
     # Algorithm
     
     class TotalVariationPenalty(Enum):
@@ -55,13 +57,17 @@ class Regularizer():
     # TotalVariationPenalty
         
     def __init__(self , algorithm):
-        
+        self.setAlgorithm ( algorithm )
+    # __init__
+    
+    def setAlgorithm(self, algorithm):
         self.algorithm = algorithm
         self.pars = self.parsForAlgorithm(algorithm)
-    # __init__
+    # setAlgorithm
         
     def parsForAlgorithm(self, algorithm):
         pars = dict()
+        
         if algorithm == Regularizer.Algorithm.SplitBregman_TV :
             pars['algorithm'] = algorithm
             pars['input'] = None
@@ -69,6 +75,7 @@ class Regularizer():
             pars['number_of_iterations'] = 35
             pars['tolerance_constant'] = 0.0001
             pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic
+            
         elif algorithm == Regularizer.Algorithm.FGP_TV :
             pars['algorithm'] = algorithm
             pars['input'] = None
@@ -76,6 +83,7 @@ class Regularizer():
             pars['number_of_iterations'] = 50
             pars['tolerance_constant'] = 0.001
             pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic
+            
         elif algorithm == Regularizer.Algorithm.LLT_model:
             pars['algorithm'] = algorithm
             pars['input'] = None
@@ -85,6 +93,24 @@ class Regularizer():
             pars['tolerance_constant'] = None
             pars['restrictive_Z_smoothing'] = 0
             
+        elif algorithm == Regularizer.Algorithm.PatchBased_Regul:
+            pars['algorithm'] = algorithm
+            pars['input'] = None
+            pars['searching_window_ratio'] = None
+            pars['similarity_window_ratio'] = None
+            pars['PB_filtering_parameter'] = None
+            pars['regularization_parameter'] = None
+            
+        elif algorithm == Regularizer.Algorithm.TGV_PD:
+            pars['algorithm'] = algorithm
+            pars['input'] = None
+            pars['first_order_term'] = None
+            pars['second_order_term'] = None
+            pars['number_of_iterations'] = None
+            pars['regularization_parameter'] = None
+            
+            
+            
         return pars
     # parsForAlgorithm
         
@@ -98,6 +124,8 @@ class Regularizer():
         self.pars['regularization_parameter'] = regularization_parameter
         #for key, value in self.pars.items():
         #        print("{0} = {1}".format(key, value))
+        if None in self.pars:
+                raise Exception("Not all parameters have been provided")
                 
         if self.algorithm == Regularizer.Algorithm.SplitBregman_TV :
             return self.algorithm(input, regularization_parameter,
@@ -112,15 +140,27 @@ class Regularizer():
         elif self.algorithm == Regularizer.Algorithm.LLT_model :
             #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
             # no default
-            if None in self.pars:
-                raise Exception("Not all parameters have been provided")
-            else:
-                return self.algorithm(input, 
-                                  regularization_parameter,
-                                  self.pars['time_step'] , 
-                                  self.pars['number_of_iterations'],
-                                  self.pars['tolerance_constant'],
-                                  self.pars['restrictive_Z_smoothing'] )
+            return self.algorithm(input, 
+                              regularization_parameter,
+                              self.pars['time_step'] , 
+                              self.pars['number_of_iterations'],
+                              self.pars['tolerance_constant'],
+                              self.pars['restrictive_Z_smoothing'] )
+        elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul :
+            #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
+            # no default
+            return self.algorithm(input, regularization_parameter,
+                                  self.pars['searching_window_ratio'] , 
+                                  self.pars['similarity_window_ratio'] , 
+                                  self.pars['PB_filtering_parameter'])
+        elif self.algorithm == Regularizer.Algorithm.TGV_PD :
+            #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
+            # no default
+            return self.algorithm(input, regularization_parameter,
+                                  self.pars['first_order_term'] , 
+                                  self.pars['second_order_term'] , 
+                                  self.pars['number_of_iterations'])
+            
             
         
     # __call__
@@ -142,13 +182,40 @@ class Regularizer():
     @staticmethod
     def LLT_model(input, regularization_parameter , time_step, number_of_iterations,
                   tolerance_constant, restrictive_Z_smoothing=0):
-        reg = Regularizer(Regularizer.Algorithm.FGP_TV)
+        reg = Regularizer(Regularizer.Algorithm.LLT_model)
         out = list( reg(input, regularization_parameter, time_step=time_step, 
                         number_of_iterations=number_of_iterations,
                         tolerance_constant=tolerance_constant, 
                         restrictive_Z_smoothing=restrictive_Z_smoothing) )
         out.append(reg.pars)
         return out
+    
+    @staticmethod
+    def PatchBased_Regul(input, regularization_parameter,
+                        searching_window_ratio, 
+                        similarity_window_ratio,
+                        PB_filtering_parameter):
+        reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul)   
+        out = list( reg(input, 
+                        regularization_parameter,
+                        searching_window_ratio=searching_window_ratio, 
+                        similarity_window_ratio=similarity_window_ratio,
+                        PB_filtering_parameter=PB_filtering_parameter )
+            )
+        out.append(reg.pars)
+        return out
+    
+    @staticmethod
+    def TGV_PD(input, regularization_parameter , first_order_term, 
+               second_order_term, number_of_iterations):
+        
+        reg = Regularizer(Regularizer.Algorithm.TGV_PD)
+        out = list( reg(input, regularization_parameter, 
+                        first_order_term=first_order_term, 
+                        second_order_term=second_order_term,
+                        number_of_iterations=number_of_iterations) )
+        out.append(reg.pars)
+        return out
         
 
 #Example:
@@ -171,17 +238,17 @@ u0 = Im + (perc* np.random.normal(size=np.shape(Im)))
 f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1)
 u0 = f(u0).astype('float32')
 
-# plot 
+## plot 
 fig = plt.figure()
-a=fig.add_subplot(2,3,1)
-a.set_title('Original')
-imgplot = plt.imshow(Im)
+#a=fig.add_subplot(3,3,1)
+#a.set_title('Original')
+#imgplot = plt.imshow(Im)
 
-a=fig.add_subplot(2,3,2)
+a=fig.add_subplot(2,3,1)
 a.set_title('noise')
 imgplot = plt.imshow(u0)
 
-
+reg_output = []
 ##############################################################################
 # Call regularizer
 
@@ -199,8 +266,9 @@ out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., numbe
           TV_Penalty=Regularizer.TotalVariationPenalty.l1)
 out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. )
 pars = out2[2]
+reg_output.append(out2)
 
-a=fig.add_subplot(2,3,3)
+a=fig.add_subplot(2,3,2)
 a.set_title('SplitBregman_TV')
 textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s'
 textstr = textstr % (pars['regularization_parameter'], 
@@ -213,7 +281,7 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
 # place a text box in upper left in axes coords
 a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
         verticalalignment='top', bbox=props)
-imgplot = plt.imshow(out2[0])
+imgplot = plt.imshow(reg_output[-1][0])
 
 ###################### FGP_TV #########################################
 # u = FGP_TV(single(u0), 0.05, 100, 1e-04);
@@ -221,7 +289,9 @@ out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.05,
                           number_of_iterations=10)
 pars = out2[-1]
 
-a=fig.add_subplot(2,3,4)
+reg_output.append(out2)
+
+a=fig.add_subplot(2,3,3)
 a.set_title('FGP_TV')
 textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s'
 textstr = textstr % (pars['regularization_parameter'], 
@@ -234,18 +304,23 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
 # place a text box in upper left in axes coords
 a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
         verticalalignment='top', bbox=props)
-imgplot = plt.imshow(out2[0])
+imgplot = plt.imshow(reg_output[-1][0])
 
 ###################### LLT_model #########################################
 # * u0 = Im + .03*randn(size(Im)); % adding noise
 # [Den] = LLT_model(single(u0), 10, 0.1, 1);
-out2 = Regularizer.LLT_model(input=u0, regularization_parameter=10.,
-                          time_step=0.1,
-                          tolerance_constant=1e-4,
-                          number_of_iterations=10)
+#Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0); 
+#input, regularization_parameter , time_step, number_of_iterations,
+#                  tolerance_constant, restrictive_Z_smoothing=0
+out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25,
+                          time_step=0.0003,
+                          tolerance_constant=0.0001,
+                          number_of_iterations=300)
 pars = out2[-1]
 
-a=fig.add_subplot(2,3,5)
+reg_output.append(out2)
+
+a=fig.add_subplot(2,3,4)
 a.set_title('LLT_model')
 textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\ntime-step=%f'
 textstr = textstr % (pars['regularization_parameter'], 
@@ -259,7 +334,71 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
 # place a text box in upper left in axes coords
 a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
         verticalalignment='top', bbox=props)
-imgplot = plt.imshow(out2[0])
+imgplot = plt.imshow(reg_output[-1][0])
+
+###################### PatchBased_Regul #########################################
+# Quick 2D denoising example in Matlab:   
+#   Im = double(imread('lena_gray_256.tif'))/255;  % loading image
+#   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise
+#   ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); 
+
+out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05,
+                          searching_window_ratio=3,
+                          similarity_window_ratio=1,
+                          PB_filtering_parameter=0.08)
+pars = out2[-1]
+reg_output.append(out2)
+
+a=fig.add_subplot(2,3,5)
+a.set_title('PatchBased_Regul')
+textstr = 'regularization_parameter=%.2f\nsearching_window_ratio=%d\nsimilarity_window_ratio=%.2e\nPB_filtering_parameter=%f'
+textstr = textstr % (pars['regularization_parameter'], 
+                     pars['searching_window_ratio'], 
+                     pars['similarity_window_ratio'],
+                     pars['PB_filtering_parameter'])
+
+
+
+
+# these are matplotlib.patch.Patch properties
+props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+# place a text box in upper left in axes coords
+a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+        verticalalignment='top', bbox=props)
+imgplot = plt.imshow(reg_output[-1][0])
+
+
+###################### TGV_PD #########################################
+# Quick 2D denoising example in Matlab:   
+#   Im = double(imread('lena_gray_256.tif'))/255;  % loading image
+#   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise
+#   u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550);
+
+
+out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05,
+                          first_order_term=1.3,
+                          second_order_term=1,
+                          number_of_iterations=550)
+pars = out2[-1]
+reg_output.append(out2)
+
+a=fig.add_subplot(2,3,6)
+a.set_title('TGV_PD')
+textstr = 'regularization_parameter=%.2f\nfirst_order_term=%.2f\nsecond_order_term=%.2f\nnumber_of_iterations=%d'
+textstr = textstr % (pars['regularization_parameter'], 
+                     pars['first_order_term'], 
+                     pars['second_order_term'],
+                     pars['number_of_iterations'])
+
+
+
+
+# these are matplotlib.patch.Patch properties
+props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+# place a text box in upper left in axes coords
+a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+        verticalalignment='top', bbox=props)
+imgplot = plt.imshow(reg_output[-1][0])
 
 
 
-- 
cgit v1.2.3


From 3fffd568589137b17d1fbe44e55a757e3745a3b1 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Wed, 11 Oct 2017 15:42:05 +0100
Subject: added simple_astra_test.py

---
 src/Python/test/simple_astra_test.py | 25 +++++++++++++++++++++++++
 1 file changed, 25 insertions(+)
 create mode 100644 src/Python/test/simple_astra_test.py

(limited to 'src/Python')

diff --git a/src/Python/test/simple_astra_test.py b/src/Python/test/simple_astra_test.py
new file mode 100644
index 0000000..905eeea
--- /dev/null
+++ b/src/Python/test/simple_astra_test.py
@@ -0,0 +1,25 @@
+import astra
+import numpy
+
+detectorSpacingX = 1.0
+detectorSpacingY = 1.0
+det_row_count = 128
+det_col_count = 128
+
+angles_rad = numpy.asarray([i for i in range(360)], dtype=float) / 180. * numpy.pi
+
+proj_geom = astra.creators.create_proj_geom('parallel3d',
+                                            detectorSpacingX,
+                                            detectorSpacingY,
+                                            det_row_count,
+                                            det_col_count,
+                                            angles_rad)
+
+image_size_x = 64
+image_size_y = 64
+image_size_z = 32
+
+vol_geom = astra.creators.create_vol_geom(image_size_x,image_size_y,image_size_z)
+
+x1 = numpy.random.rand(image_size_z,image_size_y,image_size_x)
+sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom)
-- 
cgit v1.2.3


From 0611d34c31fa1e706c3bcd7e17651f7555469e00 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Thu, 17 Aug 2017 16:33:09 +0100
Subject: initial revision

---
 src/Python/test/simple_astra_test.py | 25 -------------------------
 1 file changed, 25 deletions(-)
 delete mode 100644 src/Python/test/simple_astra_test.py

(limited to 'src/Python')

diff --git a/src/Python/test/simple_astra_test.py b/src/Python/test/simple_astra_test.py
deleted file mode 100644
index 905eeea..0000000
--- a/src/Python/test/simple_astra_test.py
+++ /dev/null
@@ -1,25 +0,0 @@
-import astra
-import numpy
-
-detectorSpacingX = 1.0
-detectorSpacingY = 1.0
-det_row_count = 128
-det_col_count = 128
-
-angles_rad = numpy.asarray([i for i in range(360)], dtype=float) / 180. * numpy.pi
-
-proj_geom = astra.creators.create_proj_geom('parallel3d',
-                                            detectorSpacingX,
-                                            detectorSpacingY,
-                                            det_row_count,
-                                            det_col_count,
-                                            angles_rad)
-
-image_size_x = 64
-image_size_y = 64
-image_size_z = 32
-
-vol_geom = astra.creators.create_vol_geom(image_size_x,image_size_y,image_size_z)
-
-x1 = numpy.random.rand(image_size_z,image_size_y,image_size_x)
-sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom)
-- 
cgit v1.2.3


From bc29e0690d856ad9dd147b435d34c5761556a1e5 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Wed, 23 Aug 2017 12:55:19 +0100
Subject: Regularizer.pyfirst commit

---
 src/Python/Regularizer.py | 322 ++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 322 insertions(+)
 create mode 100644 src/Python/Regularizer.py

(limited to 'src/Python')

diff --git a/src/Python/Regularizer.py b/src/Python/Regularizer.py
new file mode 100644
index 0000000..15dbbb4
--- /dev/null
+++ b/src/Python/Regularizer.py
@@ -0,0 +1,322 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Tue Aug  8 14:26:00 2017
+
+@author: ofn77899
+"""
+
+import regularizers
+import numpy as np
+from enum import Enum
+import timeit
+
+class Regularizer():
+    '''Class to handle regularizer algorithms to be used during reconstruction
+    
+    Currently 5 CPU (OMP) regularization algorithms are available:
+        
+    1) SplitBregman_TV
+    2) FGP_TV
+    3) LLT_model
+    4) PatchBased_Regul
+    5) TGV_PD
+    
+    Usage:
+        the regularizer can be invoked as object or as static method
+        Depending on the actual regularizer the input parameter may vary, and 
+        a different default setting is defined.
+        reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+
+        out = reg(input=u0, regularization_parameter=10., number_of_iterations=30,
+          tolerance_constant=1e-4, 
+          TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+
+        out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10.,
+          number_of_iterations=30, tolerance_constant=1e-4, 
+          TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+        
+        A number of optional parameters can be passed or skipped
+        out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. )
+
+    '''
+    class Algorithm(Enum):
+        SplitBregman_TV = regularizers.SplitBregman_TV
+        FGP_TV = regularizers.FGP_TV
+        LLT_model = regularizers.LLT_model
+        PatchBased_Regul = regularizers.PatchBased_Regul
+        TGV_PD = regularizers.TGV_PD
+    # Algorithm
+    
+    class TotalVariationPenalty(Enum):
+        isotropic = 0
+        l1 = 1
+    # TotalVariationPenalty
+        
+    def __init__(self , algorithm, debug = True):
+        self.setAlgorithm ( algorithm )
+        self.debug = debug
+    # __init__
+    
+    def setAlgorithm(self, algorithm):
+        self.algorithm = algorithm
+        self.pars = self.getDefaultParsForAlgorithm(algorithm)
+    # setAlgorithm
+        
+    def getDefaultParsForAlgorithm(self, algorithm):
+        pars = dict()
+        
+        if algorithm == Regularizer.Algorithm.SplitBregman_TV :
+            pars['algorithm'] = algorithm
+            pars['input'] = None
+            pars['regularization_parameter'] = None
+            pars['number_of_iterations'] = 35
+            pars['tolerance_constant'] = 0.0001
+            pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic
+            
+        elif algorithm == Regularizer.Algorithm.FGP_TV :
+            pars['algorithm'] = algorithm
+            pars['input'] = None
+            pars['regularization_parameter'] = None
+            pars['number_of_iterations'] = 50
+            pars['tolerance_constant'] = 0.001
+            pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic
+            
+        elif algorithm == Regularizer.Algorithm.LLT_model:
+            pars['algorithm'] = algorithm
+            pars['input'] = None
+            pars['regularization_parameter'] = None
+            pars['time_step'] = None
+            pars['number_of_iterations'] = None
+            pars['tolerance_constant'] = None
+            pars['restrictive_Z_smoothing'] = 0
+            
+        elif algorithm == Regularizer.Algorithm.PatchBased_Regul:
+            pars['algorithm'] = algorithm
+            pars['input'] = None
+            pars['searching_window_ratio'] = None
+            pars['similarity_window_ratio'] = None
+            pars['PB_filtering_parameter'] = None
+            pars['regularization_parameter'] = None
+            
+        elif algorithm == Regularizer.Algorithm.TGV_PD:
+            pars['algorithm'] = algorithm
+            pars['input'] = None
+            pars['first_order_term'] = None
+            pars['second_order_term'] = None
+            pars['number_of_iterations'] = None
+            pars['regularization_parameter'] = None
+            
+        else:
+            raise Exception('Unknown regularizer algorithm')
+            
+        return pars
+    # parsForAlgorithm
+    
+    def setParameter(self, **kwargs):
+        '''set named parameter for the regularization engine
+        
+        raises Exception if the named parameter is not recognized
+        Typical usage is:
+            
+        reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+        reg.setParameter(input=u0)    
+        reg.setParameter(regularization_parameter=10.)
+        
+        it can be also used as
+        reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+        reg.setParameter(input=u0 , regularization_parameter=10.)
+        '''
+        
+        for key , value in kwargs.items():
+            if key in self.pars.keys():
+                self.pars[key] = value
+            else:
+                raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key))
+    # setParameter
+	
+    def getParameter(self, **kwargs):
+        ret = {}
+        for key , value in kwargs.items():
+            if key in self.pars.keys():
+                ret[key] = self.pars[key]
+        else:
+            raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key))
+    # setParameter
+	
+        
+    def __call__(self, input = None, regularization_parameter = None, **kwargs):
+        '''Actual call for the regularizer. 
+        
+        One can either set the regularization parameters first and then call the
+        algorithm or set the regularization parameter during the call (as 
+        is done in the static methods). 
+        '''
+        
+        if kwargs is not None:
+            for key, value in kwargs.items():
+                #print("{0} = {1}".format(key, value))                        
+                self.pars[key] = value
+                    
+        if input is not None: 
+            self.pars['input'] = input
+        if regularization_parameter is not None:
+            self.pars['regularization_parameter'] = regularization_parameter
+            
+        if self.debug:
+            print ("--------------------------------------------------")
+            for key, value in self.pars.items():
+                if key== 'algorithm' :
+                    print("{0} = {1}".format(key, value.__name__))
+                elif key == 'input':
+                    print("{0} = {1}".format(key, np.shape(value)))
+                else:
+                    print("{0} = {1}".format(key, value))
+        
+            
+        if None in self.pars:
+                raise Exception("Not all parameters have been provided")
+        
+        input = self.pars['input']
+        regularization_parameter = self.pars['regularization_parameter']
+        if self.algorithm == Regularizer.Algorithm.SplitBregman_TV :
+            return self.algorithm(input, regularization_parameter,
+                              self.pars['number_of_iterations'],
+                              self.pars['tolerance_constant'],
+                              self.pars['TV_penalty'].value )    
+        elif self.algorithm == Regularizer.Algorithm.FGP_TV :
+            return self.algorithm(input, regularization_parameter,
+                              self.pars['number_of_iterations'],
+                              self.pars['tolerance_constant'],
+                              self.pars['TV_penalty'].value )
+        elif self.algorithm == Regularizer.Algorithm.LLT_model :
+            #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
+            # no default
+            return self.algorithm(input, 
+                              regularization_parameter,
+                              self.pars['time_step'] , 
+                              self.pars['number_of_iterations'],
+                              self.pars['tolerance_constant'],
+                              self.pars['restrictive_Z_smoothing'] )
+        elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul :
+            #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
+            # no default
+            return self.algorithm(input, regularization_parameter,
+                                  self.pars['searching_window_ratio'] , 
+                                  self.pars['similarity_window_ratio'] , 
+                                  self.pars['PB_filtering_parameter'])
+        elif self.algorithm == Regularizer.Algorithm.TGV_PD :
+            #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
+            # no default
+            if len(np.shape(input)) == 2:
+                return self.algorithm(input, regularization_parameter,
+                                  self.pars['first_order_term'] , 
+                                  self.pars['second_order_term'] , 
+                                  self.pars['number_of_iterations'])
+            elif len(np.shape(input)) == 3:
+                #assuming it's 3D
+                # run independent calls on each slice
+                out3d = input.copy()
+                for i in range(np.shape(input)[2]):
+                    out = self.algorithm(input, regularization_parameter,
+                                 self.pars['first_order_term'] , 
+                                 self.pars['second_order_term'] , 
+                                 self.pars['number_of_iterations'])
+                    # copy the result in the 3D image
+                    out3d.T[i] = out[0].copy()
+                # append the rest of the info that the algorithm returns
+                output = [out3d]
+                for i in range(1,len(out)):
+                    output.append(out[i])
+                return output
+                
+                
+            
+            
+        
+    # __call__
+    
+    @staticmethod
+    def SplitBregman_TV(input, regularization_parameter , **kwargs):
+        start_time = timeit.default_timer()
+        reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+        out = list( reg(input, regularization_parameter, **kwargs) )
+        out.append(reg.pars)
+        txt = reg.printParametersToString()
+        txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time)
+        out.append(txt)
+        return out
+        
+    @staticmethod
+    def FGP_TV(input, regularization_parameter , **kwargs):
+        start_time = timeit.default_timer()
+        reg = Regularizer(Regularizer.Algorithm.FGP_TV)
+        out = list( reg(input, regularization_parameter, **kwargs) )
+        out.append(reg.pars)
+        txt = reg.printParametersToString()
+        txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time)
+        out.append(txt)
+        return out
+    
+    @staticmethod
+    def LLT_model(input, regularization_parameter , time_step, number_of_iterations,
+                  tolerance_constant, restrictive_Z_smoothing=0):
+        start_time = timeit.default_timer()
+        reg = Regularizer(Regularizer.Algorithm.LLT_model)
+        out = list( reg(input, regularization_parameter, time_step=time_step, 
+                        number_of_iterations=number_of_iterations,
+                        tolerance_constant=tolerance_constant, 
+                        restrictive_Z_smoothing=restrictive_Z_smoothing) )
+        out.append(reg.pars)
+        txt = reg.printParametersToString()
+        txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time)
+        out.append(txt)
+        return out
+    
+    @staticmethod
+    def PatchBased_Regul(input, regularization_parameter,
+                        searching_window_ratio, 
+                        similarity_window_ratio,
+                        PB_filtering_parameter):
+        start_time = timeit.default_timer()
+        reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul)   
+        out = list( reg(input, 
+                        regularization_parameter,
+                        searching_window_ratio=searching_window_ratio, 
+                        similarity_window_ratio=similarity_window_ratio,
+                        PB_filtering_parameter=PB_filtering_parameter )
+            )
+        out.append(reg.pars)
+        txt = reg.printParametersToString()
+        txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time)
+        out.append(txt)
+        return out
+    
+    @staticmethod
+    def TGV_PD(input, regularization_parameter , first_order_term, 
+               second_order_term, number_of_iterations):
+        start_time = timeit.default_timer()
+        
+        reg = Regularizer(Regularizer.Algorithm.TGV_PD)
+        out = list( reg(input, regularization_parameter, 
+                        first_order_term=first_order_term, 
+                        second_order_term=second_order_term,
+                        number_of_iterations=number_of_iterations) )
+        out.append(reg.pars)
+        txt = reg.printParametersToString()
+        txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time)
+        out.append(txt)
+        
+        return out
+    
+    def printParametersToString(self):
+        txt = r''
+        for key, value in self.pars.items():
+            if key== 'algorithm' :
+                txt += "{0} = {1}".format(key, value.__name__)
+            elif key == 'input':
+                txt += "{0} = {1}".format(key, np.shape(value))
+            else:
+                txt += "{0} = {1}".format(key, value)
+            txt += '\n'
+        return txt
+        
-- 
cgit v1.2.3


From 48a4d5315b4b6ca62eaa931912b6a02993979688 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Wed, 23 Aug 2017 12:56:09 +0100
Subject: Test module for Boost Python

currently can pass a function to the C++ layer to be evaluated.
---
 src/Python/Matlab2Python_utils.cpp | 68 +++++++++++++++++++++++++++++++++++++-
 src/Python/setup_test.py           |  6 ++--
 src/Python/test.py                 | 34 ++++++++++++++++---
 3 files changed, 99 insertions(+), 9 deletions(-)

(limited to 'src/Python')

diff --git a/src/Python/Matlab2Python_utils.cpp b/src/Python/Matlab2Python_utils.cpp
index 6aaad90..e15d738 100644
--- a/src/Python/Matlab2Python_utils.cpp
+++ b/src/Python/Matlab2Python_utils.cpp
@@ -175,6 +175,71 @@ bp::list mexFunction( np::ndarray input ) {
 	}
 
 
+	bp::list result;
+
+	result.append<int>(number_of_dims);
+	result.append<int>(dim_array[0]);
+	result.append<int>(dim_array[1]);
+	result.append<int>(dim_array[2]);
+	result.append<np::ndarray>(zz);
+	result.append<np::ndarray>(fzz);
+
+	//result.append<bp::tuple>(tup);
+	return result;
+
+}
+bp::list doSomething(np::ndarray input, PyObject *pyobj , PyObject *pyobj2) {
+
+	boost::python::object output(boost::python::handle<>(boost::python::borrowed(pyobj)));
+	int isOutput = !(output == boost::python::api::object());
+
+	boost::python::object calculate(boost::python::handle<>(boost::python::borrowed(pyobj2)));
+	int isCalculate = !(calculate == boost::python::api::object());
+
+	int number_of_dims = input.get_nd();
+	int dim_array[3];
+
+	dim_array[0] = input.shape(0);
+	dim_array[1] = input.shape(1);
+	if (number_of_dims == 2) {
+		dim_array[2] = -1;
+	}
+	else {
+		dim_array[2] = input.shape(2);
+	}
+
+	/**************************************************************************/
+	np::ndarray zz = zeros(3, dim_array, (int)0);
+	np::ndarray fzz = zeros(3, dim_array, (float)0);
+	/**************************************************************************/
+
+	int * A = reinterpret_cast<int *>(input.get_data());
+	int * B = reinterpret_cast<int *>(zz.get_data());
+	float * C = reinterpret_cast<float *>(fzz.get_data());
+
+	//Copy data and cast
+	for (int i = 0; i < dim_array[0]; i++) {
+		for (int j = 0; j < dim_array[1]; j++) {
+			for (int k = 0; k < dim_array[2]; k++) {
+				int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i;
+				int val = (*(A + index));
+				float fval = sqrt((float)val);
+				std::memcpy(B + index, &val, sizeof(int));
+				std::memcpy(C + index, &fval, sizeof(float));
+				// if the PyObj is not None evaluate the function 
+				if (isOutput)	
+					output(fval);
+				if (isCalculate) {
+					float nfval = (float)bp::extract<float>(calculate(val));
+					if (isOutput)
+						output(nfval);
+					std::memcpy(C + index, &nfval, sizeof(float));
+				}
+			}
+		}
+	}
+
+
 	bp::list result;
 
 	result.append<int>(number_of_dims);
@@ -196,7 +261,7 @@ BOOST_PYTHON_MODULE(prova)
 
 	//To specify that this module is a package
 	bp::object package = bp::scope();
-	package.attr("__path__") = "fista";
+	package.attr("__path__") = "prova";
 
 	np::dtype dt1 = np::dtype::get_builtin<uint8_t>();
 	np::dtype dt2 = np::dtype::get_builtin<uint16_t>();
@@ -207,4 +272,5 @@ BOOST_PYTHON_MODULE(prova)
 	//numpy_boost_python_register_type<float, 3>();
 	//numpy_boost_python_register_type<double, 3>();
 	def("mexFunction", mexFunction);
+	def("doSomething", doSomething);
 }
\ No newline at end of file
diff --git a/src/Python/setup_test.py b/src/Python/setup_test.py
index ffb9c02..7c86175 100644
--- a/src/Python/setup_test.py
+++ b/src/Python/setup_test.py
@@ -30,13 +30,13 @@ extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x'
 extra_libraries = []
 if platform.system() == 'Windows':
     extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB']   
-    extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."]
+    #extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."]
     if sys.version_info.major == 3 :   
         extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64']
     else:
         extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64']
 else:
-    extra_include_dirs += ["../ContourTree/", "../Core/","."]
+    #extra_include_dirs += ["../ContourTree/", "../Core/","."]
     if sys.version_info.major == 3:
         extra_libraries += ['boost_python3', 'boost_numpy3','gomp']
     else:
@@ -47,7 +47,7 @@ setup(
 	description='CCPi Core Imaging Library - FISTA Reconstruction Module',
 	version=cil_version,
     cmdclass = {'build_ext': build_ext},
-    ext_modules = [Extension("fista",
+    ext_modules = [Extension("prova",
                              sources=[  "Matlab2Python_utils.cpp",
                                         ],
                              include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), 
diff --git a/src/Python/test.py b/src/Python/test.py
index e283f89..db47380 100644
--- a/src/Python/test.py
+++ b/src/Python/test.py
@@ -5,14 +5,38 @@ Created on Thu Aug  3 14:08:09 2017
 @author: ofn77899
 """
 
-import fista
+import prova
 import numpy as np
 
-a = np.asarray([i for i in range(3*4*5)])
-a = a.reshape([3,4,5])
+a = np.asarray([i for i in range(1*2*3)])
+a = a.reshape([1,2,3])
 print (a)
-b = fista.mexFunction(a)
+b = prova.mexFunction(a)
 #print (b)
 print (b[4].shape)
 print (b[4])
-print (b[5])
\ No newline at end of file
+print (b[5])
+
+def print_element(input):
+	print ("f: {0}".format(input))
+	
+prova.doSomething(a, print_element, None)
+
+c = []
+def append_to_list(input, shouldPrint=False):
+	c.append(input)
+	if shouldPrint:
+		print ("{0} appended to list {1}".format(input, c))
+
+def element_wise_algebra(input, shouldPrint=True):
+	ret = input - 7
+	if shouldPrint:
+		print ("element_wise {0}".format(ret))
+	return ret
+		
+prova.doSomething(a, append_to_list, None)
+#print ("this is c: {0}".format(c))
+
+b = prova.doSomething(a, None, element_wise_algebra)
+#print (a)
+print (b[5])
-- 
cgit v1.2.3


From c28385d0dd5efcb32bd2c33e4bd93ba61f959b3f Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Wed, 23 Aug 2017 14:27:28 +0100
Subject: updated test for regularizer API

---
 src/Python/test_regularizers.py | 590 ++++++++++++++++++++--------------------
 1 file changed, 290 insertions(+), 300 deletions(-)

(limited to 'src/Python')

diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py
index 6a34749..5d25f02 100644
--- a/src/Python/test_regularizers.py
+++ b/src/Python/test_regularizers.py
@@ -8,216 +8,37 @@ Created on Fri Aug  4 11:10:05 2017
 from ccpi.viewer.CILViewer2D import Converter
 import vtk
 
-import regularizers
 import matplotlib.pyplot as plt
 import numpy as np
 import os    
 from enum import Enum
-
-class Regularizer():
-    '''Class to handle regularizer algorithms to be used during reconstruction
-    
-    Currently 5 regularization algorithms are available:
-        
-    1) SplitBregman_TV
-    2) FGP_TV
-    3)
-    4)
-    5)
-    
-    Usage:
-        the regularizer can be invoked as object or as static method
-        Depending on the actual regularizer the input parameter may vary, and 
-        a different default setting is defined.
-        reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
-
-        out = reg(input=u0, regularization_parameter=10., number_of_iterations=30,
-          tolerance_constant=1e-4, 
-          TV_Penalty=Regularizer.TotalVariationPenalty.l1)
-
-        out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10.,
-          number_of_iterations=30, tolerance_constant=1e-4, 
-          TV_Penalty=Regularizer.TotalVariationPenalty.l1)
-        
-        A number of optional parameters can be passed or skipped
-        out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. )
-
-    '''
-    class Algorithm(Enum):
-        SplitBregman_TV = regularizers.SplitBregman_TV
-        FGP_TV = regularizers.FGP_TV
-        LLT_model = regularizers.LLT_model
-        PatchBased_Regul = regularizers.PatchBased_Regul
-        TGV_PD = regularizers.TGV_PD
-    # Algorithm
-    
-    class TotalVariationPenalty(Enum):
-        isotropic = 0
-        l1 = 1
-    # TotalVariationPenalty
-        
-    def __init__(self , algorithm):
-        self.setAlgorithm ( algorithm )
-    # __init__
-    
-    def setAlgorithm(self, algorithm):
-        self.algorithm = algorithm
-        self.pars = self.parsForAlgorithm(algorithm)
-    # setAlgorithm
-        
-    def parsForAlgorithm(self, algorithm):
-        pars = dict()
-        
-        if algorithm == Regularizer.Algorithm.SplitBregman_TV :
-            pars['algorithm'] = algorithm
-            pars['input'] = None
-            pars['regularization_parameter'] = None
-            pars['number_of_iterations'] = 35
-            pars['tolerance_constant'] = 0.0001
-            pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic
-            
-        elif algorithm == Regularizer.Algorithm.FGP_TV :
-            pars['algorithm'] = algorithm
-            pars['input'] = None
-            pars['regularization_parameter'] = None
-            pars['number_of_iterations'] = 50
-            pars['tolerance_constant'] = 0.001
-            pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic
-            
-        elif algorithm == Regularizer.Algorithm.LLT_model:
-            pars['algorithm'] = algorithm
-            pars['input'] = None
-            pars['regularization_parameter'] = None
-            pars['time_step'] = None
-            pars['number_of_iterations'] = None
-            pars['tolerance_constant'] = None
-            pars['restrictive_Z_smoothing'] = 0
-            
-        elif algorithm == Regularizer.Algorithm.PatchBased_Regul:
-            pars['algorithm'] = algorithm
-            pars['input'] = None
-            pars['searching_window_ratio'] = None
-            pars['similarity_window_ratio'] = None
-            pars['PB_filtering_parameter'] = None
-            pars['regularization_parameter'] = None
-            
-        elif algorithm == Regularizer.Algorithm.TGV_PD:
-            pars['algorithm'] = algorithm
-            pars['input'] = None
-            pars['first_order_term'] = None
-            pars['second_order_term'] = None
-            pars['number_of_iterations'] = None
-            pars['regularization_parameter'] = None
-            
-            
-            
-        return pars
-    # parsForAlgorithm
-        
-    def __call__(self, input, regularization_parameter, **kwargs):
-        
-        if kwargs is not None:
-            for key, value in kwargs.items():
-                #print("{0} = {1}".format(key, value))
-                self.pars[key] = value
-        self.pars['input'] = input
-        self.pars['regularization_parameter'] = regularization_parameter
-        #for key, value in self.pars.items():
-        #        print("{0} = {1}".format(key, value))
-        if None in self.pars:
-                raise Exception("Not all parameters have been provided")
-                
-        if self.algorithm == Regularizer.Algorithm.SplitBregman_TV :
-            return self.algorithm(input, regularization_parameter,
-                              self.pars['number_of_iterations'],
-                              self.pars['tolerance_constant'],
-                              self.pars['TV_penalty'].value )    
-        elif self.algorithm == Regularizer.Algorithm.FGP_TV :
-            return self.algorithm(input, regularization_parameter,
-                              self.pars['number_of_iterations'],
-                              self.pars['tolerance_constant'],
-                              self.pars['TV_penalty'].value )
-        elif self.algorithm == Regularizer.Algorithm.LLT_model :
-            #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
-            # no default
-            return self.algorithm(input, 
-                              regularization_parameter,
-                              self.pars['time_step'] , 
-                              self.pars['number_of_iterations'],
-                              self.pars['tolerance_constant'],
-                              self.pars['restrictive_Z_smoothing'] )
-        elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul :
-            #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
-            # no default
-            return self.algorithm(input, regularization_parameter,
-                                  self.pars['searching_window_ratio'] , 
-                                  self.pars['similarity_window_ratio'] , 
-                                  self.pars['PB_filtering_parameter'])
-        elif self.algorithm == Regularizer.Algorithm.TGV_PD :
-            #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
-            # no default
-            return self.algorithm(input, regularization_parameter,
-                                  self.pars['first_order_term'] , 
-                                  self.pars['second_order_term'] , 
-                                  self.pars['number_of_iterations'])
-            
-            
-        
-    # __call__
-    
-    @staticmethod
-    def SplitBregman_TV(input, regularization_parameter , **kwargs):
-        reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
-        out = list( reg(input, regularization_parameter, **kwargs) )
-        out.append(reg.pars)
-        return out
-        
-    @staticmethod
-    def FGP_TV(input, regularization_parameter , **kwargs):
-        reg = Regularizer(Regularizer.Algorithm.FGP_TV)
-        out = list( reg(input, regularization_parameter, **kwargs) )
-        out.append(reg.pars)
-        return out
-    
-    @staticmethod
-    def LLT_model(input, regularization_parameter , time_step, number_of_iterations,
-                  tolerance_constant, restrictive_Z_smoothing=0):
-        reg = Regularizer(Regularizer.Algorithm.LLT_model)
-        out = list( reg(input, regularization_parameter, time_step=time_step, 
-                        number_of_iterations=number_of_iterations,
-                        tolerance_constant=tolerance_constant, 
-                        restrictive_Z_smoothing=restrictive_Z_smoothing) )
-        out.append(reg.pars)
-        return out
-    
-    @staticmethod
-    def PatchBased_Regul(input, regularization_parameter,
-                        searching_window_ratio, 
-                        similarity_window_ratio,
-                        PB_filtering_parameter):
-        reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul)   
-        out = list( reg(input, 
-                        regularization_parameter,
-                        searching_window_ratio=searching_window_ratio, 
-                        similarity_window_ratio=similarity_window_ratio,
-                        PB_filtering_parameter=PB_filtering_parameter )
-            )
-        out.append(reg.pars)
-        return out
-    
-    @staticmethod
-    def TGV_PD(input, regularization_parameter , first_order_term, 
-               second_order_term, number_of_iterations):
-        
-        reg = Regularizer(Regularizer.Algorithm.TGV_PD)
-        out = list( reg(input, regularization_parameter, 
-                        first_order_term=first_order_term, 
-                        second_order_term=second_order_term,
-                        number_of_iterations=number_of_iterations) )
-        out.append(reg.pars)
-        return out
-        
-
+import timeit
+
+from Regularizer import Regularizer
+
+###############################################################################
+#https://stackoverflow.com/questions/13875989/comparing-image-in-url-to-image-in-filesystem-in-python/13884956#13884956
+#NRMSE a normalization of the root of the mean squared error
+#NRMSE is simply 1 - [RMSE / (maxval - minval)]. Where maxval is the maximum
+# intensity from the two images being compared, and respectively the same for
+# minval. RMSE is given by the square root of MSE: 
+# sqrt[(sum(A - B) ** 2) / |A|],
+# where |A| means the number of elements in A. By doing this, the maximum value
+# given by RMSE is maxval.
+
+def nrmse(im1, im2):
+    a, b = im1.shape
+    rmse = np.sqrt(np.sum((im2 - im1) ** 2) / float(a * b))
+    max_val = max(np.max(im1), np.max(im2))
+    min_val = min(np.min(im1), np.min(im2))
+    return 1 - (rmse / (max_val - min_val))
+###############################################################################
+
+###############################################################################
+#
+#  2D Regularizers
+#
+###############################################################################
 #Example:
 # figure;
 # Im = double(imread('lena_gray_256.tif'))/255;  % loading image
@@ -255,49 +76,55 @@ reg_output = []
 ####################### SplitBregman_TV #####################################
 # u = SplitBregman_TV(single(u0), 10, 30, 1e-04);
 
-reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+use_object = True
+if use_object:
+    reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+    reg.setParameter(input=u0)    
+    reg.setParameter(regularization_parameter=10.)
+    # or 
+    # reg.setParameter(input=u0, regularization_parameter=10., #number_of_iterations=30,
+              #tolerance_constant=1e-4, 
+              #TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+    plotme = reg() [0]
+    pars = reg.pars
+    textstr = reg.printParametersToString() 
+    
+    #out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30,
+              #tolerance_constant=1e-4, 
+    #          TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+    
+#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30,
+#          tolerance_constant=1e-4, 
+#          TV_Penalty=Regularizer.TotalVariationPenalty.l1)
 
-out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30,
-          #tolerance_constant=1e-4, 
-          TV_Penalty=Regularizer.TotalVariationPenalty.l1)
-
-out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30,
-          tolerance_constant=1e-4, 
-          TV_Penalty=Regularizer.TotalVariationPenalty.l1)
-out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. )
-pars = out2[2]
-reg_output.append(out2)
+else:
+    out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. )
+    pars = out2[2]
+    reg_output.append(out2)
+    plotme = reg_output[-1][0]
+    textstr = out2[-1]
 
 a=fig.add_subplot(2,3,2)
-a.set_title('SplitBregman_TV')
-textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s'
-textstr = textstr % (pars['regularization_parameter'], 
-                     pars['number_of_iterations'], 
-                     pars['tolerance_constant'],
-                     pars['TV_penalty'].name)
+
 
 # these are matplotlib.patch.Patch properties
 props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
 # place a text box in upper left in axes coords
 a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
         verticalalignment='top', bbox=props)
-imgplot = plt.imshow(reg_output[-1][0])
+imgplot = plt.imshow(plotme)
 
 ###################### FGP_TV #########################################
 # u = FGP_TV(single(u0), 0.05, 100, 1e-04);
-out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.05,
-                          number_of_iterations=10)
-pars = out2[-1]
+out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.005,
+                          number_of_iterations=200)
+pars = out2[-2]
 
 reg_output.append(out2)
 
 a=fig.add_subplot(2,3,3)
-a.set_title('FGP_TV')
-textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s'
-textstr = textstr % (pars['regularization_parameter'], 
-                     pars['number_of_iterations'], 
-                     pars['tolerance_constant'],
-                     pars['TV_penalty'].name)
+
+textstr = out2[-1]
 
 # these are matplotlib.patch.Patch properties
 props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
@@ -316,50 +143,12 @@ out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25,
                           time_step=0.0003,
                           tolerance_constant=0.0001,
                           number_of_iterations=300)
-pars = out2[-1]
+pars = out2[-2]
 
 reg_output.append(out2)
 
 a=fig.add_subplot(2,3,4)
-a.set_title('LLT_model')
-textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\ntime-step=%f'
-textstr = textstr % (pars['regularization_parameter'], 
-                     pars['number_of_iterations'], 
-                     pars['tolerance_constant'],
-                     pars['time_step']
-                     )
-
-# these are matplotlib.patch.Patch properties
-props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
-# place a text box in upper left in axes coords
-a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
-        verticalalignment='top', bbox=props)
-imgplot = plt.imshow(reg_output[-1][0])
-
-###################### PatchBased_Regul #########################################
-# Quick 2D denoising example in Matlab:   
-#   Im = double(imread('lena_gray_256.tif'))/255;  % loading image
-#   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise
-#   ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); 
-
-out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05,
-                          searching_window_ratio=3,
-                          similarity_window_ratio=1,
-                          PB_filtering_parameter=0.08)
-pars = out2[-1]
-reg_output.append(out2)
-
-a=fig.add_subplot(2,3,5)
-a.set_title('PatchBased_Regul')
-textstr = 'regularization_parameter=%.2f\nsearching_window_ratio=%d\nsimilarity_window_ratio=%.2e\nPB_filtering_parameter=%f'
-textstr = textstr % (pars['regularization_parameter'], 
-                     pars['searching_window_ratio'], 
-                     pars['similarity_window_ratio'],
-                     pars['PB_filtering_parameter'])
-
-
-
-
+textstr = out2[-1]
 # these are matplotlib.patch.Patch properties
 props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
 # place a text box in upper left in axes coords
@@ -367,6 +156,215 @@ a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
         verticalalignment='top', bbox=props)
 imgplot = plt.imshow(reg_output[-1][0])
 
+# ###################### PatchBased_Regul #########################################
+# # Quick 2D denoising example in Matlab:   
+# #   Im = double(imread('lena_gray_256.tif'))/255;  % loading image
+# #   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise
+# #   ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); 
+
+# out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05,
+                          # searching_window_ratio=3,
+                          # similarity_window_ratio=1,
+                          # PB_filtering_parameter=0.08)
+# pars = out2[-2]
+# reg_output.append(out2)
+
+# a=fig.add_subplot(2,3,5)
+
+
+# textstr = out2[-1]
+
+# # these are matplotlib.patch.Patch properties
+# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+# # place a text box in upper left in axes coords
+# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+        # verticalalignment='top', bbox=props)
+# imgplot = plt.imshow(reg_output[-1][0])
+
+
+# ###################### TGV_PD #########################################
+# # Quick 2D denoising example in Matlab:   
+# #   Im = double(imread('lena_gray_256.tif'))/255;  % loading image
+# #   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise
+# #   u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550);
+
+
+# out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05,
+                          # first_order_term=1.3,
+                          # second_order_term=1,
+                          # number_of_iterations=550)
+# pars = out2[-2]
+# reg_output.append(out2)
+
+# a=fig.add_subplot(2,3,6)
+
+
+# textstr = out2[-1]
+
+
+# # these are matplotlib.patch.Patch properties
+# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+# # place a text box in upper left in axes coords
+# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+        # verticalalignment='top', bbox=props)
+# imgplot = plt.imshow(reg_output[-1][0])
+
+
+plt.show()
+
+################################################################################
+##
+##  3D Regularizers
+##
+################################################################################
+##Example:
+## figure;
+## Im = double(imread('lena_gray_256.tif'))/255;  % loading image
+## u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0;
+## u = SplitBregman_TV(single(u0), 10, 30, 1e-04);
+#
+##filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Reconstruction\python\test\reconstruction_example.mha"
+#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Simpleflex\data\head.mha"
+#
+#reader = vtk.vtkMetaImageReader()
+#reader.SetFileName(os.path.normpath(filename))
+#reader.Update()
+##vtk returns 3D images, let's take just the one slice there is as 2D
+#Im = Converter.vtk2numpy(reader.GetOutput())
+#Im = Im.astype('float32')
+##imgplot = plt.imshow(Im)
+#perc = 0.05
+#u0 = Im + (perc* np.random.normal(size=np.shape(Im)))
+## map the u0 u0->u0>0
+#f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1)
+#u0 = f(u0).astype('float32')
+#converter = Converter.numpy2vtkImporter(u0, reader.GetOutput().GetSpacing(),
+#                                        reader.GetOutput().GetOrigin())
+#converter.Update()
+#writer = vtk.vtkMetaImageWriter()
+#writer.SetInputData(converter.GetOutput())
+#writer.SetFileName(r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\noisy_head.mha")
+##writer.Write()
+#
+#
+### plot 
+#fig3D = plt.figure()
+##a=fig.add_subplot(3,3,1)
+##a.set_title('Original')
+##imgplot = plt.imshow(Im)
+#sliceNo = 32
+#
+#a=fig3D.add_subplot(2,3,1)
+#a.set_title('noise')
+#imgplot = plt.imshow(u0.T[sliceNo])
+#
+#reg_output3d = []
+#
+###############################################################################
+## Call regularizer
+#
+######################## SplitBregman_TV #####################################
+## u = SplitBregman_TV(single(u0), 10, 30, 1e-04);
+#
+##reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+#
+##out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30,
+##          #tolerance_constant=1e-4, 
+##          TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+#
+#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30,
+#          tolerance_constant=1e-4, 
+#          TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+#
+#
+#pars = out2[-2]
+#reg_output3d.append(out2)
+#
+#a=fig3D.add_subplot(2,3,2)
+#
+#
+#textstr = out2[-1]
+#
+#
+## these are matplotlib.patch.Patch properties
+#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+## place a text box in upper left in axes coords
+#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+#        verticalalignment='top', bbox=props)
+#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo])
+#
+####################### FGP_TV #########################################
+## u = FGP_TV(single(u0), 0.05, 100, 1e-04);
+#out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.005,
+#                          number_of_iterations=200)
+#pars = out2[-2]
+#reg_output3d.append(out2)
+#
+#a=fig3D.add_subplot(2,3,2)
+#
+#
+#textstr = out2[-1]
+#
+#
+## these are matplotlib.patch.Patch properties
+#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+## place a text box in upper left in axes coords
+#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+#        verticalalignment='top', bbox=props)
+#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo])
+#
+####################### LLT_model #########################################
+## * u0 = Im + .03*randn(size(Im)); % adding noise
+## [Den] = LLT_model(single(u0), 10, 0.1, 1);
+##Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0); 
+##input, regularization_parameter , time_step, number_of_iterations,
+##                  tolerance_constant, restrictive_Z_smoothing=0
+#out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25,
+#                          time_step=0.0003,
+#                          tolerance_constant=0.0001,
+#                          number_of_iterations=300)
+#pars = out2[-2]
+#reg_output3d.append(out2)
+#
+#a=fig3D.add_subplot(2,3,2)
+#
+#
+#textstr = out2[-1]
+#
+#
+## these are matplotlib.patch.Patch properties
+#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+## place a text box in upper left in axes coords
+#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+#        verticalalignment='top', bbox=props)
+#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo])
+#
+####################### PatchBased_Regul #########################################
+## Quick 2D denoising example in Matlab:   
+##   Im = double(imread('lena_gray_256.tif'))/255;  % loading image
+##   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise
+##   ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); 
+#
+#out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05,
+#                          searching_window_ratio=3,
+#                          similarity_window_ratio=1,
+#                          PB_filtering_parameter=0.08)
+#pars = out2[-2]
+#reg_output3d.append(out2)
+#
+#a=fig3D.add_subplot(2,3,2)
+#
+#
+#textstr = out2[-1]
+#
+#
+## these are matplotlib.patch.Patch properties
+#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+## place a text box in upper left in axes coords
+#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+#        verticalalignment='top', bbox=props)
+#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo])
+#
 
 ###################### TGV_PD #########################################
 # Quick 2D denoising example in Matlab:   
@@ -375,30 +373,22 @@ imgplot = plt.imshow(reg_output[-1][0])
 #   u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550);
 
 
-out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05,
-                          first_order_term=1.3,
-                          second_order_term=1,
-                          number_of_iterations=550)
-pars = out2[-1]
-reg_output.append(out2)
-
-a=fig.add_subplot(2,3,6)
-a.set_title('TGV_PD')
-textstr = 'regularization_parameter=%.2f\nfirst_order_term=%.2f\nsecond_order_term=%.2f\nnumber_of_iterations=%d'
-textstr = textstr % (pars['regularization_parameter'], 
-                     pars['first_order_term'], 
-                     pars['second_order_term'],
-                     pars['number_of_iterations'])
-
-
-
-
-# these are matplotlib.patch.Patch properties
-props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
-# place a text box in upper left in axes coords
-a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
-        verticalalignment='top', bbox=props)
-imgplot = plt.imshow(reg_output[-1][0])
-
-
-
+#out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05,
+#                          first_order_term=1.3,
+#                          second_order_term=1,
+#                          number_of_iterations=550)
+#pars = out2[-2]
+#reg_output3d.append(out2)
+#
+#a=fig3D.add_subplot(2,3,2)
+#
+#
+#textstr = out2[-1]
+#
+#
+## these are matplotlib.patch.Patch properties
+#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+## place a text box in upper left in axes coords
+#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+#        verticalalignment='top', bbox=props)
+#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo])
-- 
cgit v1.2.3


From db45d96898f23c3bc97e4c19e834fa976ec301c8 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Wed, 23 Aug 2017 14:31:16 +0100
Subject: initial commit of Reconstructor.py

---
 src/Python/ccpi/reconstruction/Reconstructor.py | 598 ++++++++++++++++++++++++
 1 file changed, 598 insertions(+)
 create mode 100644 src/Python/ccpi/reconstruction/Reconstructor.py

(limited to 'src/Python')

diff --git a/src/Python/ccpi/reconstruction/Reconstructor.py b/src/Python/ccpi/reconstruction/Reconstructor.py
new file mode 100644
index 0000000..ba67327
--- /dev/null
+++ b/src/Python/ccpi/reconstruction/Reconstructor.py
@@ -0,0 +1,598 @@
+# -*- coding: utf-8 -*-
+###############################################################################
+#This work is part of the Core Imaging Library developed by
+#Visual Analytics and Imaging System Group of the Science Technology
+#Facilities Council, STFC
+#
+#Copyright 2017 Edoardo Pasca, Srikanth Nagella
+#Copyright 2017 Daniil Kazantsev
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#http://www.apache.org/licenses/LICENSE-2.0
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+###############################################################################
+
+
+
+import numpy
+import h5py
+from ccpi.reconstruction.parallelbeam import alg
+
+from Regularizer import Regularizer
+from enum import Enum
+
+import astra
+
+
+class Reconstructor:
+    
+    class Algorithm(Enum):
+        CGLS = alg.cgls
+        CGLS_CONV = alg.cgls_conv
+        SIRT = alg.sirt
+        MLEM = alg.mlem
+        CGLS_TICHONOV = alg.cgls_tikhonov
+        CGLS_TVREG = alg.cgls_TVreg
+        FISTA = 'fista'
+        
+    def __init__(self, algorithm = None, projection_data = None,
+                 angles = None, center_of_rotation = None , 
+                 flat_field = None, dark_field = None, 
+                 iterations = None, resolution = None, isLogScale = False, threads = None, 
+                 normalized_projection = None):
+    
+        self.pars = dict()
+        self.pars['algorithm'] = algorithm
+        self.pars['projection_data'] = projection_data
+        self.pars['normalized_projection'] = normalized_projection
+        self.pars['angles'] = angles
+        self.pars['center_of_rotation'] = numpy.double(center_of_rotation)
+        self.pars['flat_field'] = flat_field
+        self.pars['iterations'] = iterations
+        self.pars['dark_field'] = dark_field
+        self.pars['resolution'] = resolution
+        self.pars['isLogScale'] = isLogScale
+        self.pars['threads'] = threads
+        if (iterations != None):
+            self.pars['iterationValues'] = numpy.zeros((iterations)) 
+        
+        if projection_data != None and dark_field != None and flat_field != None:
+            norm = self.normalize(projection_data, dark_field, flat_field, 0.1)
+            self.pars['normalized_projection'] = norm
+            
+    
+    def setPars(self, parameters):
+        keys = ['algorithm','projection_data' ,'normalized_projection', \
+                'angles' , 'center_of_rotation' , 'flat_field', \
+                'iterations','dark_field' , 'resolution', 'isLogScale' , \
+                'threads' , 'iterationValues', 'regularize']
+        
+        for k in keys:
+            if k not in parameters.keys():
+                self.pars[k] = None
+            else:
+                self.pars[k] = parameters[k]
+                
+        
+    def sanityCheck(self):
+        projection_data = self.pars['projection_data']
+        dark_field = self.pars['dark_field']
+        flat_field = self.pars['flat_field']
+        angles = self.pars['angles']
+        
+        if projection_data != None and dark_field != None and \
+            angles != None and flat_field != None:
+            data_shape =  numpy.shape(projection_data)
+            angle_shape = numpy.shape(angles)
+            
+            if angle_shape[0] != data_shape[0]:
+                #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \
+                #                (angle_shape[0] , data_shape[0]) )
+                return (False , 'Projections and angles dimensions do not match: %d vs %d' % \
+                                (angle_shape[0] , data_shape[0]) )
+            
+            if data_shape[1:] != numpy.shape(flat_field):
+                #raise Exception('Projection and flat field dimensions do not match')
+                return (False , 'Projection and flat field dimensions do not match')
+            if data_shape[1:] != numpy.shape(dark_field):
+                #raise Exception('Projection and dark field dimensions do not match')
+                return (False , 'Projection and dark field dimensions do not match')
+            
+            return (True , '' )
+        elif self.pars['normalized_projection'] != None:
+            data_shape =  numpy.shape(self.pars['normalized_projection'])
+            angle_shape = numpy.shape(angles)
+            
+            if angle_shape[0] != data_shape[0]:
+                #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \
+                #                (angle_shape[0] , data_shape[0]) )
+                return (False , 'Projections and angles dimensions do not match: %d vs %d' % \
+                                (angle_shape[0] , data_shape[0]) )
+            else:
+                return (True , '' )
+        else:
+            return (False , 'Not enough data')
+            
+    def reconstruct(self, parameters = None):
+        if parameters != None:
+            self.setPars(parameters)
+        
+        go , reason = self.sanityCheck()
+        if go:
+            return self._reconstruct()
+        else:
+            raise Exception(reason)
+            
+            
+    def _reconstruct(self, parameters=None):
+        if parameters!=None:
+            self.setPars(parameters)
+        parameters = self.pars
+        
+        if parameters['algorithm'] != None and \
+           parameters['normalized_projection'] != None and \
+           parameters['angles'] != None and \
+           parameters['center_of_rotation'] != None and \
+           parameters['iterations'] != None and \
+           parameters['resolution'] != None and\
+           parameters['threads'] != None and\
+           parameters['isLogScale'] != None:
+               
+               
+           if parameters['algorithm'] in (Reconstructor.Algorithm.CGLS,
+                        Reconstructor.Algorithm.MLEM, Reconstructor.Algorithm.SIRT):
+               #store parameters
+               self.pars = parameters
+               result = parameters['algorithm'](
+                           parameters['normalized_projection'] ,
+                           parameters['angles'],
+                           parameters['center_of_rotation'],
+                           parameters['resolution'],
+                           parameters['iterations'],
+                           parameters['threads'] ,
+                           parameters['isLogScale']
+                           )
+               return result
+           elif parameters['algorithm'] in (Reconstructor.Algorithm.CGLS_CONV,
+                          Reconstructor.Algorithm.CGLS_TICHONOV, 
+                          Reconstructor.Algorithm.CGLS_TVREG) :
+               self.pars = parameters
+               result = parameters['algorithm'](
+                           parameters['normalized_projection'] ,
+                           parameters['angles'],
+                           parameters['center_of_rotation'],
+                           parameters['resolution'],
+                           parameters['iterations'],
+                           parameters['threads'] ,
+                           parameters['regularize'],
+                           numpy.zeros((parameters['iterations'])),
+                           parameters['isLogScale']
+                           )
+               
+           elif parameters['algorithm'] == Reconstructor.Algorithm.FISTA:
+               pass
+             
+        else:
+           if parameters['projection_data'] != None and \
+                     parameters['dark_field'] != None and \
+                     parameters['flat_field'] != None:
+               norm = self.normalize(parameters['projection_data'],
+                                   parameters['dark_field'], 
+                                   parameters['flat_field'], 0.1)
+               self.pars['normalized_projection'] = norm
+               return self._reconstruct(parameters)
+              
+                
+                
+    def _normalize(self, projection, dark, flat, def_val=0):
+        a = (projection - dark)
+        b = (flat-dark)
+        with numpy.errstate(divide='ignore', invalid='ignore'):
+            c = numpy.true_divide( a, b )
+            c[ ~ numpy.isfinite( c )] = def_val  # set to not zero if 0/0 
+        return c
+    
+    def normalize(self, projections, dark, flat, def_val=0):
+        norm = [self._normalize(projection, dark, flat, def_val) for projection in projections]
+        return numpy.asarray (norm, dtype=numpy.float32)
+        
+    
+    
+class FISTA():
+    '''FISTA-based reconstruction algorithm using ASTRA-toolbox
+    
+    '''
+    # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>>
+    # ___Input___:
+    # params.[] file:
+    #       - .proj_geom (geometry of the projector) [required]
+    #       - .vol_geom (geometry of the reconstructed object) [required]
+    #       - .sino (vectorized in 2D or 3D sinogram) [required]
+    #       - .iterFISTA (iterations for the main loop, default 40)
+    #       - .L_const (Lipschitz constant, default Power method)                                                                                                    )
+    #       - .X_ideal (ideal image, if given)
+    #       - .weights (statisitcal weights, size of the sinogram)
+    #       - .ROI (Region-of-interest, only if X_ideal is given)
+    #       - .initialize (a 'warm start' using SIRT method from ASTRA)
+    #----------------Regularization choices------------------------
+    #       - .Regul_Lambda_FGPTV (FGP-TV regularization parameter)
+    #       - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter)
+    #       - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter)
+    #       - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04)
+    #       - .Regul_Iterations (iterations for the selected penalty, default 25)
+    #       - .Regul_tauLLT (time step parameter for LLT term)
+    #       - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal)
+    #       - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1)
+    #----------------Visualization parameters------------------------
+    #       - .show (visualize reconstruction 1/0, (0 default))
+    #       - .maxvalplot (maximum value to use for imshow[0 maxvalplot])
+    #       - .slice (for 3D volumes - slice number to imshow)
+    # ___Output___:
+    # 1. X - reconstructed image/volume
+    # 2. output - a structure with
+    #    - .Resid_error - residual error (if X_ideal is given)
+    #    - .objective: value of the objective function
+    #    - .L_const: Lipshitz constant to avoid recalculations
+    
+    # References:
+    # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse
+    # Problems" by A. Beck and M Teboulle
+    # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo
+    # 3. "A novel tomographic reconstruction method based on the robust
+    # Student's t function for suppressing data outliers" D. Kazantsev et.al.
+    # D. Kazantsev, 2016-17
+    def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs):
+        self.params = dict()
+        self.params['projector_geometry'] = projector_geometry
+        self.params['output_geometry'] = output_geometry
+        self.params['input_sinogram'] = input_sinogram
+        detectors, nangles, sliceZ = numpy.shape(input_sinogram)
+        self.params['detectors'] = detectors
+        self.params['number_og_angles'] = nangles
+        self.params['SlicesZ'] = sliceZ
+        
+        # Accepted input keywords
+        kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' ,
+              'weights' , 'region_of_interest' , 'initialize' , 
+              'regularizer' , 
+              'ring_lambda_R_L1',
+              'ring_alpha')
+        
+        # handle keyworded parameters
+        if kwargs is not None:
+            for key, value in kwargs.items():
+                if key in kw:
+                    #print("{0} = {1}".format(key, value))                        
+                    self.pars[key] = value
+                    
+        # set the default values for the parameters if not set
+        if 'number_of_iterations' in kwargs.keys():
+            self.pars['number_of_iterations'] = kwargs['number_of_iterations']
+        else:
+            self.pars['number_of_iterations'] = 40
+        if 'weights' in kwargs.keys():
+            self.pars['weights'] = kwargs['weights']
+        else:
+            self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram']))
+        if 'Lipschitz_constant' in kwargs.keys():
+            self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant']
+        else:
+            self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod()
+        
+        if not self.pars['ideal_image'] in kwargs.keys():
+            self.pars['ideal_image'] = None
+        
+        if not self.pars['region_of_interest'] :
+            if self.pars['ideal_image'] == None:
+                pass
+            else:
+                self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0)
+            
+        if not self.pars['regularizer'] :
+            self.pars['regularizer'] = None
+        else:
+            # the regularizer must be a correctly instantiated object
+            if not self.pars['ring_lambda_R_L1']:
+                self.pars['ring_lambda_R_L1'] = 0
+            if not self.pars['ring_alpha']:
+                self.pars['ring_alpha'] = 1
+        
+            
+            
+        
+    def calculateLipschitzConstantWithPowerMethod(self):
+        ''' using Power method (PM) to establish L constant'''
+        
+        #N = params.vol_geom.GridColCount
+        N = self.pars['output_geometry'].GridColCount
+        proj_geom = self.params['projector_geometry']
+        vol_geom = self.params['output_geometry']
+        weights = self.pars['weights']
+        SlicesZ = self.pars['SlicesZ']
+        
+        if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'):
+            #% for parallel geometry we can do just one slice
+            #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...');
+            niter = 15;# % number of iteration for the PM
+            #N = params.vol_geom.GridColCount;
+            #x1 = rand(N,N,1);
+            x1 = numpy.random.rand(1,N,N)
+            #sqweight = sqrt(weights(:,:,1));
+            sqweight = numpy.sqrt(weights.T[0])
+            proj_geomT = proj_geom.copy();
+            proj_geomT.DetectorRowCount = 1;
+            vol_geomT = vol_geom.copy();
+            vol_geomT['GridSliceCount'] = 1;
+            
+            
+            for i in range(niter):
+                if i == 0:
+                    #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+                    sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT);
+                    y = sqweight * y # element wise multiplication
+                    #astra_mex_data3d('delete', sino_id);
+                    astra.matlab.data3d('delete', sino_id)
+                    
+                idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT);
+                s = numpy.linalg.norm(x1)
+                ### this line?
+                x1 = x1/s;
+                ### this line?
+                sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+                y = sqweight*y;
+                astra.matlab.data3d('delete', sino_id);
+                astra.matlab.data3d('delete', idx);
+            #end
+            del proj_geomT
+            del vol_geomT
+        else
+            #% divergen beam geometry
+            #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...');
+            niter = 8; #% number of iteration for PM
+            x1 = numpy.random.rand(SlicesZ , N , N);
+            #sqweight = sqrt(weights);
+            sqweight = numpy.sqrt(weights.T[0])
+            
+            sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom);
+            y = sqweight*y;
+            #astra_mex_data3d('delete', sino_id);
+            astra.matlab.data3d('delete', sino_id);
+            
+            for i in range(niter):
+                #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom);
+                idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, 
+                                                                    proj_geom, 
+                                                                    vol_geom)
+                s = numpy.linalg.norm(x1)
+                ### this line?
+                x1 = x1/s;
+                ### this line?
+                #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom);
+                sino_id, y = astra.creators.create_sino3d_gpu(x1, 
+                                                              proj_geom, 
+                                                              vol_geom);
+                
+                y = sqweight*y;
+                #astra_mex_data3d('delete', sino_id);
+                #astra_mex_data3d('delete', id);
+                astra.matlab.data3d('delete', sino_id);
+                astra.matlab.data3d('delete', idx);
+            #end
+            #clear x1
+            del x1
+        
+        return s
+    
+    
+    def setRegularizer(self, regularizer):
+        if regularizer
+        self.pars['regularizer'] = regularizer
+        
+    
+    
+
+
+def getEntry(location):
+    for item in nx[location].keys():
+        print (item)
+
+
+print ("Loading Data")
+
+##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif"
+####ind = [i * 1049 for i in range(360)]
+#### use only 360 images
+##images = 200
+##ind = [int(i * 1049 / images) for i in range(images)]
+##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None)
+
+#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs"
+fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs"
+nx = h5py.File(fname, "r")
+
+# the data are stored in a particular location in the hdf5
+for item in nx['entry1/tomo_entry/data'].keys():
+    print (item)
+
+data = nx.get('entry1/tomo_entry/data/rotation_angle')
+angles = numpy.zeros(data.shape)
+data.read_direct(angles)
+print (angles)
+# angles should be in degrees
+
+data = nx.get('entry1/tomo_entry/data/data')
+stack = numpy.zeros(data.shape)
+data.read_direct(stack)
+print (data.shape)
+
+print ("Data Loaded")
+
+
+# Normalize
+data = nx.get('entry1/tomo_entry/instrument/detector/image_key')
+itype = numpy.zeros(data.shape)
+data.read_direct(itype)
+# 2 is dark field
+darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ]
+dark = darks[0]
+for i in range(1, len(darks)):
+    dark += darks[i]
+dark = dark / len(darks)
+#dark[0][0] = dark[0][1]
+
+# 1 is flat field
+flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ]
+flat = flats[0]
+for i in range(1, len(flats)):
+    flat += flats[i]
+flat = flat / len(flats)
+#flat[0][0] = dark[0][1]
+
+
+# 0 is projection data
+proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ]
+angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ]
+angle_proj = numpy.asarray (angle_proj)
+angle_proj = angle_proj.astype(numpy.float32)
+
+# normalized data are
+# norm = (projection - dark)/(flat-dark)
+
+def normalize(projection, dark, flat, def_val=0.1):
+    a = (projection - dark)
+    b = (flat-dark)
+    with numpy.errstate(divide='ignore', invalid='ignore'):
+        c = numpy.true_divide( a, b )
+        c[ ~ numpy.isfinite( c )] = def_val  # set to not zero if 0/0 
+    return c
+    
+
+norm = [normalize(projection, dark, flat) for projection in proj]
+norm = numpy.asarray (norm)
+norm = norm.astype(numpy.float32)
+
+#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm,
+#                 angles = angle_proj, center_of_rotation = 86.2 , 
+#                 flat_field = flat, dark_field = dark, 
+#                 iterations = 15, resolution = 1, isLogScale = False, threads = 3)
+
+#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj,
+#                 angles = angle_proj, center_of_rotation = 86.2 , 
+#                 flat_field = flat, dark_field = dark, 
+#                 iterations = 15, resolution = 1, isLogScale = False, threads = 3)
+#img_cgls = recon.reconstruct()
+#
+#pars = dict()
+#pars['algorithm'] = Reconstructor.Algorithm.SIRT
+#pars['projection_data'] = proj
+#pars['angles'] = angle_proj
+#pars['center_of_rotation'] = numpy.double(86.2)
+#pars['flat_field'] = flat
+#pars['iterations'] = 15
+#pars['dark_field'] = dark
+#pars['resolution'] = 1
+#pars['isLogScale'] = False
+#pars['threads'] = 3
+#
+#img_sirt = recon.reconstruct(pars)
+#
+#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM
+#img_mlem = recon.reconstruct()
+
+############################################################
+############################################################
+#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV
+#recon.pars['regularize'] = numpy.double(0.1)
+#img_cgls_conv = recon.reconstruct()
+
+niterations = 15
+threads = 3
+
+img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
+img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
+img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
+
+iteration_values = numpy.zeros((niterations,))
+img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
+                              iteration_values, False)
+print ("iteration values %s" % str(iteration_values))
+
+iteration_values = numpy.zeros((niterations,))
+img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
+                                      numpy.double(1e-5), iteration_values , False)
+print ("iteration values %s" % str(iteration_values))
+iteration_values = numpy.zeros((niterations,))
+img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
+                                      numpy.double(1e-5), iteration_values , False)
+print ("iteration values %s" % str(iteration_values))
+
+
+##numpy.save("cgls_recon.npy", img_data)
+import matplotlib.pyplot as plt
+fig, ax = plt.subplots(1,6,sharey=True)
+ax[0].imshow(img_cgls[80])
+ax[0].axis('off')  # clear x- and y-axes
+ax[1].imshow(img_sirt[80])
+ax[1].axis('off')  # clear x- and y-axes
+ax[2].imshow(img_mlem[80])
+ax[2].axis('off')  # clear x- and y-axesplt.show()
+ax[3].imshow(img_cgls_conv[80])
+ax[3].axis('off')  # clear x- and y-axesplt.show()
+ax[4].imshow(img_cgls_tikhonov[80])
+ax[4].axis('off')  # clear x- and y-axesplt.show()
+ax[5].imshow(img_cgls_TVreg[80])
+ax[5].axis('off')  # clear x- and y-axesplt.show()
+
+
+plt.show()
+
+#viewer = edo.CILViewer()
+#viewer.setInputAsNumpy(img_cgls2)
+#viewer.displaySliceActor(0)
+#viewer.startRenderLoop()
+
+import vtk
+
+def NumpyToVTKImageData(numpyarray):
+    if (len(numpy.shape(numpyarray)) == 3):
+        doubleImg = vtk.vtkImageData()
+        shape = numpy.shape(numpyarray)
+        doubleImg.SetDimensions(shape[0], shape[1], shape[2])
+        doubleImg.SetOrigin(0,0,0)
+        doubleImg.SetSpacing(1,1,1)
+        doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1)
+        #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation())
+        doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1)
+        
+        for i in range(shape[0]):
+            for j in range(shape[1]):
+                for k in range(shape[2]):
+                    doubleImg.SetScalarComponentFromDouble(
+                        i,j,k,0, numpyarray[i][j][k])
+    #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) )
+        # rescale to appropriate VTK_UNSIGNED_SHORT
+        stats = vtk.vtkImageAccumulate()
+        stats.SetInputData(doubleImg)
+        stats.Update()
+        iMin = stats.GetMin()[0]
+        iMax = stats.GetMax()[0]
+        scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin)
+
+        shiftScaler = vtk.vtkImageShiftScale ()
+        shiftScaler.SetInputData(doubleImg)
+        shiftScaler.SetScale(scale)
+        shiftScaler.SetShift(iMin)
+        shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT)
+        shiftScaler.Update()
+        return shiftScaler.GetOutput()
+        
+#writer = vtk.vtkMetaImageWriter()
+#writer.SetFileName(alg + "_recon.mha")
+#writer.SetInputData(NumpyToVTKImageData(img_cgls2))
+#writer.Write()
-- 
cgit v1.2.3


From c3b58791b906aa6a3b99f32fa5f69a09bb075527 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Wed, 23 Aug 2017 14:56:08 +0100
Subject: module rename to cpu_regularizers

---
 src/Python/setup.py             | 4 ++--
 src/Python/test_regularizers.py | 3 ++-
 2 files changed, 4 insertions(+), 3 deletions(-)

(limited to 'src/Python')

diff --git a/src/Python/setup.py b/src/Python/setup.py
index 0468722..94467c4 100644
--- a/src/Python/setup.py
+++ b/src/Python/setup.py
@@ -47,7 +47,7 @@ setup(
 	description='CCPi Core Imaging Library - FISTA Reconstruction Module',
 	version=cil_version,
     cmdclass = {'build_ext': build_ext},
-    ext_modules = [Extension("regularizers",
+    ext_modules = [Extension("ccpi.imaging.cpu_regularizers",
                              sources=["fista_module.cpp",
                                       "..\\..\\main_func\\regularizers_CPU\\FGP_TV_core.c",
                                       "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c",
@@ -60,5 +60,5 @@ setup(
     
     ],
 	zip_safe = False,	
-	packages = {'ccpi','ccpi.reconstruction'},
+	packages = {'ccpi','ccpi.fistareconstruction'},
 )
diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py
index 5d25f02..755804a 100644
--- a/src/Python/test_regularizers.py
+++ b/src/Python/test_regularizers.py
@@ -14,7 +14,8 @@ import os
 from enum import Enum
 import timeit
 
-from Regularizer import Regularizer
+#from Regularizer import Regularizer
+from ccpi.imaging.Regularizer import Regularizer
 
 ###############################################################################
 #https://stackoverflow.com/questions/13875989/comparing-image-in-url-to-image-in-filesystem-in-python/13884956#13884956
-- 
cgit v1.2.3


From 70d03d2c7567fac409086f015ca9e2ac47b0fc20 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Wed, 23 Aug 2017 14:58:11 +0100
Subject: changed the backward slash to forward

---
 src/Python/setup.py | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

(limited to 'src/Python')

diff --git a/src/Python/setup.py b/src/Python/setup.py
index 94467c4..154f979 100644
--- a/src/Python/setup.py
+++ b/src/Python/setup.py
@@ -49,12 +49,12 @@ setup(
     cmdclass = {'build_ext': build_ext},
     ext_modules = [Extension("ccpi.imaging.cpu_regularizers",
                              sources=["fista_module.cpp",
-                                      "..\\..\\main_func\\regularizers_CPU\\FGP_TV_core.c",
-                                      "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c",
-                                      "..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c",
-                                      "..\\..\\main_func\\regularizers_CPU\\PatchBased_Regul_core.c",
-                                      "..\\..\\main_func\\regularizers_CPU\\TGV_PD_core.c",
-                                      "..\\..\\main_func\\regularizers_CPU\\utils.c"
+                                      "../../main_func/regularizers_CPU/FGP_TV_core.c",
+                                      "../../main_func/regularizers_CPU/SplitBregman_TV_core.c",
+                                      "../../main_func/regularizers_CPU/LLT_model_core.c",
+                                      "../../main_func/regularizers_CPU/PatchBased_Regul_core.c",
+                                      "../../main_func/regularizers_CPU/TGV_PD_core.c",
+                                      "../../main_func/regularizers_CPU/utils.c"
                                         ],
                              include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), 
     
-- 
cgit v1.2.3


From 396c11bd2c8bde1197b708062590a9e3b95538bd Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Wed, 23 Aug 2017 15:01:28 +0100
Subject: added viewer for testing

---
 src/Python/ccpi/viewer/CILViewer.py                |  361 +++++++
 src/Python/ccpi/viewer/CILViewer2D.py              | 1126 ++++++++++++++++++++
 src/Python/ccpi/viewer/QVTKWidget.py               |  340 ++++++
 src/Python/ccpi/viewer/QVTKWidget2.py              |   84 ++
 src/Python/ccpi/viewer/__init__.py                 |    1 +
 .../viewer/__pycache__/CILViewer.cpython-35.pyc    |  Bin 0 -> 10542 bytes
 .../viewer/__pycache__/CILViewer2D.cpython-35.pyc  |  Bin 0 -> 35633 bytes
 .../viewer/__pycache__/QVTKWidget.cpython-35.pyc   |  Bin 0 -> 10099 bytes
 .../viewer/__pycache__/QVTKWidget2.cpython-35.pyc  |  Bin 0 -> 1316 bytes
 .../viewer/__pycache__/__init__.cpython-35.pyc     |  Bin 0 -> 210 bytes
 src/Python/ccpi/viewer/embedvtk.py                 |   75 ++
 11 files changed, 1987 insertions(+)
 create mode 100644 src/Python/ccpi/viewer/CILViewer.py
 create mode 100644 src/Python/ccpi/viewer/CILViewer2D.py
 create mode 100644 src/Python/ccpi/viewer/QVTKWidget.py
 create mode 100644 src/Python/ccpi/viewer/QVTKWidget2.py
 create mode 100644 src/Python/ccpi/viewer/__init__.py
 create mode 100644 src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc
 create mode 100644 src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc
 create mode 100644 src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc
 create mode 100644 src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc
 create mode 100644 src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc
 create mode 100644 src/Python/ccpi/viewer/embedvtk.py

(limited to 'src/Python')

diff --git a/src/Python/ccpi/viewer/CILViewer.py b/src/Python/ccpi/viewer/CILViewer.py
new file mode 100644
index 0000000..efcf8be
--- /dev/null
+++ b/src/Python/ccpi/viewer/CILViewer.py
@@ -0,0 +1,361 @@
+# -*- coding: utf-8 -*-
+#   Copyright 2017 Edoardo Pasca
+#
+#   Licensed under the Apache License, Version 2.0 (the "License");
+#   you may not use this file except in compliance with the License.
+#   You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+#   Unless required by applicable law or agreed to in writing, software
+#   distributed under the License is distributed on an "AS IS" BASIS,
+#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#   See the License for the specific language governing permissions and
+#   limitations under the License.
+   
+import vtk
+import numpy
+import math
+from vtk.util import numpy_support
+
+SLICE_ORIENTATION_XY = 2 # Z
+SLICE_ORIENTATION_XZ = 1 # Y
+SLICE_ORIENTATION_YZ = 0 # X
+
+
+
+class CILViewer():
+    '''Simple 3D Viewer based on VTK classes'''
+    
+    def __init__(self, dimx=600,dimy=600):
+        '''creates the rendering pipeline'''
+        
+        # create a rendering window and renderer
+        self.ren = vtk.vtkRenderer()
+        self.renWin = vtk.vtkRenderWindow()
+        self.renWin.SetSize(dimx,dimy)
+        self.renWin.AddRenderer(self.ren)
+
+        # img 3D as slice
+        self.img3D = None
+        self.sliceno = 0
+        self.sliceOrientation = SLICE_ORIENTATION_XY
+        self.sliceActor = None
+        self.voi = None
+        self.wl = None
+        self.ia = None
+        self.sliceActorNo = 0
+        # create a renderwindowinteractor
+        self.iren = vtk.vtkRenderWindowInteractor()
+        self.iren.SetRenderWindow(self.renWin)
+
+        self.style = vtk.vtkInteractorStyleTrackballCamera()
+        self.iren.SetInteractorStyle(self.style)
+
+        self.ren.SetBackground(.1, .2, .4)
+
+        self.actors = {}
+        self.iren.RemoveObservers('MouseWheelForwardEvent')
+        self.iren.RemoveObservers('MouseWheelBackwardEvent')
+        
+        self.iren.AddObserver('MouseWheelForwardEvent', self.mouseInteraction, 1.0)
+        self.iren.AddObserver('MouseWheelBackwardEvent', self.mouseInteraction, 1.0)
+
+        self.iren.RemoveObservers('KeyPressEvent')
+        self.iren.AddObserver('KeyPressEvent', self.keyPress, 1.0)
+        
+        
+        self.iren.Initialize()
+
+        
+
+    def getRenderer(self):
+        '''returns the renderer'''
+        return self.ren
+
+    def getRenderWindow(self):
+        '''returns the render window'''
+        return self.renWin
+
+    def getInteractor(self):
+        '''returns the render window interactor'''
+        return self.iren
+
+    def getCamera(self):
+        '''returns the active camera'''
+        return self.ren.GetActiveCamera()
+
+    def createPolyDataActor(self, polydata):
+        '''returns an actor for a given polydata'''
+        mapper = vtk.vtkPolyDataMapper()
+        if vtk.VTK_MAJOR_VERSION <= 5:
+            mapper.SetInput(polydata)
+        else:
+            mapper.SetInputData(polydata)
+   
+        # actor
+        actor = vtk.vtkActor()
+        actor.SetMapper(mapper)
+        #actor.GetProperty().SetOpacity(0.8)
+        return actor
+
+    def setPolyDataActor(self, actor):
+        '''displays the given polydata'''
+        
+        self.ren.AddActor(actor)
+        
+        self.actors[len(self.actors)+1] = [actor, True]
+        self.iren.Initialize()
+        self.renWin.Render()
+
+    def displayPolyData(self, polydata):
+        self.setPolyDataActor(self.createPolyDataActor(polydata))
+        
+    def hideActor(self, actorno):
+        '''Hides an actor identified by its number in the list of actors'''
+        try:
+            if self.actors[actorno][1]:
+                self.ren.RemoveActor(self.actors[actorno][0])
+                self.actors[actorno][1] = False
+        except KeyError as ke:
+            print ("Warning Actor not present")
+        
+    def showActor(self, actorno, actor = None):
+        '''Shows hidden actor identified by its number in the list of actors'''
+        try:
+            if not self.actors[actorno][1]:
+                self.ren.AddActor(self.actors[actorno][0])
+                self.actors[actorno][1] = True
+                return actorno
+        except KeyError as ke:
+            # adds it to the actors if not there already
+            if actor != None:
+                self.ren.AddActor(actor)
+                self.actors[len(self.actors)+1] = [actor, True]
+                return len(self.actors)
+
+    def addActor(self, actor):
+        '''Adds an actor to the render'''
+        return self.showActor(0, actor)
+            
+        
+    def saveRender(self, filename, renWin=None):
+        '''Save the render window to PNG file'''
+        # screenshot code:
+        w2if = vtk.vtkWindowToImageFilter()
+        if renWin == None:
+            renWin = self.renWin
+        w2if.SetInput(renWin)
+        w2if.Update()
+         
+        writer = vtk.vtkPNGWriter()
+        writer.SetFileName("%s.png" % (filename))
+        writer.SetInputConnection(w2if.GetOutputPort())
+        writer.Write()
+
+    
+    def startRenderLoop(self):
+        self.iren.Start()
+
+
+    def setupObservers(self, interactor):
+        interactor.RemoveObservers('LeftButtonPressEvent')
+        interactor.AddObserver('LeftButtonPressEvent', self.mouseInteraction)
+        interactor.Initialize()
+
+        
+    def mouseInteraction(self, interactor, event):
+        if event == 'MouseWheelForwardEvent':
+            maxSlice = self.img3D.GetDimensions()[self.sliceOrientation]
+            if (self.sliceno + 1 < maxSlice):
+                self.hideActor(self.sliceActorNo)
+                self.sliceno = self.sliceno + 1
+                self.displaySliceActor(self.sliceno)
+        else:
+            minSlice = 0
+            if (self.sliceno - 1 > minSlice):
+                self.hideActor(self.sliceActorNo)
+                self.sliceno = self.sliceno - 1
+                self.displaySliceActor(self.sliceno)
+                 
+
+    def keyPress(self, interactor, event):
+        #print ("Pressed key %s" % interactor.GetKeyCode())
+        # Slice Orientation 
+        if interactor.GetKeyCode() == "x":
+            # slice on the other orientation
+            self.sliceOrientation = SLICE_ORIENTATION_YZ
+            self.sliceno = int(self.img3D.GetDimensions()[1] / 2)
+            self.hideActor(self.sliceActorNo)
+            self.displaySliceActor(self.sliceno)
+        elif interactor.GetKeyCode() == "y":
+            # slice on the other orientation
+            self.sliceOrientation = SLICE_ORIENTATION_XZ
+            self.sliceno = int(self.img3D.GetDimensions()[1] / 2)
+            self.hideActor(self.sliceActorNo)
+            self.displaySliceActor(self.sliceno)
+        elif interactor.GetKeyCode() == "z":
+            # slice on the other orientation
+            self.sliceOrientation = SLICE_ORIENTATION_XY
+            self.sliceno = int(self.img3D.GetDimensions()[2] / 2)
+            self.hideActor(self.sliceActorNo)
+            self.displaySliceActor(self.sliceno)
+        if interactor.GetKeyCode() == "X":
+            # Change the camera view point
+            camera = vtk.vtkCamera()
+            camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint())
+            camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp())
+            newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()]
+            newposition[SLICE_ORIENTATION_YZ] = math.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) 
+            camera.SetPosition(newposition)
+            camera.SetViewUp(0,0,-1)
+            self.ren.SetActiveCamera(camera)
+            self.ren.ResetCamera()
+            self.ren.Render()
+            interactor.SetKeyCode("x")
+            self.keyPress(interactor, event)
+        elif interactor.GetKeyCode() == "Y":
+             # Change the camera view point
+            camera = vtk.vtkCamera()
+            camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint())
+            camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp())
+            newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()]
+            newposition[SLICE_ORIENTATION_XZ] = math.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_YZ] ** 2) 
+            camera.SetPosition(newposition)
+            camera.SetViewUp(0,0,-1)
+            self.ren.SetActiveCamera(camera)
+            self.ren.ResetCamera()
+            self.ren.Render()
+            interactor.SetKeyCode("y")
+            self.keyPress(interactor, event)
+        elif interactor.GetKeyCode() == "Z":
+             # Change the camera view point
+            camera = vtk.vtkCamera()
+            camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint())
+            camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp())
+            newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()]
+            newposition[SLICE_ORIENTATION_XY] = math.sqrt(newposition[SLICE_ORIENTATION_YZ] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) 
+            camera.SetPosition(newposition)
+            camera.SetViewUp(0,0,-1)
+            self.ren.SetActiveCamera(camera)
+            self.ren.ResetCamera()
+            self.ren.Render()
+            interactor.SetKeyCode("z")
+            self.keyPress(interactor, event)
+        else :
+            print ("Unhandled event %s" % interactor.GetKeyCode())
+
+
+        
+    def setInput3DData(self, imageData):
+        self.img3D = imageData
+
+    def setInputAsNumpy(self, numpyarray):
+        if (len(numpy.shape(numpyarray)) == 3):
+            doubleImg = vtk.vtkImageData()
+            shape = numpy.shape(numpyarray)
+            doubleImg.SetDimensions(shape[0], shape[1], shape[2])
+            doubleImg.SetOrigin(0,0,0)
+            doubleImg.SetSpacing(1,1,1)
+            doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1)
+            #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation())
+            doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1)
+            
+            for i in range(shape[0]):
+                for j in range(shape[1]):
+                    for k in range(shape[2]):
+                        doubleImg.SetScalarComponentFromDouble(
+                            i,j,k,0, numpyarray[i][j][k])
+        #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) )
+            # rescale to appropriate VTK_UNSIGNED_SHORT
+            stats = vtk.vtkImageAccumulate()
+            stats.SetInputData(doubleImg)
+            stats.Update()
+            iMin = stats.GetMin()[0]
+            iMax = stats.GetMax()[0]
+            scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin)
+
+            shiftScaler = vtk.vtkImageShiftScale ()
+            shiftScaler.SetInputData(doubleImg)
+            shiftScaler.SetScale(scale)
+            shiftScaler.SetShift(iMin)
+            shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT)
+            shiftScaler.Update()
+            self.img3D = shiftScaler.GetOutput()
+            
+    def displaySliceActor(self, sliceno = 0):
+        self.sliceno = sliceno
+        first = False
+        
+        self.sliceActor , self.voi, self.wl , self.ia = \
+                        self.getSliceActor(self.img3D,
+                                                 sliceno,
+                                                 self.sliceActor,
+                                                 self.voi,
+                                                 self.wl,
+                                                 self.ia)
+        no = self.showActor(self.sliceActorNo, self.sliceActor)
+        self.sliceActorNo = no
+        
+        self.iren.Initialize()
+        self.renWin.Render()
+        
+        return self.sliceActorNo
+
+                      
+    def getSliceActor(self,
+                      imageData ,
+                      sliceno=0,
+                      imageActor=None ,
+                      voi=None,
+                      windowLevel=None,
+                      imageAccumulate=None):
+        '''Slices a 3D volume and then creates an actor to be rendered'''
+        if (voi==None):
+            voi = vtk.vtkExtractVOI()
+            #voi = vtk.vtkImageClip()
+        voi.SetInputData(imageData)
+        #select one slice in Z
+        extent = [ i for i in self.img3D.GetExtent()]
+        extent[self.sliceOrientation * 2] = sliceno
+        extent[self.sliceOrientation * 2 + 1] = sliceno 
+        voi.SetVOI(extent[0], extent[1],
+                   extent[2], extent[3],
+                   extent[4], extent[5])
+        
+        voi.Update()
+        # set window/level for all slices
+        if imageAccumulate == None:
+            imageAccumulate = vtk.vtkImageAccumulate()
+        
+        if (windowLevel == None):
+            windowLevel = vtk.vtkImageMapToWindowLevelColors()
+            imageAccumulate.SetInputData(imageData)
+            imageAccumulate.Update()
+            cmax = imageAccumulate.GetMax()[0]
+            cmin = imageAccumulate.GetMin()[0]
+            windowLevel.SetLevel((cmax+cmin)/2)
+            windowLevel.SetWindow(cmax-cmin)
+
+        windowLevel.SetInputData(voi.GetOutput())
+        windowLevel.Update()
+            
+        if imageActor == None:
+            imageActor = vtk.vtkImageActor()
+        imageActor.SetInputData(windowLevel.GetOutput())
+        imageActor.SetDisplayExtent(extent[0], extent[1],
+                   extent[2], extent[3],
+                   extent[4], extent[5])
+        imageActor.Update()
+        return (imageActor , voi, windowLevel, imageAccumulate)
+
+
+    # Set interpolation on
+    def setInterpolateOn(self):
+        self.sliceActor.SetInterpolate(True)
+        self.renWin.Render()
+
+    # Set interpolation off
+    def setInterpolateOff(self):
+        self.sliceActor.SetInterpolate(False)
+        self.renWin.Render()
\ No newline at end of file
diff --git a/src/Python/ccpi/viewer/CILViewer2D.py b/src/Python/ccpi/viewer/CILViewer2D.py
new file mode 100644
index 0000000..c1629af
--- /dev/null
+++ b/src/Python/ccpi/viewer/CILViewer2D.py
@@ -0,0 +1,1126 @@
+# -*- coding: utf-8 -*-
+#   Copyright 2017 Edoardo Pasca
+#
+#   Licensed under the Apache License, Version 2.0 (the "License");
+#   you may not use this file except in compliance with the License.
+#   You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+#   Unless required by applicable law or agreed to in writing, software
+#   distributed under the License is distributed on an "AS IS" BASIS,
+#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#   See the License for the specific language governing permissions and
+#   limitations under the License.
+   
+import vtk
+import numpy
+from vtk.util import numpy_support , vtkImageImportFromArray
+from enum import Enum
+
+SLICE_ORIENTATION_XY = 2 # Z
+SLICE_ORIENTATION_XZ = 1 # Y
+SLICE_ORIENTATION_YZ = 0 # X
+
+CONTROL_KEY = 8
+SHIFT_KEY = 4
+ALT_KEY = -128
+
+
+# Converter class
+class Converter():
+    
+    # Utility functions to transform numpy arrays to vtkImageData and viceversa
+    @staticmethod
+    def numpy2vtkImporter(nparray, spacing=(1.,1.,1.), origin=(0,0,0)):
+        '''Creates a vtkImageImportFromArray object and returns it.
+        
+        It handles the different axis order from numpy to VTK'''
+        importer = vtkImageImportFromArray.vtkImageImportFromArray()
+        importer.SetArray(numpy.transpose(nparray).copy())
+        importer.SetDataSpacing(spacing)
+        importer.SetDataOrigin(origin)
+        return importer
+    
+    @staticmethod
+    def numpy2vtk(nparray, spacing=(1.,1.,1.), origin=(0,0,0)):
+        '''Converts a 3D numpy array to a vtkImageData'''
+        importer = Converter.numpy2vtkImporter(nparray, spacing, origin)
+        importer.Update()
+        return importer.GetOutput()
+    
+    @staticmethod
+    def vtk2numpy(imgdata):
+        '''Converts the VTK data to 3D numpy array'''
+        img_data = numpy_support.vtk_to_numpy(
+                imgdata.GetPointData().GetScalars())
+    
+        dims = imgdata.GetDimensions()
+        dims = (dims[2],dims[1],dims[0])
+        data3d = numpy.reshape(img_data, dims)
+        
+        return numpy.transpose(data3d).copy() 
+
+    @staticmethod
+    def tiffStack2numpy(filename, indices, 
+                        extent = None , sampleRate = None ,\
+                        flatField = None, darkField = None):
+        '''Converts a stack of TIFF files to numpy array.
+        
+        filename must contain the whole path. The filename is supposed to be named and
+        have a suffix with the ordinal file number, i.e. /path/to/projection_%03d.tif
+        
+        indices are the suffix, generally an increasing number
+        
+        Optionally extracts only a selection of the 2D images and (optionally)
+        normalizes.
+        '''
+        
+        stack = vtk.vtkImageData()
+        reader = vtk.vtkTIFFReader()
+        voi = vtk.vtkExtractVOI()
+        
+        #directory = "C:\\Users\\ofn77899\\Documents\\CCPi\\IMAT\\20170419_crabtomo\\crabtomo\\"
+        
+        stack_image = numpy.asarray([])
+        nreduced = len(indices)
+        
+        for num in range(len(indices)):
+            fn = filename % indices[num]
+            print ("resampling %s" % ( fn ) )
+            reader.SetFileName(fn)
+            reader.Update()     
+            print (reader.GetOutput().GetScalarTypeAsString())
+            if num == 0:
+                if (extent == None):
+                    sliced = reader.GetOutput().GetExtent()
+                    stack.SetExtent(sliced[0],sliced[1], sliced[2],sliced[3], 0, nreduced-1)
+                else:
+                    sliced = extent
+                    voi.SetVOI(extent)
+                   
+                    if sampleRate is not None:
+                        voi.SetSampleRate(sampleRate)
+                        ext = numpy.asarray([(sliced[2*i+1] - sliced[2*i])/sampleRate[i] for i in range(3)], dtype=int)
+                        print ("ext {0}".format(ext))
+                        stack.SetExtent(0, ext[0] , 0, ext[1], 0, nreduced-1)
+                    else:
+                         stack.SetExtent(0, sliced[1] - sliced[0] , 0, sliced[3]-sliced[2], 0, nreduced-1)
+                if (flatField != None and darkField != None):
+                    stack.AllocateScalars(vtk.VTK_FLOAT, 1)
+                else:
+                    stack.AllocateScalars(reader.GetOutput().GetScalarType(), 1)
+                print ("Image Size: %d" % ((sliced[1]+1)*(sliced[3]+1) ))
+                stack_image = Converter.vtk2numpy(stack)
+                print ("Stack shape %s" % str(numpy.shape(stack_image)))
+            
+            if extent!=None:
+                voi.SetInputData(reader.GetOutput())
+                voi.Update()
+                img = voi.GetOutput()
+            else:
+                img = reader.GetOutput()
+                
+            theSlice = Converter.vtk2numpy(img).T[0]
+            if darkField != None and flatField != None:
+                print("Try to normalize")
+                #if numpy.shape(darkField) == numpy.shape(flatField) and numpy.shape(flatField) == numpy.shape(theSlice):
+                theSlice = Converter.normalize(theSlice, darkField, flatField, 0.01)
+                print (theSlice.dtype)
+            
+                    
+            print ("Slice shape %s" % str(numpy.shape(theSlice)))
+            stack_image.T[num] = theSlice.copy()
+        
+        return stack_image
+    
+    @staticmethod
+    def normalize(projection, dark, flat, def_val=0):
+        a = (projection - dark)
+        b = (flat-dark)
+        with numpy.errstate(divide='ignore', invalid='ignore'):
+            c = numpy.true_divide( a, b )
+            c[ ~ numpy.isfinite( c )] = def_val  # set to not zero if 0/0 
+        return c
+
+
+
+## Utility functions to transform numpy arrays to vtkImageData and viceversa
+#def numpy2vtkImporter(nparray, spacing=(1.,1.,1.), origin=(0,0,0)):
+#    return Converter.numpy2vtkImporter(nparray, spacing, origin)
+#
+#def numpy2vtk(nparray, spacing=(1.,1.,1.), origin=(0,0,0)):
+#    return Converter.numpy2vtk(nparray, spacing, origin)
+#
+#def vtk2numpy(imgdata):
+#    return Converter.vtk2numpy(imgdata)
+#
+#def tiffStack2numpy(filename, indices):
+#    return Converter.tiffStack2numpy(filename, indices)
+
+class ViewerEvent(Enum):
+    # left button
+    PICK_EVENT = 0 
+    # alt  + right button + move
+    WINDOW_LEVEL_EVENT = 1
+    # shift + right button
+    ZOOM_EVENT = 2
+    # control + right button
+    PAN_EVENT = 3
+    # control + left button
+    CREATE_ROI_EVENT = 4
+    # alt + left button
+    DELETE_ROI_EVENT = 5
+    # release button
+    NO_EVENT = -1
+
+
+#class CILInteractorStyle(vtk.vtkInteractorStyleTrackballCamera):
+class CILInteractorStyle(vtk.vtkInteractorStyleImage):
+    
+    def __init__(self, callback):
+        vtk.vtkInteractorStyleImage.__init__(self)
+        self.callback = callback
+        self._viewer = callback
+        priority = 1.0
+        
+#        self.AddObserver("MouseWheelForwardEvent" , callback.OnMouseWheelForward , priority)
+#        self.AddObserver("MouseWheelBackwardEvent" , callback.OnMouseWheelBackward, priority)
+#        self.AddObserver('KeyPressEvent', callback.OnKeyPress, priority)
+#        self.AddObserver('LeftButtonPressEvent', callback.OnLeftButtonPressEvent, priority)
+#        self.AddObserver('RightButtonPressEvent', callback.OnRightButtonPressEvent, priority)
+#        self.AddObserver('LeftButtonReleaseEvent', callback.OnLeftButtonReleaseEvent, priority)
+#        self.AddObserver('RightButtonReleaseEvent', callback.OnRightButtonReleaseEvent, priority)
+#        self.AddObserver('MouseMoveEvent', callback.OnMouseMoveEvent, priority)
+        
+        self.AddObserver("MouseWheelForwardEvent" , self.OnMouseWheelForward , priority)
+        self.AddObserver("MouseWheelBackwardEvent" , self.OnMouseWheelBackward, priority)
+        self.AddObserver('KeyPressEvent', self.OnKeyPress, priority)
+        self.AddObserver('LeftButtonPressEvent', self.OnLeftButtonPressEvent, priority)
+        self.AddObserver('RightButtonPressEvent', self.OnRightButtonPressEvent, priority)
+        self.AddObserver('LeftButtonReleaseEvent', self.OnLeftButtonReleaseEvent, priority)
+        self.AddObserver('RightButtonReleaseEvent', self.OnRightButtonReleaseEvent, priority)
+        self.AddObserver('MouseMoveEvent', self.OnMouseMoveEvent, priority)
+        
+        self.InitialEventPosition = (0,0)
+        
+        
+    def SetInitialEventPosition(self, xy):
+        self.InitialEventPosition = xy
+        
+    def GetInitialEventPosition(self):
+        return self.InitialEventPosition
+    
+    def GetKeyCode(self):
+        return self.GetInteractor().GetKeyCode()
+    
+    def SetKeyCode(self, keycode):
+        self.GetInteractor().SetKeyCode(keycode)
+        
+    def GetControlKey(self):
+        return self.GetInteractor().GetControlKey() == CONTROL_KEY
+    
+    def GetShiftKey(self):
+        return self.GetInteractor().GetShiftKey() == SHIFT_KEY
+    
+    def GetAltKey(self):
+        return self.GetInteractor().GetAltKey() == ALT_KEY
+    
+    def GetEventPosition(self):
+        return self.GetInteractor().GetEventPosition()
+    
+    def GetEventPositionInWorldCoordinates(self):
+        pass
+    
+    def GetDeltaEventPosition(self):
+        x,y = self.GetInteractor().GetEventPosition()
+        return (x - self.InitialEventPosition[0] , y - self.InitialEventPosition[1])
+    
+    def Dolly(self, factor):
+        self.callback.camera.Dolly(factor)
+        self.callback.ren.ResetCameraClippingRange()
+        
+    def GetDimensions(self):
+        return self._viewer.img3D.GetDimensions()
+    
+    def GetInputData(self):
+        return self._viewer.img3D
+    
+    def GetSliceOrientation(self):
+        return self._viewer.sliceOrientation
+    
+    def SetSliceOrientation(self, orientation):
+        self._viewer.sliceOrientation = orientation
+
+    def GetActiveSlice(self):
+        return self._viewer.sliceno
+    
+    def SetActiveSlice(self, sliceno):
+        self._viewer.sliceno = sliceno
+    
+    def UpdatePipeline(self, reset = False):
+        self._viewer.updatePipeline(reset)
+        
+    def GetActiveCamera(self):
+        return self._viewer.ren.GetActiveCamera()
+    
+    def SetActiveCamera(self, camera):
+        self._viewer.ren.SetActiveCamera(camera)
+    
+    def ResetCamera(self):
+        self._viewer.ren.ResetCamera()
+    
+    def Render(self):
+        self._viewer.renWin.Render()
+        
+    def UpdateSliceActor(self):
+        self._viewer.sliceActor.Update()
+    
+    def AdjustCamera(self):
+        self._viewer.AdjustCamera()
+        
+    def SaveRender(self, filename):
+        self._viewer.SaveRender(filename)
+        
+    def GetRenderWindow(self):
+        return self._viewer.renWin
+        
+    def GetRenderer(self):
+        return self._viewer.ren
+    
+    def GetROIWidget(self):
+        return self._viewer.ROIWidget
+    
+    def SetViewerEvent(self, event):
+        self._viewer.event = event
+        
+    def GetViewerEvent(self):
+        return self._viewer.event
+    
+    def SetInitialCameraPosition(self, position):
+        self._viewer.InitialCameraPosition = position
+        
+    def GetInitialCameraPosition(self):
+        return self._viewer.InitialCameraPosition
+
+    def SetInitialLevel(self, level):
+        self._viewer.InitialLevel = level
+    
+    def GetInitialLevel(self):
+        return self._viewer.InitialLevel
+    
+    def SetInitialWindow(self, window):
+        self._viewer.InitialWindow = window
+    
+    def GetInitialWindow(self):
+        return self._viewer.InitialWindow
+    
+    def GetWindowLevel(self):
+        return self._viewer.wl
+    
+    def SetROI(self, roi):
+        self._viewer.ROI = roi
+        
+    def GetROI(self):
+        return self._viewer.ROI
+    
+    def UpdateCornerAnnotation(self, text, corner):
+        self._viewer.updateCornerAnnotation(text, corner)
+
+    def GetPicker(self):
+        return self._viewer.picker
+    
+    def GetCornerAnnotation(self):
+        return self._viewer.cornerAnnotation
+    
+    def UpdateROIHistogram(self):
+        self._viewer.updateROIHistogram()
+        
+        
+    ############### Handle events
+    def OnMouseWheelForward(self, interactor, event):
+        maxSlice = self.GetDimensions()[self.GetSliceOrientation()]
+        shift = interactor.GetShiftKey()
+        advance = 1
+        if shift:
+            advance = 10
+            
+        if (self.GetActiveSlice() + advance < maxSlice):
+            self.SetActiveSlice(self.GetActiveSlice() + advance)
+            
+            self.UpdatePipeline()
+        else:
+            print ("maxSlice %d request %d" % (maxSlice, self.GetActiveSlice() + 1 ))
+    
+    def OnMouseWheelBackward(self, interactor, event):
+        minSlice = 0
+        shift = interactor.GetShiftKey()
+        advance = 1
+        if shift:
+            advance = 10
+        if (self.GetActiveSlice() - advance >= minSlice):
+            self.SetActiveSlice( self.GetActiveSlice() - advance)
+            self.UpdatePipeline()
+        else:
+            print ("minSlice %d request %d" % (minSlice, self.GetActiveSlice() + 1 ))
+        
+    def OnKeyPress(self, interactor, event):
+        #print ("Pressed key %s" % interactor.GetKeyCode())
+        # Slice Orientation 
+        if interactor.GetKeyCode() == "X":
+            # slice on the other orientation
+            self.SetSliceOrientation ( SLICE_ORIENTATION_YZ )
+            self.SetActiveSlice( int(self.GetDimensions()[1] / 2) )
+            self.UpdatePipeline(True)
+        elif interactor.GetKeyCode() == "Y":
+            # slice on the other orientation
+            self.SetSliceOrientation (  SLICE_ORIENTATION_XZ )
+            self.SetActiveSlice ( int(self.GetInputData().GetDimensions()[1] / 2) )
+            self.UpdatePipeline(True)
+        elif interactor.GetKeyCode() == "Z":
+            # slice on the other orientation
+            self.SetSliceOrientation (  SLICE_ORIENTATION_XY )
+            self.SetActiveSlice ( int(self.GetInputData().GetDimensions()[2] / 2) )
+            self.UpdatePipeline(True)
+        if interactor.GetKeyCode() == "x":
+            # Change the camera view point
+            camera = vtk.vtkCamera()
+            camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint())
+            camera.SetViewUp(self.GetActiveCamera().GetViewUp())
+            newposition = [i for i in self.GetActiveCamera().GetFocalPoint()]
+            newposition[SLICE_ORIENTATION_YZ] = numpy.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) 
+            camera.SetPosition(newposition)
+            camera.SetViewUp(0,0,-1)
+            self.SetActiveCamera(camera)
+            self.Render()
+            interactor.SetKeyCode("X")
+            self.OnKeyPress(interactor, event)
+        elif interactor.GetKeyCode() == "y":
+             # Change the camera view point
+            camera = vtk.vtkCamera()
+            camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint())
+            camera.SetViewUp(self.GetActiveCamera().GetViewUp())
+            newposition = [i for i in self.GetActiveCamera().GetFocalPoint()]
+            newposition[SLICE_ORIENTATION_XZ] = numpy.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_YZ] ** 2) 
+            camera.SetPosition(newposition)
+            camera.SetViewUp(0,0,-1)
+            self.SetActiveCamera(camera)
+            self.Render()
+            interactor.SetKeyCode("Y")
+            self.OnKeyPress(interactor, event)
+        elif interactor.GetKeyCode() == "z":
+             # Change the camera view point
+            camera = vtk.vtkCamera()
+            camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint())
+            camera.SetViewUp(self.GetActiveCamera().GetViewUp())
+            newposition = [i for i in self.GetActiveCamera().GetFocalPoint()]
+            newposition[SLICE_ORIENTATION_XY] = numpy.sqrt(newposition[SLICE_ORIENTATION_YZ] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) 
+            camera.SetPosition(newposition)
+            camera.SetViewUp(0,1,0)
+            self.SetActiveCamera(camera)
+            self.ResetCamera()
+            self.Render()
+            interactor.SetKeyCode("Z")
+            self.OnKeyPress(interactor, event)
+        elif interactor.GetKeyCode() == "a":
+            # reset color/window
+            cmax = self._viewer.ia.GetMax()[0]
+            cmin = self._viewer.ia.GetMin()[0]
+            
+            self.SetInitialLevel( (cmax+cmin)/2 )
+            self.SetInitialWindow( cmax-cmin )
+            
+            self.GetWindowLevel().SetLevel(self.GetInitialLevel())
+            self.GetWindowLevel().SetWindow(self.GetInitialWindow())
+            
+            self.GetWindowLevel().Update()
+                
+            self.UpdateSliceActor()
+            self.AdjustCamera()
+            self.Render()
+            
+        elif interactor.GetKeyCode() == "s":
+            filename = "current_render"
+            self.SaveRender(filename)
+        elif interactor.GetKeyCode() == "q":
+            print ("Terminating by pressing q %s" % (interactor.GetKeyCode(), ))
+            interactor.SetKeyCode("e")
+            self.OnKeyPress(interactor, event)
+        else :
+            #print ("Unhandled event %s" % (interactor.GetKeyCode(), )))
+            pass 
+    
+    def OnLeftButtonPressEvent(self, interactor, event):
+        alt = interactor.GetAltKey()
+        shift = interactor.GetShiftKey()
+        ctrl = interactor.GetControlKey()
+#        print ("alt pressed " + (lambda x : "Yes" if x else "No")(alt))
+#        print ("shift pressed " + (lambda x : "Yes" if x else "No")(shift))
+#        print ("ctrl pressed " + (lambda x : "Yes" if x else "No")(ctrl))
+        
+        interactor.SetInitialEventPosition(interactor.GetEventPosition())
+        
+        if ctrl and not (alt and shift): 
+            self.SetViewerEvent( ViewerEvent.CREATE_ROI_EVENT )
+            wsize = self.GetRenderWindow().GetSize()
+            position = interactor.GetEventPosition()
+            self.GetROIWidget().GetBorderRepresentation().SetPosition((position[0]/wsize[0] - 0.05) , (position[1]/wsize[1] - 0.05))
+            self.GetROIWidget().GetBorderRepresentation().SetPosition2( (0.1) , (0.1))
+            
+            self.GetROIWidget().On()
+            self.SetDisplayHistogram(True)
+            self.Render()
+            print ("Event %s is CREATE_ROI_EVENT" % (event))
+        elif alt and not (shift and ctrl):
+            self.SetViewerEvent( ViewerEvent.DELETE_ROI_EVENT )
+            self.GetROIWidget().Off()
+            self._viewer.updateCornerAnnotation("", 1, False)
+            self.SetDisplayHistogram(False)
+            self.Render()
+            print ("Event %s is DELETE_ROI_EVENT" % (event))
+        elif not (ctrl and alt and shift):
+            self.SetViewerEvent ( ViewerEvent.PICK_EVENT )
+            self.HandlePickEvent(interactor, event)
+            print ("Event %s is PICK_EVENT" % (event))
+        
+          
+    def SetDisplayHistogram(self, display):
+        if display:
+            if (self._viewer.displayHistogram == 0):
+                self.GetRenderer().AddActor(self._viewer.histogramPlotActor)
+                self.firstHistogram = 1
+                self.Render()
+                
+            self._viewer.histogramPlotActor.VisibilityOn()
+            self._viewer.displayHistogram = True
+        else:
+            self._viewer.histogramPlotActor.VisibilityOff()
+            self._viewer.displayHistogram = False
+            
+    
+    def OnLeftButtonReleaseEvent(self, interactor, event):
+        if self.GetViewerEvent() == ViewerEvent.CREATE_ROI_EVENT:
+            #bc = self.ROIWidget.GetBorderRepresentation().GetPositionCoordinate()
+            #print (bc.GetValue())
+            self.OnROIModifiedEvent(interactor, event)
+            
+        elif self.GetViewerEvent() == ViewerEvent.PICK_EVENT:
+            self.HandlePickEvent(interactor, event)
+         
+        self.SetViewerEvent( ViewerEvent.NO_EVENT )
+
+    def OnRightButtonPressEvent(self, interactor, event):
+        alt = interactor.GetAltKey()
+        shift = interactor.GetShiftKey()
+        ctrl = interactor.GetControlKey()
+#        print ("alt pressed " + (lambda x : "Yes" if x else "No")(alt))
+#        print ("shift pressed " + (lambda x : "Yes" if x else "No")(shift))
+#        print ("ctrl pressed " + (lambda x : "Yes" if x else "No")(ctrl))
+        
+        interactor.SetInitialEventPosition(interactor.GetEventPosition())
+        
+        
+        if alt and not (ctrl and shift):
+            self.SetViewerEvent( ViewerEvent.WINDOW_LEVEL_EVENT )
+            print ("Event %s is WINDOW_LEVEL_EVENT" % (event))
+            self.HandleWindowLevel(interactor, event)
+        elif shift and not (ctrl and alt):
+            self.SetViewerEvent( ViewerEvent.ZOOM_EVENT )
+            self.SetInitialCameraPosition( self.GetActiveCamera().GetPosition())
+            print ("Event %s is ZOOM_EVENT" % (event))
+        elif ctrl and not (shift and alt):
+            self.SetViewerEvent (ViewerEvent.PAN_EVENT )
+            self.SetInitialCameraPosition ( self.GetActiveCamera().GetPosition() )
+            print ("Event %s is PAN_EVENT" % (event))
+        
+    def OnRightButtonReleaseEvent(self, interactor, event):
+        print (event)
+        if self.GetViewerEvent() == ViewerEvent.WINDOW_LEVEL_EVENT:
+            self.SetInitialLevel( self.GetWindowLevel().GetLevel() )
+            self.SetInitialWindow ( self.GetWindowLevel().GetWindow() )
+        elif self.GetViewerEvent() == ViewerEvent.ZOOM_EVENT or \
+             self.GetViewerEvent() == ViewerEvent.PAN_EVENT:
+            self.SetInitialCameraPosition( () )
+			
+        self.SetViewerEvent( ViewerEvent.NO_EVENT )
+        
+    
+    def OnROIModifiedEvent(self, interactor, event):
+        
+        #print ("ROI EVENT " + event)
+        p1 = self.GetROIWidget().GetBorderRepresentation().GetPositionCoordinate()
+        p2 = self.GetROIWidget().GetBorderRepresentation().GetPosition2Coordinate()
+        wsize = self.GetRenderWindow().GetSize()
+        
+        #print (p1.GetValue())
+        #print (p2.GetValue())
+        pp1 = [p1.GetValue()[0] * wsize[0] , p1.GetValue()[1] * wsize[1] , 0.0]
+        pp2 = [p2.GetValue()[0] * wsize[0] + pp1[0] , p2.GetValue()[1] * wsize[1] + pp1[1] , 0.0]
+        vox1 = self.viewport2imageCoordinate(pp1)
+        vox2 = self.viewport2imageCoordinate(pp2)
+        
+        self.SetROI( (vox1 , vox2) )
+        roi = self.GetROI()
+        print ("Pixel1 %d,%d,%d Value %f" % vox1 )
+        print ("Pixel2 %d,%d,%d Value %f" % vox2 )
+        if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: 
+            print ("slice orientation : XY")
+            x = abs(roi[1][0] - roi[0][0])
+            y = abs(roi[1][1] - roi[0][1])
+        elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ:
+            print ("slice orientation : XY")
+            x = abs(roi[1][0] - roi[0][0])
+            y = abs(roi[1][2] - roi[0][2])
+        elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ:
+            print ("slice orientation : XY")
+            x = abs(roi[1][1] - roi[0][1])
+            y = abs(roi[1][2] - roi[0][2])
+        
+        text = "ROI: %d x %d, %.2f kp" % (x,y,float(x*y)/1024.)
+        print (text)
+        self.UpdateCornerAnnotation(text, 1)
+        self.UpdateROIHistogram()
+        self.SetViewerEvent( ViewerEvent.NO_EVENT )
+        
+    def viewport2imageCoordinate(self, viewerposition):
+        #Determine point index
+        
+        self.GetPicker().Pick(viewerposition[0], viewerposition[1], 0.0, self.GetRenderer())
+        pickPosition = list(self.GetPicker().GetPickPosition())
+        pickPosition[self.GetSliceOrientation()] = \
+            self.GetInputData().GetSpacing()[self.GetSliceOrientation()] * self.GetActiveSlice() + \
+            self.GetInputData().GetOrigin()[self.GetSliceOrientation()]
+        print ("Pick Position " + str (pickPosition))
+        
+        if (pickPosition != [0,0,0]):
+            dims = self.GetInputData().GetDimensions()
+            print (dims)
+            spac = self.GetInputData().GetSpacing()
+            orig = self.GetInputData().GetOrigin()
+            imagePosition = [int(pickPosition[i] / spac[i] + orig[i]) for i in range(3) ]
+            
+            pixelValue = self.GetInputData().GetScalarComponentAsDouble(imagePosition[0], imagePosition[1], imagePosition[2], 0)
+            return (imagePosition[0], imagePosition[1], imagePosition[2] , pixelValue)
+        else:
+            return (0,0,0,0)
+
+        
+    
+    
+    def OnMouseMoveEvent(self, interactor, event):        
+        if self.GetViewerEvent() == ViewerEvent.WINDOW_LEVEL_EVENT:
+            print ("Event %s is WINDOW_LEVEL_EVENT" % (event))
+            self.HandleWindowLevel(interactor, event)    
+        elif self.GetViewerEvent() == ViewerEvent.PICK_EVENT:
+            self.HandlePickEvent(interactor, event)
+        elif self.GetViewerEvent() == ViewerEvent.ZOOM_EVENT:
+            self.HandleZoomEvent(interactor, event)
+        elif self.GetViewerEvent() == ViewerEvent.PAN_EVENT:
+            self.HandlePanEvent(interactor, event)
+            
+            
+    def HandleZoomEvent(self, interactor, event):
+        dx,dy = interactor.GetDeltaEventPosition()   
+        size = self.GetRenderWindow().GetSize()
+        dy = - 4 * dy / size[1]
+        
+        print ("distance: " + str(self.GetActiveCamera().GetDistance()))
+        
+        print ("\ndy: %f\ncamera dolly %f\n" % (dy, 1 + dy))
+        
+        camera = vtk.vtkCamera()
+        camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint())
+        #print ("current position " + str(self.InitialCameraPosition))
+        camera.SetViewUp(self.GetActiveCamera().GetViewUp())
+        camera.SetPosition(self.GetInitialCameraPosition())
+        newposition = [i for i in self.GetInitialCameraPosition()]
+        if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: 
+            dist = newposition[SLICE_ORIENTATION_XY] * ( 1 + dy ) 
+            newposition[SLICE_ORIENTATION_XY] *= ( 1 + dy )
+        elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ:
+            newposition[SLICE_ORIENTATION_XZ] *= ( 1 + dy )
+        elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ:
+            newposition[SLICE_ORIENTATION_YZ] *= ( 1 + dy )
+        #print ("new position " + str(newposition))
+        camera.SetPosition(newposition)
+        self.SetActiveCamera(camera)
+        
+        self.Render()
+        	
+        print ("distance after: " + str(self.GetActiveCamera().GetDistance()))
+        
+    def HandlePanEvent(self, interactor, event):
+        x,y = interactor.GetEventPosition()
+        x0,y0 = interactor.GetInitialEventPosition()
+        
+        ic = self.viewport2imageCoordinate((x,y))
+        ic0 = self.viewport2imageCoordinate((x0,y0))
+        
+        dx = 4 *( ic[0] - ic0[0])
+        dy = 4* (ic[1] - ic0[1])
+        
+        camera = vtk.vtkCamera()
+        #print ("current position " + str(self.InitialCameraPosition))
+        camera.SetViewUp(self.GetActiveCamera().GetViewUp())
+        camera.SetPosition(self.GetInitialCameraPosition())
+        newposition = [i for i in self.GetInitialCameraPosition()]
+        newfocalpoint = [i for i in self.GetActiveCamera().GetFocalPoint()]
+        if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: 
+            newposition[0] -= dx
+            newposition[1] -= dy
+            newfocalpoint[0] = newposition[0]
+            newfocalpoint[1] = newposition[1]
+        elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ:
+            newposition[0] -= dx
+            newposition[2] -= dy
+            newfocalpoint[0] = newposition[0]
+            newfocalpoint[2] = newposition[2]
+        elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ:
+            newposition[1] -= dx
+            newposition[2] -= dy
+            newfocalpoint[2] = newposition[2]
+            newfocalpoint[1] = newposition[1]
+        #print ("new position " + str(newposition))
+        camera.SetFocalPoint(newfocalpoint)
+        camera.SetPosition(newposition)
+        self.SetActiveCamera(camera)
+        
+        self.Render()
+        
+    def HandleWindowLevel(self, interactor, event):
+        dx,dy = interactor.GetDeltaEventPosition()
+        print ("Event delta %d %d" % (dx,dy))
+        size = self.GetRenderWindow().GetSize()
+        
+        dx = 4 * dx / size[0]
+        dy = 4 * dy / size[1]
+        window = self.GetInitialWindow()
+        level = self.GetInitialLevel()
+        
+        if abs(window) > 0.01:
+            dx = dx * window
+        else:
+            dx = dx * (lambda x: -0.01 if x <0 else 0.01)(window);
+			
+        if abs(level) > 0.01:
+            dy = dy * level
+        else:
+            dy = dy * (lambda x: -0.01 if x <0 else 0.01)(level)
+			
+
+        # Abs so that direction does not flip
+
+        if window < 0.0:
+            dx = -1*dx
+        if level < 0.0:
+            dy = -1*dy
+
+		 # Compute new window level
+
+        newWindow = dx + window
+        newLevel = level - dy
+
+        # Stay away from zero and really
+
+        if abs(newWindow) < 0.01:
+            newWindow = 0.01 * (lambda x: -1 if x <0 else 1)(newWindow)
+
+        if abs(newLevel) < 0.01:
+            newLevel = 0.01 * (lambda x: -1 if x <0 else 1)(newLevel)
+
+        self.GetWindowLevel().SetWindow(newWindow)
+        self.GetWindowLevel().SetLevel(newLevel)
+        
+        self.GetWindowLevel().Update()
+        self.UpdateSliceActor()
+        self.AdjustCamera()
+        
+        self.Render()
+    
+    def HandlePickEvent(self, interactor, event):
+        position = interactor.GetEventPosition()
+        #print ("PICK " + str(position))
+        vox = self.viewport2imageCoordinate(position)
+        #print ("Pixel %d,%d,%d Value %f" % vox )
+        self._viewer.cornerAnnotation.VisibilityOn()
+        self.UpdateCornerAnnotation("[%d,%d,%d] : %.2f" % vox , 0)
+        self.Render()
+        
+###############################################################################
+    
+        
+
+class CILViewer2D():
+    '''Simple Interactive Viewer based on VTK classes'''
+    
+    def __init__(self, dimx=600,dimy=600, ren=None, renWin=None,iren=None):
+        '''creates the rendering pipeline'''
+        # create a rendering window and renderer
+        if ren == None:
+            self.ren = vtk.vtkRenderer()
+        else:
+            self.ren = ren
+        if renWin == None:
+            self.renWin = vtk.vtkRenderWindow()
+        else:
+            self.renWin = renWin
+        if iren == None:
+            self.iren = vtk.vtkRenderWindowInteractor()
+        else:
+            self.iren = iren
+            
+        self.renWin.SetSize(dimx,dimy)
+        self.renWin.AddRenderer(self.ren)
+        
+        self.style = CILInteractorStyle(self)
+        
+        self.iren.SetInteractorStyle(self.style)
+        self.iren.SetRenderWindow(self.renWin)
+        self.iren.Initialize()
+        self.ren.SetBackground(.1, .2, .4)
+        
+        self.camera = vtk.vtkCamera()
+        self.camera.ParallelProjectionOn()
+        self.ren.SetActiveCamera(self.camera)
+        
+        # data
+        self.img3D = None
+        self.sliceno = 0
+        self.sliceOrientation = SLICE_ORIENTATION_XY
+        
+        #Actors
+        self.sliceActor = vtk.vtkImageActor()
+        self.voi = vtk.vtkExtractVOI()
+        self.wl = vtk.vtkImageMapToWindowLevelColors()
+        self.ia = vtk.vtkImageAccumulate()
+        self.sliceActorNo = 0
+        
+        #initial Window/Level
+        self.InitialLevel = 0
+        self.InitialWindow = 0
+        
+        #ViewerEvent
+        self.event = ViewerEvent.NO_EVENT
+        
+        # ROI Widget
+        self.ROIWidget = vtk.vtkBorderWidget()
+        self.ROIWidget.SetInteractor(self.iren)
+        self.ROIWidget.CreateDefaultRepresentation()
+        self.ROIWidget.GetBorderRepresentation().GetBorderProperty().SetColor(0,1,0)
+        self.ROIWidget.AddObserver(vtk.vtkWidgetEvent.Select, self.style.OnROIModifiedEvent, 1.0)
+        
+        # edge points of the ROI
+        self.ROI = ()
+        
+        #picker
+        self.picker = vtk.vtkPropPicker()
+        self.picker.PickFromListOn()
+        self.picker.AddPickList(self.sliceActor)
+
+        self.iren.SetPicker(self.picker)
+        
+        # corner annotation
+        self.cornerAnnotation = vtk.vtkCornerAnnotation()
+        self.cornerAnnotation.SetMaximumFontSize(12);
+        self.cornerAnnotation.PickableOff();
+        self.cornerAnnotation.VisibilityOff();
+        self.cornerAnnotation.GetTextProperty().ShadowOn();
+        self.cornerAnnotation.SetLayerNumber(1);
+        
+        
+        
+        # cursor doesn't show up
+        self.cursor = vtk.vtkCursor2D()
+        self.cursorMapper = vtk.vtkPolyDataMapper2D()
+        self.cursorActor = vtk.vtkActor2D()
+        self.cursor.SetModelBounds(-10, 10, -10, 10, 0, 0)
+        self.cursor.SetFocalPoint(0, 0, 0)
+        self.cursor.AllOff()
+        self.cursor.AxesOn()
+        self.cursorActor.PickableOff()
+        self.cursorActor.VisibilityOn()
+        self.cursorActor.GetProperty().SetColor(1, 1, 1)
+        self.cursorActor.SetLayerNumber(1)
+        self.cursorMapper.SetInputData(self.cursor.GetOutput())
+        self.cursorActor.SetMapper(self.cursorMapper)
+        
+        # Zoom
+        self.InitialCameraPosition = ()
+        
+        # XY Plot actor for histogram
+        self.displayHistogram = False
+        self.firstHistogram = 0
+        self.roiIA = vtk.vtkImageAccumulate()
+        self.roiVOI = vtk.vtkExtractVOI()
+        self.histogramPlotActor = vtk.vtkXYPlotActor()
+        self.histogramPlotActor.ExchangeAxesOff();
+        self.histogramPlotActor.SetXLabelFormat( "%g" )
+        self.histogramPlotActor.SetXLabelFormat( "%g" )
+        self.histogramPlotActor.SetAdjustXLabels(3)
+        self.histogramPlotActor.SetXTitle( "Level" )
+        self.histogramPlotActor.SetYTitle( "N" )
+        self.histogramPlotActor.SetXValuesToValue()
+        self.histogramPlotActor.SetPlotColor(0, (0,1,1) )
+        self.histogramPlotActor.SetPosition(0.6,0.6)
+        self.histogramPlotActor.SetPosition2(0.4,0.4)
+ 
+        
+        
+    def GetInteractor(self):
+        return self.iren
+    
+    def GetRenderer(self):
+        return self.ren
+        
+    def setInput3DData(self, imageData):
+        self.img3D = imageData
+        self.installPipeline()
+
+    def setInputAsNumpy(self, numpyarray,  origin=(0,0,0), spacing=(1.,1.,1.), 
+                        rescale=True, dtype=vtk.VTK_UNSIGNED_SHORT):
+        importer = Converter.numpy2vtkImporter(numpyarray, spacing, origin)
+        importer.Update()
+        
+        if rescale:
+            # rescale to appropriate VTK_UNSIGNED_SHORT
+            stats = vtk.vtkImageAccumulate()
+            stats.SetInputData(importer.GetOutput())
+            stats.Update()
+            iMin = stats.GetMin()[0]
+            iMax = stats.GetMax()[0]
+            if (iMax - iMin == 0):
+                scale = 1
+            else:
+                if dtype == vtk.VTK_UNSIGNED_SHORT:
+                    scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin)
+                elif dtype == vtk.VTK_UNSIGNED_INT:
+                    scale = vtk.VTK_UNSIGNED_INT_MAX / (iMax - iMin)
+    
+            shiftScaler = vtk.vtkImageShiftScale ()
+            shiftScaler.SetInputData(importer.GetOutput())
+            shiftScaler.SetScale(scale)
+            shiftScaler.SetShift(-iMin)
+            shiftScaler.SetOutputScalarType(dtype)
+            shiftScaler.Update()
+            self.img3D = shiftScaler.GetOutput()
+        else:
+            self.img3D = importer.GetOutput()
+            
+        self.installPipeline()
+
+    def displaySlice(self, sliceno = 0):
+        self.sliceno = sliceno
+        
+        self.updatePipeline()
+        
+        self.renWin.Render()
+        
+        return self.sliceActorNo
+
+    def updatePipeline(self, resetcamera = False):
+        extent = [ i for i in self.img3D.GetExtent()]
+        extent[self.sliceOrientation * 2] = self.sliceno
+        extent[self.sliceOrientation * 2 + 1] = self.sliceno 
+        self.voi.SetVOI(extent[0], extent[1],
+                   extent[2], extent[3],
+                   extent[4], extent[5])
+        
+        self.voi.Update()
+        self.ia.Update()
+        self.wl.Update()
+        self.sliceActor.SetDisplayExtent(extent[0], extent[1],
+                   extent[2], extent[3],
+                   extent[4], extent[5])
+        self.sliceActor.Update()
+        
+        self.updateCornerAnnotation("Slice %d/%d" % (self.sliceno + 1 , self.img3D.GetDimensions()[self.sliceOrientation]))
+        
+        if self.displayHistogram:
+            self.updateROIHistogram()            
+            
+        self.AdjustCamera(resetcamera)
+        
+        self.renWin.Render()
+        
+        
+    def installPipeline(self):
+        '''Slices a 3D volume and then creates an actor to be rendered'''
+        
+        self.ren.AddViewProp(self.cornerAnnotation)
+        
+        self.voi.SetInputData(self.img3D)
+        #select one slice in Z
+        extent = [ i for i in self.img3D.GetExtent()]
+        extent[self.sliceOrientation * 2] = self.sliceno
+        extent[self.sliceOrientation * 2 + 1] = self.sliceno 
+        self.voi.SetVOI(extent[0], extent[1],
+                   extent[2], extent[3],
+                   extent[4], extent[5])
+        
+        self.voi.Update()
+        # set window/level for current slices
+         
+    
+        self.wl = vtk.vtkImageMapToWindowLevelColors()
+        self.ia.SetInputData(self.voi.GetOutput())
+        self.ia.Update()
+        cmax = self.ia.GetMax()[0]
+        cmin = self.ia.GetMin()[0]
+        
+        self.InitialLevel = (cmax+cmin)/2
+        self.InitialWindow = cmax-cmin
+
+        
+        self.wl.SetLevel(self.InitialLevel)
+        self.wl.SetWindow(self.InitialWindow)
+        
+        self.wl.SetInputData(self.voi.GetOutput())
+        self.wl.Update()
+            
+        self.sliceActor.SetInputData(self.wl.GetOutput())
+        self.sliceActor.SetDisplayExtent(extent[0], extent[1],
+                   extent[2], extent[3],
+                   extent[4], extent[5])
+        self.sliceActor.Update()
+        self.sliceActor.SetInterpolate(False)
+        self.ren.AddActor(self.sliceActor)
+        self.ren.ResetCamera()
+        self.ren.Render()
+        
+        self.AdjustCamera()
+        
+        self.ren.AddViewProp(self.cursorActor)
+        self.cursorActor.VisibilityOn()
+        
+        self.iren.Initialize()
+        self.renWin.Render()
+        #self.iren.Start()
+    
+    def AdjustCamera(self, resetcamera = False):
+        self.ren.ResetCameraClippingRange()
+        if resetcamera:
+            self.ren.ResetCamera()
+        
+            
+    def getROI(self):
+        return self.ROI
+    
+    def getROIExtent(self):
+        p0 = self.ROI[0]
+        p1 = self.ROI[1]
+        return (p0[0], p1[0],p0[1],p1[1],p0[2],p1[2])
+        
+    ############### Handle events are moved to the interactor style
+    
+        
+    def viewport2imageCoordinate(self, viewerposition):
+        #Determine point index
+        
+        self.picker.Pick(viewerposition[0], viewerposition[1], 0.0, self.GetRenderer())
+        pickPosition = list(self.picker.GetPickPosition())
+        pickPosition[self.sliceOrientation] = \
+            self.img3D.GetSpacing()[self.sliceOrientation] * self.sliceno + \
+            self.img3D.GetOrigin()[self.sliceOrientation]
+        print ("Pick Position " + str (pickPosition))
+        
+        if (pickPosition != [0,0,0]):
+            dims = self.img3D.GetDimensions()
+            print (dims)
+            spac = self.img3D.GetSpacing()
+            orig = self.img3D.GetOrigin()
+            imagePosition = [int(pickPosition[i] / spac[i] + orig[i]) for i in range(3) ]
+            
+            pixelValue = self.img3D.GetScalarComponentAsDouble(imagePosition[0], imagePosition[1], imagePosition[2], 0)
+            return (imagePosition[0], imagePosition[1], imagePosition[2] , pixelValue)
+        else:
+            return (0,0,0,0)
+
+        
+    
+    def GetRenderWindow(self):
+        return self.renWin
+    
+    
+    def startRenderLoop(self):
+        self.iren.Start()
+        
+    def GetSliceOrientation(self):
+        return self.sliceOrientation
+    
+    def GetActiveSlice(self):
+        return self.sliceno
+    
+    def updateCornerAnnotation(self, text , idx=0, visibility=True):
+        if visibility:
+            self.cornerAnnotation.VisibilityOn()
+        else:
+            self.cornerAnnotation.VisibilityOff()
+            
+        self.cornerAnnotation.SetText(idx, text)
+        self.iren.Render()
+        
+    def saveRender(self, filename, renWin=None):
+        '''Save the render window to PNG file'''
+        # screenshot code:
+        w2if = vtk.vtkWindowToImageFilter()
+        if renWin == None:
+            renWin = self.renWin
+        w2if.SetInput(renWin)
+        w2if.Update()
+         
+        writer = vtk.vtkPNGWriter()
+        writer.SetFileName("%s.png" % (filename))
+        writer.SetInputConnection(w2if.GetOutputPort())
+        writer.Write()
+    
+    def updateROIHistogram(self):
+        
+        extent = [0 for i in range(6)]
+        if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: 
+            print ("slice orientation : XY")
+            extent[0] = self.ROI[0][0]
+            extent[1] = self.ROI[1][0]
+            extent[2] = self.ROI[0][1]
+            extent[3] = self.ROI[1][1]
+            extent[4] = self.GetActiveSlice()
+            extent[5] = self.GetActiveSlice()+1
+            #y = abs(roi[1][1] - roi[0][1])
+        elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ:
+            print ("slice orientation : XY")
+            extent[0] = self.ROI[0][0]
+            extent[1] = self.ROI[1][0]
+            #x = abs(roi[1][0] - roi[0][0])
+            extent[4] = self.ROI[0][2]
+            extent[5] = self.ROI[1][2]
+            #y = abs(roi[1][2] - roi[0][2])
+            extent[2] = self.GetActiveSlice()
+            extent[3] = self.GetActiveSlice()+1
+        elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ:
+            print ("slice orientation : XY")
+            extent[2] = self.ROI[0][1]
+            extent[3] = self.ROI[1][1]
+            #x = abs(roi[1][1] - roi[0][1])
+            extent[4] = self.ROI[0][2]
+            extent[5] = self.ROI[1][2]
+            #y = abs(roi[1][2] - roi[0][2])
+            extent[0] = self.GetActiveSlice()
+            extent[1] = self.GetActiveSlice()+1
+        
+        self.roiVOI.SetVOI(extent)
+        self.roiVOI.SetInputData(self.img3D)
+        self.roiVOI.Update()
+        irange = self.roiVOI.GetOutput().GetScalarRange()
+        
+        self.roiIA.SetInputData(self.roiVOI.GetOutput())
+        self.roiIA.IgnoreZeroOff()
+        self.roiIA.SetComponentExtent(0,int(irange[1]-irange[0]-1),0,0,0,0 )
+        self.roiIA.SetComponentOrigin( int(irange[0]),0,0 );
+        self.roiIA.SetComponentSpacing( 1,0,0 );
+        self.roiIA.Update()
+        
+        self.histogramPlotActor.AddDataSetInputConnection(self.roiIA.GetOutputPort())
+        self.histogramPlotActor.SetXRange(irange[0],irange[1])
+        
+        self.histogramPlotActor.SetYRange( self.roiIA.GetOutput().GetScalarRange() )
+        
+        
\ No newline at end of file
diff --git a/src/Python/ccpi/viewer/QVTKWidget.py b/src/Python/ccpi/viewer/QVTKWidget.py
new file mode 100644
index 0000000..906786b
--- /dev/null
+++ b/src/Python/ccpi/viewer/QVTKWidget.py
@@ -0,0 +1,340 @@
+################################################################################
+# File:         QVTKWidget.py
+# Author:       Edoardo Pasca
+# Description:  PyVE Viewer Qt widget
+#
+# License:
+#               This file is part of PyVE. PyVE is an open-source image 
+#               analysis and visualization environment focused on medical
+#               imaging. More info at http://pyve.sourceforge.net
+#	       
+#               Copyright (c) 2011-2012 Edoardo Pasca, Lukas Batteau 
+#               All rights reserved.
+#	       
+#              Redistribution and use in source and binary forms, with or
+#              without modification, are permitted provided that the following
+#              conditions are met:
+#
+#              Redistributions of source code must retain the above copyright
+#              notice, this list of conditions and the following disclaimer.
+#              Redistributions in binary form must reproduce the above
+#              copyright notice, this list of conditions and the following
+#              disclaimer in the documentation and/or other materials provided
+#              with the distribution.  Neither name of Edoardo Pasca or Lukas
+#              Batteau nor the names of any contributors may be used to endorse
+#              or promote products derived from this software without specific
+#              prior written permission.
+#
+#              THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
+#              CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES,
+#              INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
+#              MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+#              DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE
+#              LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
+#              OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+#              PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
+#              OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+#              THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
+#              TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
+#              OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
+#              OF SUCH DAMAGE.
+#
+# CHANGE HISTORY
+#
+# 20120118    Edoardo Pasca        Initial version
+#             
+###############################################################################
+
+import os
+from PyQt5 import QtCore, QtGui, QtWidgets
+#import itk
+import vtk
+#from viewer import PyveViewer
+from ccpi.viewer.CILViewer2D import CILViewer2D , Converter
+
+class QVTKWidget(QtWidgets.QWidget):
+
+    """ A QVTKWidget for Python and Qt."""
+
+    # Map between VTK and Qt cursors.
+    _CURSOR_MAP = {
+        0:  QtCore.Qt.ArrowCursor,          # VTK_CURSOR_DEFAULT
+        1:  QtCore.Qt.ArrowCursor,          # VTK_CURSOR_ARROW
+        2:  QtCore.Qt.SizeBDiagCursor,      # VTK_CURSOR_SIZENE
+        3:  QtCore.Qt.SizeFDiagCursor,      # VTK_CURSOR_SIZENWSE
+        4:  QtCore.Qt.SizeBDiagCursor,      # VTK_CURSOR_SIZESW
+        5:  QtCore.Qt.SizeFDiagCursor,      # VTK_CURSOR_SIZESE
+        6:  QtCore.Qt.SizeVerCursor,        # VTK_CURSOR_SIZENS
+        7:  QtCore.Qt.SizeHorCursor,        # VTK_CURSOR_SIZEWE
+        8:  QtCore.Qt.SizeAllCursor,        # VTK_CURSOR_SIZEALL
+        9:  QtCore.Qt.PointingHandCursor,   # VTK_CURSOR_HAND
+        10: QtCore.Qt.CrossCursor,          # VTK_CURSOR_CROSSHAIR
+    }
+
+    def __init__(self, parent=None, wflags=QtCore.Qt.WindowFlags(), **kw):
+        # the current button
+        self._ActiveButton = QtCore.Qt.NoButton
+
+        # private attributes
+        self.__oldFocus = None
+        self.__saveX = 0
+        self.__saveY = 0
+        self.__saveModifiers = QtCore.Qt.NoModifier
+        self.__saveButtons = QtCore.Qt.NoButton
+        self.__timeframe = 0
+
+        # create qt-level widget
+        QtWidgets.QWidget.__init__(self, parent, wflags|QtCore.Qt.MSWindowsOwnDC)
+        
+        # Link to PyVE Viewer
+        self._PyveViewer = CILViewer2D()
+        #self._Viewer = self._PyveViewer._vtkPyveViewer
+        
+        self._Iren = self._PyveViewer.GetInteractor()
+        #self._Iren = self._Viewer.GetRenderWindow().GetInteractor()
+        self._RenderWindow = self._PyveViewer.GetRenderWindow()
+        #self._RenderWindow = self._Viewer.GetRenderWindow()
+        
+        self._Iren.Register(self._RenderWindow)
+        self._Iren.SetRenderWindow(self._RenderWindow)
+        self._RenderWindow.SetWindowInfo(str(int(self.winId())))
+
+        # do all the necessary qt setup
+        self.setAttribute(QtCore.Qt.WA_OpaquePaintEvent)
+        self.setAttribute(QtCore.Qt.WA_PaintOnScreen)
+        self.setMouseTracking(True) # get all mouse events
+        self.setFocusPolicy(QtCore.Qt.WheelFocus)
+        self.setSizePolicy(QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding))
+
+        self._Timer = QtCore.QTimer(self)
+        #self.connect(self._Timer, QtCore.pyqtSignal('timeout()'), self.TimerEvent)
+
+        self._Iren.AddObserver('CreateTimerEvent', self.CreateTimer)
+        self._Iren.AddObserver('DestroyTimerEvent', self.DestroyTimer)
+        self._Iren.GetRenderWindow().AddObserver('CursorChangedEvent',
+                                                 self.CursorChangedEvent)
+
+    # Destructor
+    def __del__(self):
+        self._Iren.UnRegister(self._RenderWindow)
+        #QtWidgets.QWidget.__del__(self)
+
+    # Display image data
+    def SetInput(self, imageData):
+        self._PyveViewer.setInput3DData(imageData)
+        
+    # GetInteractor
+    def GetInteractor(self):
+        return self._Iren
+    
+    # Display image data
+    def GetPyveViewer(self):
+        return self._PyveViewer
+
+    def __getattr__(self, attr):
+        """Makes the object behave like a vtkGenericRenderWindowInteractor"""
+        print (attr)
+        if attr == '__vtk__':
+            return lambda t=self._Iren: t
+        elif hasattr(self._Iren, attr):
+            return getattr(self._Iren, attr)
+#        else:
+#            raise AttributeError( self.__class__.__name__ + \
+#                  " has no attribute named " + attr )
+
+    def CreateTimer(self, obj, evt):
+        self._Timer.start(10)
+
+    def DestroyTimer(self, obj, evt):
+        self._Timer.stop()
+        return 1
+
+    def TimerEvent(self):
+        self._Iren.InvokeEvent("TimerEvent")
+
+    def CursorChangedEvent(self, obj, evt):
+        """Called when the CursorChangedEvent fires on the render window."""
+        # This indirection is needed since when the event fires, the current
+        # cursor is not yet set so we defer this by which time the current
+        # cursor should have been set.
+        QtCore.QTimer.singleShot(0, self.ShowCursor)
+
+    def HideCursor(self):
+        """Hides the cursor."""
+        self.setCursor(QtCore.Qt.BlankCursor)
+
+    def ShowCursor(self):
+        """Shows the cursor."""
+        vtk_cursor = self._Iren.GetRenderWindow().GetCurrentCursor()
+        qt_cursor = self._CURSOR_MAP.get(vtk_cursor, QtCore.Qt.ArrowCursor)
+        self.setCursor(qt_cursor)
+
+    def sizeHint(self):
+        return QtCore.QSize(400, 400)
+
+    def paintEngine(self):
+        return None
+
+    def paintEvent(self, ev):
+        self._RenderWindow.Render()
+
+    def resizeEvent(self, ev):
+        self._RenderWindow.Render()
+        w = self.width()
+        h = self.height()
+
+        self._RenderWindow.SetSize(w, h)
+        self._Iren.SetSize(w, h)
+
+    def _GetCtrlShiftAlt(self, ev):
+        ctrl = shift = alt = False
+
+        if hasattr(ev, 'modifiers'):
+            if ev.modifiers() & QtCore.Qt.ShiftModifier:
+                shift = True
+            if ev.modifiers() & QtCore.Qt.ControlModifier:
+                ctrl = True
+            if ev.modifiers() & QtCore.Qt.AltModifier:
+                alt = True
+        else:
+            if self.__saveModifiers & QtCore.Qt.ShiftModifier:
+                shift = True
+            if self.__saveModifiers & QtCore.Qt.ControlModifier:
+                ctrl = True
+            if self.__saveModifiers & QtCore.Qt.AltModifier:
+                alt = True
+
+        return ctrl, shift, alt
+
+    def enterEvent(self, ev):
+        if not self.hasFocus():
+            self.__oldFocus = self.focusWidget()
+            self.setFocus()
+
+        ctrl, shift, alt = self._GetCtrlShiftAlt(ev)
+        self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY,
+                                            ctrl, shift, chr(0), 0, None)
+        self._Iren.SetAltKey(alt)
+        self._Iren.InvokeEvent("EnterEvent")
+
+    def leaveEvent(self, ev):
+        if self.__saveButtons == QtCore.Qt.NoButton and self.__oldFocus:
+            self.__oldFocus.setFocus()
+            self.__oldFocus = None
+
+        ctrl, shift, alt = self._GetCtrlShiftAlt(ev)
+        self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY,
+                                            ctrl, shift, chr(0), 0, None)
+        self._Iren.SetAltKey(alt)
+        self._Iren.InvokeEvent("LeaveEvent")
+
+    def mousePressEvent(self, ev):
+        ctrl, shift, alt = self._GetCtrlShiftAlt(ev)
+        repeat = 0
+        if ev.type() == QtCore.QEvent.MouseButtonDblClick:
+            repeat = 1
+        self._Iren.SetEventInformationFlipY(ev.x(), ev.y(),
+                                            ctrl, shift, chr(0), repeat, None)
+
+        self._Iren.SetAltKey(alt)
+        self._ActiveButton = ev.button()
+
+        if self._ActiveButton == QtCore.Qt.LeftButton:
+            self._Iren.InvokeEvent("LeftButtonPressEvent")
+        elif self._ActiveButton == QtCore.Qt.RightButton:
+            self._Iren.InvokeEvent("RightButtonPressEvent")
+        elif self._ActiveButton == QtCore.Qt.MidButton:
+            self._Iren.InvokeEvent("MiddleButtonPressEvent")
+
+    def mouseReleaseEvent(self, ev):
+        ctrl, shift, alt = self._GetCtrlShiftAlt(ev)
+        self._Iren.SetEventInformationFlipY(ev.x(), ev.y(),
+                                            ctrl, shift, chr(0), 0, None)
+        self._Iren.SetAltKey(alt)
+
+        if self._ActiveButton == QtCore.Qt.LeftButton:
+            self._Iren.InvokeEvent("LeftButtonReleaseEvent")
+        elif self._ActiveButton == QtCore.Qt.RightButton:
+            self._Iren.InvokeEvent("RightButtonReleaseEvent")
+        elif self._ActiveButton == QtCore.Qt.MidButton:
+            self._Iren.InvokeEvent("MiddleButtonReleaseEvent")
+
+    def mouseMoveEvent(self, ev):
+        self.__saveModifiers = ev.modifiers()
+        self.__saveButtons = ev.buttons()
+        self.__saveX = ev.x()
+        self.__saveY = ev.y()
+
+        ctrl, shift, alt = self._GetCtrlShiftAlt(ev)
+        self._Iren.SetEventInformationFlipY(ev.x(), ev.y(),
+                                            ctrl, shift, chr(0), 0, None)
+        self._Iren.SetAltKey(alt)
+        self._Iren.InvokeEvent("MouseMoveEvent")
+
+    def keyPressEvent(self, ev):
+        ctrl, shift, alt = self._GetCtrlShiftAlt(ev)
+        if ev.key() < 256:
+            key = str(ev.text())
+        else:
+            key = chr(0)
+
+        self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY,
+                                            ctrl, shift, key, 0, None)
+        self._Iren.SetAltKey(alt)
+        self._Iren.InvokeEvent("KeyPressEvent")
+        self._Iren.InvokeEvent("CharEvent")
+
+    def keyReleaseEvent(self, ev):
+        ctrl, shift, alt = self._GetCtrlShiftAlt(ev)
+        if ev.key() < 256:
+            key = chr(ev.key())
+        else:
+            key = chr(0)
+
+        self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY,
+                                            ctrl, shift, key, 0, None)
+        self._Iren.SetAltKey(alt)
+        self._Iren.InvokeEvent("KeyReleaseEvent")
+
+    def wheelEvent(self, ev):
+        print ("angleDeltaX %d" % ev.angleDelta().x())
+        print ("angleDeltaY %d" % ev.angleDelta().y())
+        if ev.angleDelta().y() >= 0:
+            self._Iren.InvokeEvent("MouseWheelForwardEvent")
+        else:
+            self._Iren.InvokeEvent("MouseWheelBackwardEvent")
+
+    def GetRenderWindow(self):
+        return self._RenderWindow
+
+    def Render(self):
+        self.update()
+
+
+def QVTKExample():    
+    """A simple example that uses the QVTKWidget class."""
+
+    # every QT app needs an app
+    app = QtWidgets.QApplication(['PyVE QVTKWidget Example'])
+    page_VTK = QtWidgets.QWidget()
+    page_VTK.resize(500,500)
+    layout = QtWidgets.QVBoxLayout(page_VTK)
+    # create the widget
+    widget = QVTKWidget(parent=None)
+    layout.addWidget(widget)
+    
+    #reader = vtk.vtkPNGReader()
+    #reader.SetFileName("F:\Diagnostics\Images\PyVE\VTKData\Data\camscene.png")
+    reader = vtk.vtkMetaImageReader()
+    reader.SetFileName("C:\\Users\\ofn77899\\Documents\\GitHub\\CCPi-Simpleflex\\data\\head.mha")
+    reader.Update()
+    
+    widget.SetInput(reader.GetOutput())
+    
+    # show the widget
+    page_VTK.show()
+    # start event processing
+    app.exec_()
+
+if __name__ == "__main__":
+    QVTKExample()
diff --git a/src/Python/ccpi/viewer/QVTKWidget2.py b/src/Python/ccpi/viewer/QVTKWidget2.py
new file mode 100644
index 0000000..e32e1c2
--- /dev/null
+++ b/src/Python/ccpi/viewer/QVTKWidget2.py
@@ -0,0 +1,84 @@
+################################################################################
+# File:         QVTKWidget.py
+# Author:       Edoardo Pasca
+# Description:  PyVE Viewer Qt widget
+#
+# License:
+#               This file is part of PyVE. PyVE is an open-source image 
+#               analysis and visualization environment focused on medical
+#               imaging. More info at http://pyve.sourceforge.net
+#	       
+#               Copyright (c) 2011-2012 Edoardo Pasca, Lukas Batteau 
+#               All rights reserved.
+#	       
+#              Redistribution and use in source and binary forms, with or
+#              without modification, are permitted provided that the following
+#              conditions are met:
+#
+#              Redistributions of source code must retain the above copyright
+#              notice, this list of conditions and the following disclaimer.
+#              Redistributions in binary form must reproduce the above
+#              copyright notice, this list of conditions and the following
+#              disclaimer in the documentation and/or other materials provided
+#              with the distribution.  Neither name of Edoardo Pasca or Lukas
+#              Batteau nor the names of any contributors may be used to endorse
+#              or promote products derived from this software without specific
+#              prior written permission.
+#
+#              THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
+#              CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES,
+#              INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
+#              MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+#              DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE
+#              LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
+#              OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+#              PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
+#              OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+#              THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
+#              TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
+#              OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
+#              OF SUCH DAMAGE.
+#
+# CHANGE HISTORY
+#
+# 20120118    Edoardo Pasca        Initial version
+#             
+###############################################################################
+
+import os
+from PyQt5 import QtCore, QtGui, QtWidgets
+#import itk
+import vtk
+#from viewer import PyveViewer
+from ccpi.viewer.CILViewer2D import CILViewer2D , Converter
+from vtk.qt.QVTKRenderWindowInteractor import QVTKRenderWindowInteractor
+
+class QVTKWidget(QVTKRenderWindowInteractor):
+
+    
+    def __init__(self, parent=None, wflags=QtCore.Qt.WindowFlags(), **kw):
+        kw = dict() 
+        super().__init__(parent, **kw)
+        
+        
+        # Link to PyVE Viewer
+        self._PyveViewer = CILViewer2D(400,400)
+        #self._Viewer = self._PyveViewer._vtkPyveViewer
+        
+        self._Iren = self._PyveViewer.GetInteractor()
+        kw['iren'] = self._Iren
+        #self._Iren = self._Viewer.GetRenderWindow().GetInteractor()
+        self._RenderWindow = self._PyveViewer.GetRenderWindow()
+        #self._RenderWindow = self._Viewer.GetRenderWindow()
+        kw['rw'] = self._RenderWindow
+       
+        
+        
+       
+    def GetInteractor(self):
+        return self._Iren
+    
+    # Display image data
+    def SetInput(self, imageData):
+        self._PyveViewer.setInput3DData(imageData)
+    
\ No newline at end of file
diff --git a/src/Python/ccpi/viewer/__init__.py b/src/Python/ccpi/viewer/__init__.py
new file mode 100644
index 0000000..946188b
--- /dev/null
+++ b/src/Python/ccpi/viewer/__init__.py
@@ -0,0 +1 @@
+from ccpi.viewer.CILViewer import CILViewer
\ No newline at end of file
diff --git a/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc
new file mode 100644
index 0000000..711f77a
Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc differ
diff --git a/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc
new file mode 100644
index 0000000..77c2ca8
Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc differ
diff --git a/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc
new file mode 100644
index 0000000..3d11b87
Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc differ
diff --git a/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc
new file mode 100644
index 0000000..2fa2eaf
Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc differ
diff --git a/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc
new file mode 100644
index 0000000..fcea537
Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc differ
diff --git a/src/Python/ccpi/viewer/embedvtk.py b/src/Python/ccpi/viewer/embedvtk.py
new file mode 100644
index 0000000..b5eb0a7
--- /dev/null
+++ b/src/Python/ccpi/viewer/embedvtk.py
@@ -0,0 +1,75 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Thu Jul 27 12:18:58 2017
+
+@author: ofn77899
+"""
+
+#!/usr/bin/env python
+ 
+import sys
+import vtk
+from PyQt5 import QtCore, QtWidgets
+from vtk.qt.QVTKRenderWindowInteractor import QVTKRenderWindowInteractor
+import QVTKWidget2
+ 
+class MainWindow(QtWidgets.QMainWindow):
+ 
+    def __init__(self, parent = None):
+        QtWidgets.QMainWindow.__init__(self, parent)
+ 
+        self.frame = QtWidgets.QFrame()
+ 
+        self.vl = QtWidgets.QVBoxLayout()
+#        self.vtkWidget = QVTKRenderWindowInteractor(self.frame)
+        
+        self.vtkWidget = QVTKWidget2.QVTKWidget(self.frame)
+        self.iren = self.vtkWidget.GetInteractor()
+        self.vl.addWidget(self.vtkWidget)
+        
+        
+        
+    
+        self.ren = vtk.vtkRenderer()
+        self.vtkWidget.GetRenderWindow().AddRenderer(self.ren)
+#        self.iren = self.vtkWidget.GetRenderWindow().GetInteractor()
+# 
+#        # Create source
+#        source = vtk.vtkSphereSource()
+#        source.SetCenter(0, 0, 0)
+#        source.SetRadius(5.0)
+# 
+#        # Create a mapper
+#        mapper = vtk.vtkPolyDataMapper()
+#        mapper.SetInputConnection(source.GetOutputPort())
+# 
+#        # Create an actor
+#        actor = vtk.vtkActor()
+#        actor.SetMapper(mapper)
+# 
+#        self.ren.AddActor(actor)
+# 
+#        self.ren.ResetCamera()
+# 
+        self.frame.setLayout(self.vl)
+        self.setCentralWidget(self.frame)
+        reader = vtk.vtkMetaImageReader()
+        reader.SetFileName("C:\\Users\\ofn77899\\Documents\\GitHub\\CCPi-Simpleflex\\data\\head.mha")
+        reader.Update()
+        
+        self.vtkWidget.SetInput(reader.GetOutput())
+        
+        #self.vktWidget.Initialize()
+        #self.vktWidget.Start()
+        
+        self.show()
+        #self.iren.Initialize()
+ 
+ 
+if __name__ == "__main__":
+ 
+    app = QtWidgets.QApplication(sys.argv)
+ 
+    window = MainWindow()
+ 
+    sys.exit(app.exec_())
\ No newline at end of file
-- 
cgit v1.2.3


From 1a841b967e1db92a04e8e12c52b83489da27be1c Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Wed, 23 Aug 2017 15:08:32 +0100
Subject: initial revision

---
 src/Python/ccpi/imaging/Regularizer.py | 322 +++++++++++++++++++++++++++++++++
 1 file changed, 322 insertions(+)
 create mode 100644 src/Python/ccpi/imaging/Regularizer.py

(limited to 'src/Python')

diff --git a/src/Python/ccpi/imaging/Regularizer.py b/src/Python/ccpi/imaging/Regularizer.py
new file mode 100644
index 0000000..fb9ae08
--- /dev/null
+++ b/src/Python/ccpi/imaging/Regularizer.py
@@ -0,0 +1,322 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Tue Aug  8 14:26:00 2017
+
+@author: ofn77899
+"""
+
+from ccpi.imaging import cpu_regularizers
+import numpy as np
+from enum import Enum
+import timeit
+
+class Regularizer():
+    '''Class to handle regularizer algorithms to be used during reconstruction
+    
+    Currently 5 CPU (OMP) regularization algorithms are available:
+        
+    1) SplitBregman_TV
+    2) FGP_TV
+    3) LLT_model
+    4) PatchBased_Regul
+    5) TGV_PD
+    
+    Usage:
+        the regularizer can be invoked as object or as static method
+        Depending on the actual regularizer the input parameter may vary, and 
+        a different default setting is defined.
+        reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+
+        out = reg(input=u0, regularization_parameter=10., number_of_iterations=30,
+          tolerance_constant=1e-4, 
+          TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+
+        out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10.,
+          number_of_iterations=30, tolerance_constant=1e-4, 
+          TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+        
+        A number of optional parameters can be passed or skipped
+        out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. )
+
+    '''
+    class Algorithm(Enum):
+        SplitBregman_TV = cpu_regularizers.SplitBregman_TV
+        FGP_TV = cpu_regularizers.FGP_TV
+        LLT_model = cpu_regularizers.LLT_model
+        PatchBased_Regul = cpu_regularizers.PatchBased_Regul
+        TGV_PD = cpu_regularizers.TGV_PD
+    # Algorithm
+    
+    class TotalVariationPenalty(Enum):
+        isotropic = 0
+        l1 = 1
+    # TotalVariationPenalty
+        
+    def __init__(self , algorithm, debug = True):
+        self.setAlgorithm ( algorithm )
+        self.debug = debug
+    # __init__
+    
+    def setAlgorithm(self, algorithm):
+        self.algorithm = algorithm
+        self.pars = self.getDefaultParsForAlgorithm(algorithm)
+    # setAlgorithm
+        
+    def getDefaultParsForAlgorithm(self, algorithm):
+        pars = dict()
+        
+        if algorithm == Regularizer.Algorithm.SplitBregman_TV :
+            pars['algorithm'] = algorithm
+            pars['input'] = None
+            pars['regularization_parameter'] = None
+            pars['number_of_iterations'] = 35
+            pars['tolerance_constant'] = 0.0001
+            pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic
+            
+        elif algorithm == Regularizer.Algorithm.FGP_TV :
+            pars['algorithm'] = algorithm
+            pars['input'] = None
+            pars['regularization_parameter'] = None
+            pars['number_of_iterations'] = 50
+            pars['tolerance_constant'] = 0.001
+            pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic
+            
+        elif algorithm == Regularizer.Algorithm.LLT_model:
+            pars['algorithm'] = algorithm
+            pars['input'] = None
+            pars['regularization_parameter'] = None
+            pars['time_step'] = None
+            pars['number_of_iterations'] = None
+            pars['tolerance_constant'] = None
+            pars['restrictive_Z_smoothing'] = 0
+            
+        elif algorithm == Regularizer.Algorithm.PatchBased_Regul:
+            pars['algorithm'] = algorithm
+            pars['input'] = None
+            pars['searching_window_ratio'] = None
+            pars['similarity_window_ratio'] = None
+            pars['PB_filtering_parameter'] = None
+            pars['regularization_parameter'] = None
+            
+        elif algorithm == Regularizer.Algorithm.TGV_PD:
+            pars['algorithm'] = algorithm
+            pars['input'] = None
+            pars['first_order_term'] = None
+            pars['second_order_term'] = None
+            pars['number_of_iterations'] = None
+            pars['regularization_parameter'] = None
+            
+        else:
+            raise Exception('Unknown regularizer algorithm')
+            
+        return pars
+    # parsForAlgorithm
+    
+    def setParameter(self, **kwargs):
+        '''set named parameter for the regularization engine
+        
+        raises Exception if the named parameter is not recognized
+        Typical usage is:
+            
+        reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+        reg.setParameter(input=u0)    
+        reg.setParameter(regularization_parameter=10.)
+        
+        it can be also used as
+        reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+        reg.setParameter(input=u0 , regularization_parameter=10.)
+        '''
+        
+        for key , value in kwargs.items():
+            if key in self.pars.keys():
+                self.pars[key] = value
+            else:
+                raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key))
+    # setParameter
+	
+    def getParameter(self, **kwargs):
+        ret = {}
+        for key , value in kwargs.items():
+            if key in self.pars.keys():
+                ret[key] = self.pars[key]
+        else:
+            raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key))
+    # setParameter
+	
+        
+    def __call__(self, input = None, regularization_parameter = None, **kwargs):
+        '''Actual call for the regularizer. 
+        
+        One can either set the regularization parameters first and then call the
+        algorithm or set the regularization parameter during the call (as 
+        is done in the static methods). 
+        '''
+        
+        if kwargs is not None:
+            for key, value in kwargs.items():
+                #print("{0} = {1}".format(key, value))                        
+                self.pars[key] = value
+                    
+        if input is not None: 
+            self.pars['input'] = input
+        if regularization_parameter is not None:
+            self.pars['regularization_parameter'] = regularization_parameter
+            
+        if self.debug:
+            print ("--------------------------------------------------")
+            for key, value in self.pars.items():
+                if key== 'algorithm' :
+                    print("{0} = {1}".format(key, value.__name__))
+                elif key == 'input':
+                    print("{0} = {1}".format(key, np.shape(value)))
+                else:
+                    print("{0} = {1}".format(key, value))
+        
+            
+        if None in self.pars:
+                raise Exception("Not all parameters have been provided")
+        
+        input = self.pars['input']
+        regularization_parameter = self.pars['regularization_parameter']
+        if self.algorithm == Regularizer.Algorithm.SplitBregman_TV :
+            return self.algorithm(input, regularization_parameter,
+                              self.pars['number_of_iterations'],
+                              self.pars['tolerance_constant'],
+                              self.pars['TV_penalty'].value )    
+        elif self.algorithm == Regularizer.Algorithm.FGP_TV :
+            return self.algorithm(input, regularization_parameter,
+                              self.pars['number_of_iterations'],
+                              self.pars['tolerance_constant'],
+                              self.pars['TV_penalty'].value )
+        elif self.algorithm == Regularizer.Algorithm.LLT_model :
+            #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
+            # no default
+            return self.algorithm(input, 
+                              regularization_parameter,
+                              self.pars['time_step'] , 
+                              self.pars['number_of_iterations'],
+                              self.pars['tolerance_constant'],
+                              self.pars['restrictive_Z_smoothing'] )
+        elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul :
+            #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
+            # no default
+            return self.algorithm(input, regularization_parameter,
+                                  self.pars['searching_window_ratio'] , 
+                                  self.pars['similarity_window_ratio'] , 
+                                  self.pars['PB_filtering_parameter'])
+        elif self.algorithm == Regularizer.Algorithm.TGV_PD :
+            #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
+            # no default
+            if len(np.shape(input)) == 2:
+                return self.algorithm(input, regularization_parameter,
+                                  self.pars['first_order_term'] , 
+                                  self.pars['second_order_term'] , 
+                                  self.pars['number_of_iterations'])
+            elif len(np.shape(input)) == 3:
+                #assuming it's 3D
+                # run independent calls on each slice
+                out3d = input.copy()
+                for i in range(np.shape(input)[2]):
+                    out = self.algorithm(input, regularization_parameter,
+                                 self.pars['first_order_term'] , 
+                                 self.pars['second_order_term'] , 
+                                 self.pars['number_of_iterations'])
+                    # copy the result in the 3D image
+                    out3d.T[i] = out[0].copy()
+                # append the rest of the info that the algorithm returns
+                output = [out3d]
+                for i in range(1,len(out)):
+                    output.append(out[i])
+                return output
+                
+                
+            
+            
+        
+    # __call__
+    
+    @staticmethod
+    def SplitBregman_TV(input, regularization_parameter , **kwargs):
+        start_time = timeit.default_timer()
+        reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+        out = list( reg(input, regularization_parameter, **kwargs) )
+        out.append(reg.pars)
+        txt = reg.printParametersToString()
+        txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time)
+        out.append(txt)
+        return out
+        
+    @staticmethod
+    def FGP_TV(input, regularization_parameter , **kwargs):
+        start_time = timeit.default_timer()
+        reg = Regularizer(Regularizer.Algorithm.FGP_TV)
+        out = list( reg(input, regularization_parameter, **kwargs) )
+        out.append(reg.pars)
+        txt = reg.printParametersToString()
+        txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time)
+        out.append(txt)
+        return out
+    
+    @staticmethod
+    def LLT_model(input, regularization_parameter , time_step, number_of_iterations,
+                  tolerance_constant, restrictive_Z_smoothing=0):
+        start_time = timeit.default_timer()
+        reg = Regularizer(Regularizer.Algorithm.LLT_model)
+        out = list( reg(input, regularization_parameter, time_step=time_step, 
+                        number_of_iterations=number_of_iterations,
+                        tolerance_constant=tolerance_constant, 
+                        restrictive_Z_smoothing=restrictive_Z_smoothing) )
+        out.append(reg.pars)
+        txt = reg.printParametersToString()
+        txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time)
+        out.append(txt)
+        return out
+    
+    @staticmethod
+    def PatchBased_Regul(input, regularization_parameter,
+                        searching_window_ratio, 
+                        similarity_window_ratio,
+                        PB_filtering_parameter):
+        start_time = timeit.default_timer()
+        reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul)   
+        out = list( reg(input, 
+                        regularization_parameter,
+                        searching_window_ratio=searching_window_ratio, 
+                        similarity_window_ratio=similarity_window_ratio,
+                        PB_filtering_parameter=PB_filtering_parameter )
+            )
+        out.append(reg.pars)
+        txt = reg.printParametersToString()
+        txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time)
+        out.append(txt)
+        return out
+    
+    @staticmethod
+    def TGV_PD(input, regularization_parameter , first_order_term, 
+               second_order_term, number_of_iterations):
+        start_time = timeit.default_timer()
+        
+        reg = Regularizer(Regularizer.Algorithm.TGV_PD)
+        out = list( reg(input, regularization_parameter, 
+                        first_order_term=first_order_term, 
+                        second_order_term=second_order_term,
+                        number_of_iterations=number_of_iterations) )
+        out.append(reg.pars)
+        txt = reg.printParametersToString()
+        txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time)
+        out.append(txt)
+        
+        return out
+    
+    def printParametersToString(self):
+        txt = r''
+        for key, value in self.pars.items():
+            if key== 'algorithm' :
+                txt += "{0} = {1}".format(key, value.__name__)
+            elif key == 'input':
+                txt += "{0} = {1}".format(key, np.shape(value))
+            else:
+                txt += "{0} = {1}".format(key, value)
+            txt += '\n'
+        return txt
+        
-- 
cgit v1.2.3


From 9f8fb57e1e89c1ad200d9c7eada5c653be34db66 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Wed, 23 Aug 2017 15:10:04 +0100
Subject: module rename

---
 src/Python/fista_module.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

(limited to 'src/Python')

diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp
index eacda3d..c36329e 100644
--- a/src/Python/fista_module.cpp
+++ b/src/Python/fista_module.cpp
@@ -1032,13 +1032,13 @@ bp::list TGV_PD(np::ndarray input, double d_lambda, double d_alpha1, double d_al
 	return result;
 }
 
-BOOST_PYTHON_MODULE(regularizers)
+BOOST_PYTHON_MODULE(cpu_regularizers)
 {
 	np::initialize();
 
 	//To specify that this module is a package
 	bp::object package = bp::scope();
-	package.attr("__path__") = "regularizers";
+	package.attr("__path__") = "cpu_regularizers";
 
 	np::dtype dt1 = np::dtype::get_builtin<uint8_t>();
 	np::dtype dt2 = np::dtype::get_builtin<uint16_t>();
-- 
cgit v1.2.3


From a203949c84484fe2641e39451f033d20d445b1f3 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Wed, 23 Aug 2017 16:51:18 +0100
Subject: export/import data from hdf5

Added file to export the data from DemoRD2.m to HDF5 to pass it to Python.
Added file to import the data from DemoRD2.m from HDF5.
---
 src/Python/test/readhd5.py | 28 ++++++++++++++++++++++++++++
 1 file changed, 28 insertions(+)
 create mode 100644 src/Python/test/readhd5.py

(limited to 'src/Python')

diff --git a/src/Python/test/readhd5.py b/src/Python/test/readhd5.py
new file mode 100644
index 0000000..1e19e14
--- /dev/null
+++ b/src/Python/test/readhd5.py
@@ -0,0 +1,28 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Wed Aug 23 16:34:49 2017
+
+@author: ofn77899
+"""
+
+import h5py
+import numpy
+
+def getEntry(nx, location):
+    for item in nx[location].keys():
+        print (item)
+        
+filename = r'C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\Demos\DendrData.h5'
+nx = h5py.File(filename, "r")
+#getEntry(nx, '/')
+# I have exported the entries as children of /
+entries = [entry for entry in nx['/'].keys()]
+print (entries)
+
+Sino3D = numpy.asarray(nx.get('/Sino3D'))
+Weights3D = numpy.asarray(nx.get('/Weights3D'))
+angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0]
+angles_rad = numpy.asarray(nx.get('/angles_rad'))
+recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0]
+size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0]
+slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0]
\ No newline at end of file
-- 
cgit v1.2.3


From 8d53e078d3dabf7107982a8d25b4d66b1d0e73ce Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Wed, 23 Aug 2017 16:54:59 +0100
Subject: initial revision for testing

---
 .../ccpi/reconstruction/FISTAReconstructor.py      | 354 +++++++++++++++++++++
 1 file changed, 354 insertions(+)
 create mode 100644 src/Python/ccpi/reconstruction/FISTAReconstructor.py

(limited to 'src/Python')

diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py
new file mode 100644
index 0000000..ea96b53
--- /dev/null
+++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py
@@ -0,0 +1,354 @@
+# -*- coding: utf-8 -*-
+###############################################################################
+#This work is part of the Core Imaging Library developed by
+#Visual Analytics and Imaging System Group of the Science Technology
+#Facilities Council, STFC
+#
+#Copyright 2017 Edoardo Pasca, Srikanth Nagella
+#Copyright 2017 Daniil Kazantsev
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#http://www.apache.org/licenses/LICENSE-2.0
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+###############################################################################
+
+
+
+import numpy
+import h5py
+#from ccpi.reconstruction.parallelbeam import alg
+
+from ccpi.imaging.Regularizer import Regularizer
+from enum import Enum
+
+import astra
+
+   
+    
+class FISTAReconstructor():
+    '''FISTA-based reconstruction algorithm using ASTRA-toolbox
+    
+    '''
+    # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>>
+    # ___Input___:
+    # params.[] file:
+    #       - .proj_geom (geometry of the projector) [required]
+    #       - .vol_geom (geometry of the reconstructed object) [required]
+    #       - .sino (vectorized in 2D or 3D sinogram) [required]
+    #       - .iterFISTA (iterations for the main loop, default 40)
+    #       - .L_const (Lipschitz constant, default Power method)                                                                                                    )
+    #       - .X_ideal (ideal image, if given)
+    #       - .weights (statisitcal weights, size of the sinogram)
+    #       - .ROI (Region-of-interest, only if X_ideal is given)
+    #       - .initialize (a 'warm start' using SIRT method from ASTRA)
+    #----------------Regularization choices------------------------
+    #       - .Regul_Lambda_FGPTV (FGP-TV regularization parameter)
+    #       - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter)
+    #       - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter)
+    #       - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04)
+    #       - .Regul_Iterations (iterations for the selected penalty, default 25)
+    #       - .Regul_tauLLT (time step parameter for LLT term)
+    #       - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal)
+    #       - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1)
+    #----------------Visualization parameters------------------------
+    #       - .show (visualize reconstruction 1/0, (0 default))
+    #       - .maxvalplot (maximum value to use for imshow[0 maxvalplot])
+    #       - .slice (for 3D volumes - slice number to imshow)
+    # ___Output___:
+    # 1. X - reconstructed image/volume
+    # 2. output - a structure with
+    #    - .Resid_error - residual error (if X_ideal is given)
+    #    - .objective: value of the objective function
+    #    - .L_const: Lipshitz constant to avoid recalculations
+    
+    # References:
+    # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse
+    # Problems" by A. Beck and M Teboulle
+    # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo
+    # 3. "A novel tomographic reconstruction method based on the robust
+    # Student's t function for suppressing data outliers" D. Kazantsev et.al.
+    # D. Kazantsev, 2016-17
+    def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs):
+        self.params = dict()
+        self.params['projector_geometry'] = projector_geometry
+        self.params['output_geometry'] = output_geometry
+        self.params['input_sinogram'] = input_sinogram
+        detectors, nangles, sliceZ = numpy.shape(input_sinogram)
+        self.params['detectors'] = detectors
+        self.params['number_og_angles'] = nangles
+        self.params['SlicesZ'] = sliceZ
+        
+        # Accepted input keywords
+        kw = ('number_of_iterations', 
+              'Lipschitz_constant' , 
+              'ideal_image' ,
+              'weights' , 
+              'region_of_interest' , 
+              'initialize' , 
+              'regularizer' , 
+              'ring_lambda_R_L1',
+              'ring_alpha')
+        
+        # handle keyworded parameters
+        if kwargs is not None:
+            for key, value in kwargs.items():
+                if key in kw:
+                    #print("{0} = {1}".format(key, value))                        
+                    self.pars[key] = value
+                    
+        # set the default values for the parameters if not set
+        if 'number_of_iterations' in kwargs.keys():
+            self.pars['number_of_iterations'] = kwargs['number_of_iterations']
+        else:
+            self.pars['number_of_iterations'] = 40
+        if 'weights' in kwargs.keys():
+            self.pars['weights'] = kwargs['weights']
+        else:
+            self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram']))
+        if 'Lipschitz_constant' in kwargs.keys():
+            self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant']
+        else:
+            self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod()
+        
+        if not self.pars['ideal_image'] in kwargs.keys():
+            self.pars['ideal_image'] = None
+        
+        if not self.pars['region_of_interest'] :
+            if self.pars['ideal_image'] == None:
+                pass
+            else:
+                self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0)
+            
+        if not self.pars['regularizer'] :
+            self.pars['regularizer'] = None
+        else:
+            # the regularizer must be a correctly instantiated object
+            if not self.pars['ring_lambda_R_L1']:
+                self.pars['ring_lambda_R_L1'] = 0
+            if not self.pars['ring_alpha']:
+                self.pars['ring_alpha'] = 1
+        
+            
+            
+        
+    def calculateLipschitzConstantWithPowerMethod(self):
+        ''' using Power method (PM) to establish L constant'''
+        
+        #N = params.vol_geom.GridColCount
+        N = self.pars['output_geometry'].GridColCount
+        proj_geom = self.params['projector_geometry']
+        vol_geom = self.params['output_geometry']
+        weights = self.pars['weights']
+        SlicesZ = self.pars['SlicesZ']
+        
+        if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'):
+            #% for parallel geometry we can do just one slice
+            #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...');
+            niter = 15;# % number of iteration for the PM
+            #N = params.vol_geom.GridColCount;
+            #x1 = rand(N,N,1);
+            x1 = numpy.random.rand(1,N,N)
+            #sqweight = sqrt(weights(:,:,1));
+            sqweight = numpy.sqrt(weights.T[0])
+            proj_geomT = proj_geom.copy();
+            proj_geomT.DetectorRowCount = 1;
+            vol_geomT = vol_geom.copy();
+            vol_geomT['GridSliceCount'] = 1;
+            
+            
+            for i in range(niter):
+                if i == 0:
+                    #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+                    sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT);
+                    y = sqweight * y # element wise multiplication
+                    #astra_mex_data3d('delete', sino_id);
+                    astra.matlab.data3d('delete', sino_id)
+                    
+                idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT);
+                s = numpy.linalg.norm(x1)
+                ### this line?
+                x1 = x1/s;
+                ### this line?
+                sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT);
+                y = sqweight*y;
+                astra.matlab.data3d('delete', sino_id);
+                astra.matlab.data3d('delete', idx);
+            #end
+            del proj_geomT
+            del vol_geomT
+        else:
+            #% divergen beam geometry
+            #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...');
+            niter = 8; #% number of iteration for PM
+            x1 = numpy.random.rand(SlicesZ , N , N);
+            #sqweight = sqrt(weights);
+            sqweight = numpy.sqrt(weights.T[0])
+            
+            sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom);
+            y = sqweight*y;
+            #astra_mex_data3d('delete', sino_id);
+            astra.matlab.data3d('delete', sino_id);
+            
+            for i in range(niter):
+                #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom);
+                idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, 
+                                                                    proj_geom, 
+                                                                    vol_geom)
+                s = numpy.linalg.norm(x1)
+                ### this line?
+                x1 = x1/s;
+                ### this line?
+                #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom);
+                sino_id, y = astra.creators.create_sino3d_gpu(x1, 
+                                                              proj_geom, 
+                                                              vol_geom);
+                
+                y = sqweight*y;
+                #astra_mex_data3d('delete', sino_id);
+                #astra_mex_data3d('delete', id);
+                astra.matlab.data3d('delete', sino_id);
+                astra.matlab.data3d('delete', idx);
+            #end
+            #clear x1
+            del x1
+        
+        return s
+    
+    
+    def setRegularizer(self, regularizer):
+        if regularizer is not None:
+            self.pars['regularizer'] = regularizer
+        
+    
+    
+
+
+def getEntry(location):
+    for item in nx[location].keys():
+        print (item)
+
+
+print ("Loading Data")
+
+##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif"
+####ind = [i * 1049 for i in range(360)]
+#### use only 360 images
+##images = 200
+##ind = [int(i * 1049 / images) for i in range(images)]
+##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None)
+
+#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs"
+#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs"
+##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5"
+##nx = h5py.File(fname, "r")
+##
+### the data are stored in a particular location in the hdf5
+##for item in nx['entry1/tomo_entry/data'].keys():
+##    print (item)
+##
+##data = nx.get('entry1/tomo_entry/data/rotation_angle')
+##angles = numpy.zeros(data.shape)
+##data.read_direct(angles)
+##print (angles)
+### angles should be in degrees
+##
+##data = nx.get('entry1/tomo_entry/data/data')
+##stack = numpy.zeros(data.shape)
+##data.read_direct(stack)
+##print (data.shape)
+##
+##print ("Data Loaded")
+##
+##
+### Normalize
+##data = nx.get('entry1/tomo_entry/instrument/detector/image_key')
+##itype = numpy.zeros(data.shape)
+##data.read_direct(itype)
+### 2 is dark field
+##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ]
+##dark = darks[0]
+##for i in range(1, len(darks)):
+##    dark += darks[i]
+##dark = dark / len(darks)
+###dark[0][0] = dark[0][1]
+##
+### 1 is flat field
+##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ]
+##flat = flats[0]
+##for i in range(1, len(flats)):
+##    flat += flats[i]
+##flat = flat / len(flats)
+###flat[0][0] = dark[0][1]
+##
+##
+### 0 is projection data
+##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ]
+##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ]
+##angle_proj = numpy.asarray (angle_proj)
+##angle_proj = angle_proj.astype(numpy.float32)
+##
+### normalized data are
+### norm = (projection - dark)/(flat-dark)
+##
+##def normalize(projection, dark, flat, def_val=0.1):
+##    a = (projection - dark)
+##    b = (flat-dark)
+##    with numpy.errstate(divide='ignore', invalid='ignore'):
+##        c = numpy.true_divide( a, b )
+##        c[ ~ numpy.isfinite( c )] = def_val  # set to not zero if 0/0 
+##    return c
+##    
+##
+##norm = [normalize(projection, dark, flat) for projection in proj]
+##norm = numpy.asarray (norm)
+##norm = norm.astype(numpy.float32)
+
+
+##niterations = 15
+##threads = 3
+##
+##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
+##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
+##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
+##
+##iteration_values = numpy.zeros((niterations,))
+##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
+##                              iteration_values, False)
+##print ("iteration values %s" % str(iteration_values))
+##
+##iteration_values = numpy.zeros((niterations,))
+##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
+##                                      numpy.double(1e-5), iteration_values , False)
+##print ("iteration values %s" % str(iteration_values))
+##iteration_values = numpy.zeros((niterations,))
+##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
+##                                      numpy.double(1e-5), iteration_values , False)
+##print ("iteration values %s" % str(iteration_values))
+##
+##
+####numpy.save("cgls_recon.npy", img_data)
+##import matplotlib.pyplot as plt
+##fig, ax = plt.subplots(1,6,sharey=True)
+##ax[0].imshow(img_cgls[80])
+##ax[0].axis('off')  # clear x- and y-axes
+##ax[1].imshow(img_sirt[80])
+##ax[1].axis('off')  # clear x- and y-axes
+##ax[2].imshow(img_mlem[80])
+##ax[2].axis('off')  # clear x- and y-axesplt.show()
+##ax[3].imshow(img_cgls_conv[80])
+##ax[3].axis('off')  # clear x- and y-axesplt.show()
+##ax[4].imshow(img_cgls_tikhonov[80])
+##ax[4].axis('off')  # clear x- and y-axesplt.show()
+##ax[5].imshow(img_cgls_TVreg[80])
+##ax[5].axis('off')  # clear x- and y-axesplt.show()
+##
+##
+##plt.show()
+##
+
-- 
cgit v1.2.3


From 05bd227b56ec43c97c81630f50c3b741ef86ddcd Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Thu, 24 Aug 2017 16:39:37 +0100
Subject: bugfix

---
 src/Python/Matlab2Python_utils.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

(limited to 'src/Python')

diff --git a/src/Python/Matlab2Python_utils.cpp b/src/Python/Matlab2Python_utils.cpp
index e15d738..ee76bc7 100644
--- a/src/Python/Matlab2Python_utils.cpp
+++ b/src/Python/Matlab2Python_utils.cpp
@@ -123,7 +123,7 @@ T * mxGetData(const np::ndarray pm) {
 	probably this would work.
 	A = reinterpret_cast<float *>(prhs[0]);
 	*/
-	return reinterpret_cast<T *>(prhs[0]);
+	//return reinterpret_cast<T *>(prhs[0]);
 }
 
 template<typename T>
@@ -273,4 +273,4 @@ BOOST_PYTHON_MODULE(prova)
 	//numpy_boost_python_register_type<double, 3>();
 	def("mexFunction", mexFunction);
 	def("doSomething", doSomething);
-}
\ No newline at end of file
+}
-- 
cgit v1.2.3


From 879c6723969eaea8e00f97291612fe22443c69f3 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Thu, 24 Aug 2017 16:41:10 +0100
Subject: initial facility to test the FISTA

---
 src/Python/test_reconstructor.py | 179 +++++++++++++++++++++++++++++++++++++++
 1 file changed, 179 insertions(+)
 create mode 100644 src/Python/test_reconstructor.py

(limited to 'src/Python')

diff --git a/src/Python/test_reconstructor.py b/src/Python/test_reconstructor.py
new file mode 100644
index 0000000..0fd08f5
--- /dev/null
+++ b/src/Python/test_reconstructor.py
@@ -0,0 +1,179 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Wed Aug 23 16:34:49 2017
+
+@author: ofn77899
+Based on DemoRD2.m
+"""
+
+import h5py
+import numpy
+
+from ccpi.reconstruction_dev.FISTAReconstructor import FISTAReconstructor
+import astra
+
+##def getEntry(nx, location):
+##    for item in nx[location].keys():
+##        print (item)
+  
+filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5'
+nx = h5py.File(filename, "r")
+#getEntry(nx, '/')
+# I have exported the entries as children of /
+entries = [entry for entry in nx['/'].keys()]
+print (entries)
+
+Sino3D = numpy.asarray(nx.get('/Sino3D'))
+Weights3D = numpy.asarray(nx.get('/Weights3D'))
+angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0]
+angles_rad = numpy.asarray(nx.get('/angles_rad'))
+recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0]
+size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0]
+slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0]
+
+Z_slices = 3
+det_row_count = Z_slices
+# next definition is just for consistency of naming
+det_col_count = size_det
+
+detectorSpacingX = 1.0
+detectorSpacingY = detectorSpacingX
+
+
+proj_geom = astra.creators.create_proj_geom('parallel3d',
+                                            detectorSpacingX,
+                                            detectorSpacingY,
+                                            det_row_count,
+                                            det_col_count,
+                                            angles_rad)
+
+#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices);
+image_size_x = recon_size
+image_size_y = recon_size
+image_size_z = Z_slices
+vol_geom = astra.creators.create_vol_geom( image_size_x,
+                                           image_size_y,
+                                           image_size_z)
+
+## First pass the arguments to the FISTAReconstructor and test the
+## Lipschitz constant
+
+#fistaRecon = FISTAReconstructor(proj_geom, vol_geom, Sino3D )
+ #N = params.vol_geom.GridColCount
+ 
+pars = dict()
+pars['projector_geometry'] = proj_geom
+pars['output_geometry'] = vol_geom
+pars['input_sinogram'] = Sino3D
+sliceZ , nangles , detectors  = numpy.shape(Sino3D)
+pars['detectors'] = detectors
+pars['number_of_angles'] = nangles
+pars['SlicesZ'] = sliceZ
+    
+
+pars['weights'] = numpy.ones(numpy.shape(pars['input_sinogram']))
+         
+N = pars['output_geometry']['GridColCount']
+proj_geom = pars['projector_geometry']
+vol_geom = pars['output_geometry']
+weights = pars['weights']
+SlicesZ = pars['SlicesZ']
+
+if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'):
+    #% for parallel geometry we can do just one slice
+    print('Calculating Lipshitz constant for parallel beam geometry...')
+    niter = 15;# % number of iteration for the PM
+    #N = params.vol_geom.GridColCount;
+    #x1 = rand(N,N,1);
+    x1 = numpy.random.rand(1,N,N)
+    #sqweight = sqrt(weights(:,:,1));
+    sqweight = numpy.sqrt(weights[0])
+    proj_geomT = proj_geom.copy();
+    proj_geomT['DetectorRowCount'] = 1;
+    vol_geomT = vol_geom.copy();
+    vol_geomT['GridSliceCount'] = 1;
+    
+    #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+    
+    import matplotlib.pyplot as plt
+    fig = plt.figure()
+    
+    #a.set_title('Lipschitz')        
+    for i in range(niter):
+#        [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT);
+#            s = norm(x1(:));
+#            x1 = x1/s;
+#            [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+#            y = sqweight.*y;
+#            astra_mex_data3d('delete', sino_id);
+#            astra_mex_data3d('delete', id);
+        print ("iteration {0}".format(i))
+        sino_id, y = astra.creators.create_sino3d_gpu(x1,
+                                                  proj_geomT,
+                                                  vol_geomT)
+        #a=fig.add_subplot(2,1,1)
+        #imgplot = plt.imshow(y[0])
+        
+        y = sqweight * y # element wise multiplication
+        
+        #b=fig.add_subplot(2,1,2)
+        #imgplot = plt.imshow(x1[0])
+        #plt.show()
+        
+        #astra_mex_data3d('delete', sino_id);
+        astra.matlab.data3d('delete', sino_id)
+            
+        idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, 
+                                                            proj_geomT,
+                                                            vol_geomT);
+        print ("shape {1} x1 {0}".format(x1.T[:4].T, numpy.shape(x1)))
+        s = numpy.linalg.norm(x1)
+        ### this line?
+        x1 = x1/s;
+        print ("x1 {0}".format(x1.T[:4].T))
+        
+#        ### this line?
+#        sino_id, y = astra.creators.create_sino3d_gpu(x1, 
+#                                                      proj_geomT, 
+#                                                      vol_geomT);
+#        y = sqweight * y;
+        astra.matlab.data3d('delete', sino_id);
+        astra.matlab.data3d('delete', idx);
+    #end
+    del proj_geomT
+    del vol_geomT
+else:
+    #% divergen beam geometry
+    print('Calculating Lipshitz constant for divergen beam geometry...')
+    niter = 8; #% number of iteration for PM
+    x1 = numpy.random.rand(SlicesZ , N , N);
+    #sqweight = sqrt(weights);
+    sqweight = numpy.sqrt(weights[0])
+    
+    sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom);
+    y = sqweight*y;
+    #astra_mex_data3d('delete', sino_id);
+    astra.matlab.data3d('delete', sino_id);
+    
+    for i in range(niter):
+        #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom);
+        idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, 
+                                                            proj_geom, 
+                                                            vol_geom)
+        s = numpy.linalg.norm(x1)
+        ### this line?
+        x1 = x1/s;
+        ### this line?
+        #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom);
+        sino_id, y = astra.creators.create_sino3d_gpu(x1, 
+                                                      proj_geom, 
+                                                      vol_geom);
+        
+        y = sqweight*y;
+        #astra_mex_data3d('delete', sino_id);
+        #astra_mex_data3d('delete', id);
+        astra.matlab.data3d('delete', sino_id);
+        astra.matlab.data3d('delete', idx);
+    #end
+    #clear x1
+    del x1
-- 
cgit v1.2.3


From e58f774938edd3664dfb1f3905964b3add050bc9 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Thu, 24 Aug 2017 16:42:05 +0100
Subject: initial revision

---
 src/Python/ccpi/imaging/__init__.py | 0
 1 file changed, 0 insertions(+), 0 deletions(-)
 create mode 100644 src/Python/ccpi/imaging/__init__.py

(limited to 'src/Python')

diff --git a/src/Python/ccpi/imaging/__init__.py b/src/Python/ccpi/imaging/__init__.py
new file mode 100644
index 0000000..e69de29
-- 
cgit v1.2.3


From 9c974a26c0fc8060008745796fbe9f7ef5c250eb Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Thu, 24 Aug 2017 16:42:27 +0100
Subject: initial revision

---
 src/Python/ccpi/__init__.py | 0
 1 file changed, 0 insertions(+), 0 deletions(-)
 create mode 100644 src/Python/ccpi/__init__.py

(limited to 'src/Python')

diff --git a/src/Python/ccpi/__init__.py b/src/Python/ccpi/__init__.py
new file mode 100644
index 0000000..e69de29
-- 
cgit v1.2.3


From c8693f530e95e140a3fba85fc65d879b51b79e6d Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Wed, 11 Oct 2017 15:11:00 +0100
Subject: table with regularizers output

---
 src/Python/test_regularizers.py | 66 ++++++++++++++++++++---------------------
 1 file changed, 33 insertions(+), 33 deletions(-)

(limited to 'src/Python')

diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py
index 755804a..5804897 100644
--- a/src/Python/test_regularizers.py
+++ b/src/Python/test_regularizers.py
@@ -163,52 +163,52 @@ imgplot = plt.imshow(reg_output[-1][0])
 # #   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise
 # #   ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); 
 
-# out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05,
-                          # searching_window_ratio=3,
-                          # similarity_window_ratio=1,
-                          # PB_filtering_parameter=0.08)
-# pars = out2[-2]
-# reg_output.append(out2)
+out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05,
+                           searching_window_ratio=3,
+                           similarity_window_ratio=1,
+                           PB_filtering_parameter=0.08)
+pars = out2[-2]
+reg_output.append(out2)
 
-# a=fig.add_subplot(2,3,5)
+a=fig.add_subplot(2,3,5)
 
 
-# textstr = out2[-1]
+textstr = out2[-1]
 
-# # these are matplotlib.patch.Patch properties
-# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
-# # place a text box in upper left in axes coords
-# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
-        # verticalalignment='top', bbox=props)
-# imgplot = plt.imshow(reg_output[-1][0])
+# these are matplotlib.patch.Patch properties
+props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+# place a text box in upper left in axes coords
+a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+        verticalalignment='top', bbox=props)
+imgplot = plt.imshow(reg_output[-1][0])
 
 
-# ###################### TGV_PD #########################################
-# # Quick 2D denoising example in Matlab:   
-# #   Im = double(imread('lena_gray_256.tif'))/255;  % loading image
-# #   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise
-# #   u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550);
+###################### TGV_PD #########################################
+# Quick 2D denoising example in Matlab:   
+#   Im = double(imread('lena_gray_256.tif'))/255;  % loading image
+#   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise
+#   u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550);
 
 
-# out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05,
-                          # first_order_term=1.3,
-                          # second_order_term=1,
-                          # number_of_iterations=550)
-# pars = out2[-2]
-# reg_output.append(out2)
+out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05,
+                          first_order_term=1.3,
+                          second_order_term=1,
+                          number_of_iterations=550)
+pars = out2[-2]
+reg_output.append(out2)
 
-# a=fig.add_subplot(2,3,6)
+a=fig.add_subplot(2,3,6)
 
 
-# textstr = out2[-1]
+textstr = out2[-1]
 
 
-# # these are matplotlib.patch.Patch properties
-# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
-# # place a text box in upper left in axes coords
-# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
-        # verticalalignment='top', bbox=props)
-# imgplot = plt.imshow(reg_output[-1][0])
+# these are matplotlib.patch.Patch properties
+props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
+# place a text box in upper left in axes coords
+a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+        verticalalignment='top', bbox=props)
+imgplot = plt.imshow(reg_output[-1][0])
 
 
 plt.show()
-- 
cgit v1.2.3


From 776070e22bf95491275a023f3a5ac00cea356714 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Wed, 11 Oct 2017 15:12:42 +0100
Subject: read and plot the hdf5

---
 src/Python/test/readhd5.py | 15 ++++++++++++++-
 1 file changed, 14 insertions(+), 1 deletion(-)

(limited to 'src/Python')

diff --git a/src/Python/test/readhd5.py b/src/Python/test/readhd5.py
index 1e19e14..b042341 100644
--- a/src/Python/test/readhd5.py
+++ b/src/Python/test/readhd5.py
@@ -25,4 +25,17 @@ angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0]
 angles_rad = numpy.asarray(nx.get('/angles_rad'))
 recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0]
 size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0]
-slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0]
\ No newline at end of file
+slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0]
+
+#from ccpi.viewer.CILViewer2D import CILViewer2D
+#v = CILViewer2D()
+#v.setInputAsNumpy(Weights3D)
+#v.startRenderLoop()
+
+import matplotlib.pyplot as plt
+fig = plt.figure()
+
+a=fig.add_subplot(1,1,1)
+a.set_title('noise')
+imgplot = plt.imshow(Weights3D[0].T)
+plt.show()
-- 
cgit v1.2.3


From 5c978b706192bc5885c7e5001a4bc4626f63d29f Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Wed, 11 Oct 2017 15:49:18 +0100
Subject: initial revision

---
 src/Python/test/simple_astra_test.py | 25 +++++++++++++++++++++++++
 1 file changed, 25 insertions(+)
 create mode 100644 src/Python/test/simple_astra_test.py

(limited to 'src/Python')

diff --git a/src/Python/test/simple_astra_test.py b/src/Python/test/simple_astra_test.py
new file mode 100644
index 0000000..905eeea
--- /dev/null
+++ b/src/Python/test/simple_astra_test.py
@@ -0,0 +1,25 @@
+import astra
+import numpy
+
+detectorSpacingX = 1.0
+detectorSpacingY = 1.0
+det_row_count = 128
+det_col_count = 128
+
+angles_rad = numpy.asarray([i for i in range(360)], dtype=float) / 180. * numpy.pi
+
+proj_geom = astra.creators.create_proj_geom('parallel3d',
+                                            detectorSpacingX,
+                                            detectorSpacingY,
+                                            det_row_count,
+                                            det_col_count,
+                                            angles_rad)
+
+image_size_x = 64
+image_size_y = 64
+image_size_z = 32
+
+vol_geom = astra.creators.create_vol_geom(image_size_x,image_size_y,image_size_z)
+
+x1 = numpy.random.rand(image_size_z,image_size_y,image_size_x)
+sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom)
-- 
cgit v1.2.3