diff options
Diffstat (limited to 'src/Python')
-rw-r--r-- | src/Python/fista_module.cpp | 576 |
1 files changed, 473 insertions, 103 deletions
diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index 5344083..2492884 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -26,12 +26,10 @@ limitations under the License. #include <boost/python/numpy.hpp> #include "boost/tuple/tuple.hpp" -// include the regularizers -#include "FGP_TV_core.h" -#include "LLT_model_core.h" -#include "PatchBased_Regul_core.h" #include "SplitBregman_TV_core.h" -#include "TGV_PD_core.h" +#include "FGP_TV_core.h" + + #if defined(_WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(_WIN64) #include <windows.h> @@ -43,7 +41,6 @@ __if_not_exists(uint16_t) { typedef __int8 uint16_t; } namespace bp = boost::python; namespace np = boost::python::numpy; - /*! in the Matlab implementation this is called as void mexFunction( int nlhs, mxArray *plhs[], @@ -56,9 +53,7 @@ nlhs Array of pointers to the OUTPUT mxArrays plhs int number of OUTPUT mxArrays *********************************************************** -mxGetData -args: pm Pointer to an mxArray -Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. + *********************************************************** double mxGetScalar(const mxArray *pm); args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. @@ -92,77 +87,143 @@ double *mxGetPr(const mxArray *pm); args: pm Pointer to an mxArray of type double Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. **************************************************************** -mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, mxClassID classid, mxComplexity ComplexFlag); -args: ndim: Number of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. - dims: Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. - For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. - classid: Identifier for the class of the array, which determines the way the numerical data is represented in memory. - For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. - ComplexFlag: If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). - Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). - +mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, +mxClassID classid, mxComplexity ComplexFlag); +args: ndimNumber of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. +dims Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. +For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. +classid Identifier for the class of the array, which determines the way the numerical data is represented in memory. +For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. +ComplexFlag If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). Returns: Pointer to the created mxArray, if successful. If unsuccessful in a standalone (non-MEX file) application, returns NULL in C (0 in Fortran). - If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not - enough free heap space to create the mxArray. +If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not +enough free heap space to create the mxArray. +*/ + +void mexErrMessageText(char* text) { + std::cerr << text << std::endl; +} + +/* +double mxGetScalar(const mxArray *pm); +args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. +Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double. */ template<typename T> -np::ndarray zeros(int dims, int * dim_array, T el) { - bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); - np::dtype dtype = np::dtype::get_builtin<T>(); - np::ndarray zz = np::zeros(shape, dtype); - return zz; +double mxGetScalar(const np::ndarray plh) { + return (double)bp::extract<T>(plh[0]); } -bp::list SplitBregman_TV(np::ndarray input, double d_mu, , int niterations, double d_epsil, int TV_type) { - /* C-OMP implementation of Split Bregman - TV denoising-regularization model (2D/3D) - * - * Input Parameters: - * 1. Noisy image/volume - * 2. lambda - regularization parameter - * 3. Number of iterations [OPTIONAL parameter] - * 4. eplsilon - tolerance constant [OPTIONAL parameter] - * 5. TV-type: 'iso' or 'l1' [OPTIONAL parameter] - * - * Output: - * Filtered/regularized image - * - * All sanity checks and default values are set in Python + +template<typename T> +T * mxGetData(const np::ndarray pm) { + //args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. + //Returns: Pointer to the value of the first real(nonimaginary) element of the mxArray.In C, mxGetScalar returns a double. + /*Access the numpy array pointer: + char * get_data() const; + Returns: Array’s raw data pointer as a char + Note: This returns char so stride math works properly on it.User will have to reinterpret_cast it. + probably this would work. + A = reinterpret_cast<float *>(prhs[0]); */ + return reinterpret_cast<T *>(prhs[0]); +} + + + + +bp::list mexFunction(np::ndarray input) { + int number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + + /**************************************************************************/ + np::ndarray zz = zeros(3, dim_array, (int)0); + np::ndarray fzz = zeros(3, dim_array, (float)0); + /**************************************************************************/ + + int * A = reinterpret_cast<int *>(input.get_data()); + int * B = reinterpret_cast<int *>(zz.get_data()); + float * C = reinterpret_cast<float *>(fzz.get_data()); + + //Copy data and cast + for (int i = 0; i < dim_array[0]; i++) { + for (int j = 0; j < dim_array[1]; j++) { + for (int k = 0; k < dim_array[2]; k++) { + int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; + int val = (*(A + index)); + float fval = (float)val; + std::memcpy(B + index, &val, sizeof(int)); + std::memcpy(C + index, &fval, sizeof(float)); + } + } + } + + + bp::list result; + + result.append<int>(number_of_dims); + result.append<int>(dim_array[0]); + result.append<int>(dim_array[1]); + result.append<int>(dim_array[2]); + result.append<np::ndarray>(zz); + result.append<np::ndarray>(fzz); + + //result.append<bp::tuple>(tup); + return result; + +} + +bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { + + // the result is in the following list + bp::list result; + int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; - const int dim_array[3]; + const int *dim_array; float *A, *U = NULL, *U_old = NULL, *Dx = NULL, *Dy = NULL, *Dz = NULL, *Bx = NULL, *By = NULL, *Bz = NULL, lambda, mu, epsil, re, re1, re_old; - + //number_of_dims = mxGetNumberOfDimensions(prhs[0]); //dim_array = mxGetDimensions(prhs[0]); - number_of_dims = input.get_nd(); + + int number_of_dims = input.get_nd(); + int dim_array[3]; dim_array[0] = input.shape(0); dim_array[1] = input.shape(1); if (number_of_dims == 2) { - dim_array[2] = -11; + dim_array[2] = -1; } else { dim_array[2] = input.shape(2); } - /*Handling Matlab input data*/ + // Parameter handling is be done in Python + ///*Handling Matlab input data*/ //if ((nrhs < 2) || (nrhs > 5)) mexErrMsgTxt("At least 2 parameters is required: Image(2D/3D), Regularization parameter. The full list of parameters: Image(2D/3D), Regularization parameter, iterations number, tolerance, penalty type ('iso' or 'l1')"); - /*Handling Matlab input data*/ + ///*Handling Matlab input data*/ //A = (float *)mxGetData(prhs[0]); /*noisy image (2D/3D) */ A = reinterpret_cast<float *>(input.get_data()); - //mu = (float)mxGetScalar(prhs[1]); /* regularization parameter */ mu = (float)d_mu; + //iter = 35; /* default iterations number */ - iter = niterations; + //epsil = 0.0001; /* default tolerance constant */ epsil = (float)d_epsil; //methTV = 0; /* default isotropic TV penalty */ - methTV = TV_type; //if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5)) iter = (int)mxGetScalar(prhs[2]); /* iterations number */ //if ((nrhs == 4) || (nrhs == 5)) epsil = (float)mxGetScalar(prhs[3]); /* tolerance constant */ //if (nrhs == 5) { @@ -182,34 +243,31 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, , int niterations, doub if (number_of_dims == 2) { dimZ = 1; /*2D case*/ - /* - mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, mxClassID classid, mxComplexity ComplexFlag); -args: ndim: Number of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. - dims: Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. - For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. - classid: Identifier for the class of the array, which determines the way the numerical data is represented in memory. - For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. - ComplexFlag: If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). - Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). - - mxCreateNumericArray initializes all its real data elements to 0. -*/ - -/* - U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); -*/ //U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - U = A = reinterpret_cast<float *>input.get_data(); - U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin<float>(); + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + np::ndarray npDx = np::zeros(shape, dtype); + np::ndarray npDy = np::zeros(shape, dtype); + np::ndarray npBx = np::zeros(shape, dtype); + np::ndarray npBy = np::zeros(shape, dtype); + + U = reinterpret_cast<float *>(npU.get_data()); + U_old = reinterpret_cast<float *>(npU_old.get_data()); + Dx = reinterpret_cast<float *>(npDx.get_data()); + Dy = reinterpret_cast<float *>(npDy.get_data()); + Bx = reinterpret_cast<float *>(npBx.get_data()); + By = reinterpret_cast<float *>(npBy.get_data()); + + + copyIm(A, U, dimX, dimY, dimZ); /*initialize */ /* begin outer SB iterations */ @@ -245,59 +303,370 @@ args: ndim: Number of dimensions. If you specify a value for ndim that is less /*printf("%f %i %i \n", re, ll, count); */ /*copyIm(U_old, U, dimX, dimY, dimZ); */ + result.append<np::ndarray>(npU); + result.append<int>(ll); + } + //printf("SB iterations stopped at iteration: %i\n", ll); + if (number_of_dims == 3) { + /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dy = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Bx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + By = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Bz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));*/ + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + np::dtype dtype = np::dtype::get_builtin<float>(); + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + np::ndarray npDx = np::zeros(shape, dtype); + np::ndarray npDy = np::zeros(shape, dtype); + np::ndarray npDz = np::zeros(shape, dtype); + np::ndarray npBx = np::zeros(shape, dtype); + np::ndarray npBy = np::zeros(shape, dtype); + np::ndarray npBz = np::zeros(shape, dtype); + + U = reinterpret_cast<float *>(npU.get_data()); + U_old = reinterpret_cast<float *>(npU_old.get_data()); + Dx = reinterpret_cast<float *>(npDx.get_data()); + Dy = reinterpret_cast<float *>(npDy.get_data()); + Dz = reinterpret_cast<float *>(npDz.get_data()); + Bx = reinterpret_cast<float *>(npBx.get_data()); + By = reinterpret_cast<float *>(npBy.get_data()); + Bz = reinterpret_cast<float *>(npBz.get_data()); + + copyIm(A, U, dimX, dimY, dimZ); /*initialize */ + + /* begin outer SB iterations */ + for (ll = 0; ll<iter; ll++) { + + /*storing old values*/ + copyIm(U, U_old, dimX, dimY, dimZ); + + /*GS iteration */ + gauss_seidel3D(U, A, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ, lambda, mu); + + if (methTV == 1) updDxDyDz_shrinkAniso3D(U, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ, lambda); + else updDxDyDz_shrinkIso3D(U, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ, lambda); + + updBxByBz3D(U, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ); + + /* calculate norm to terminate earlier */ + re = 0.0f; re1 = 0.0f; + for (j = 0; j<dimX*dimY*dimZ; j++) + { + re += pow(U[j] - U_old[j], 2); + re1 += pow(U[j], 2); + } + re = sqrt(re) / sqrt(re1); + if (re < epsil) count++; + if (count > 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + /*printf("%f %i %i \n", re, ll, count); */ + re_old = re; + } + //printf("SB iterations stopped at iteration: %i\n", ll); + result.append<np::ndarray>(npU); + result.append<int>(ll); } - printf("SB iterations stopped at iteration: %i\n", ll); } - if (number_of_dims == 3) { - U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Dy = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Dz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Bx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - By = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Bz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + return result; - copyIm(A, U, dimX, dimY, dimZ); /*initialize */ +} - /* begin outer SB iterations */ +bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { + + // the result is in the following list + bp::list result; + + int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; + float *A, *D = NULL, *D_old = NULL, *P1 = NULL, *P2 = NULL, *P3 = NULL, *P1_old = NULL, *P2_old = NULL, *P3_old = NULL, *R1 = NULL, *R2 = NULL, *R3 = NULL; + float lambda, tk, tkp1, re, re1, re_old, epsil, funcval; + + //number_of_dims = mxGetNumberOfDimensions(prhs[0]); + //dim_array = mxGetDimensions(prhs[0]); + + int number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + + // Parameter handling is be done in Python + ///*Handling Matlab input data*/ + //if ((nrhs < 2) || (nrhs > 5)) mexErrMsgTxt("At least 2 parameters is required: Image(2D/3D), Regularization parameter. The full list of parameters: Image(2D/3D), Regularization parameter, iterations number, tolerance, penalty type ('iso' or 'l1')"); + + ///*Handling Matlab input data*/ + //A = (float *)mxGetData(prhs[0]); /*noisy image (2D/3D) */ + A = reinterpret_cast<float *>(input.get_data()); + + //mu = (float)mxGetScalar(prhs[1]); /* regularization parameter */ + mu = (float)d_mu; + + //iter = 35; /* default iterations number */ + + //epsil = 0.0001; /* default tolerance constant */ + epsil = (float)d_epsil; + //methTV = 0; /* default isotropic TV penalty */ + //if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5)) iter = (int)mxGetScalar(prhs[2]); /* iterations number */ + //if ((nrhs == 4) || (nrhs == 5)) epsil = (float)mxGetScalar(prhs[3]); /* tolerance constant */ + //if (nrhs == 5) { + // char *penalty_type; + // penalty_type = mxArrayToString(prhs[4]); /* choosing TV penalty: 'iso' or 'l1', 'iso' is the default */ + // if ((strcmp(penalty_type, "l1") != 0) && (strcmp(penalty_type, "iso") != 0)) mexErrMsgTxt("Choose TV type: 'iso' or 'l1',"); + // if (strcmp(penalty_type, "l1") == 0) methTV = 1; /* enable 'l1' penalty */ + // mxFree(penalty_type); + //} + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); } + + //plhs[1] = mxCreateNumericMatrix(1, 1, mxSINGLE_CLASS, mxREAL); + bp::tuple shape1 = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin<float>(); + np::ndarray out1 = np::zeros(shape1, dtype); + + //float *funcvalA = (float *)mxGetData(plhs[1]); + float * funcvalA = reinterpret_cast<float *>(out1.get_data()); + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); } + + /*Handling Matlab output data*/ + dimX = dim_array[0]; dimY = dim_array[1]; dimZ = dim_array[2]; + + tk = 1.0f; + tkp1 = 1.0f; + count = 1; + re_old = 0.0f; + + if (number_of_dims == 2) { + dimZ = 1; /*2D case*/ + /*D = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + D_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P1_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P2_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + R1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + R2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/ + + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin<float>(); + + + np::ndarray npD = np::zeros(shape, dtype); + np::ndarray npD_old = np::zeros(shape, dtype); + np::ndarray npP1 = np::zeros(shape, dtype); + np::ndarray npP2 = np::zeros(shape, dtype); + np::ndarray npP1_old = np::zeros(shape, dtype); + np::ndarray npP2_old = np::zeros(shape, dtype); + np::ndarray npR1 = np::zeros(shape, dtype); + np::ndarray npR2 = zeros(2, dim_array, (float)0); + + D = reinterpret_cast<float *>(npD.get_data()); + D_old = reinterpret_cast<float *>(npD_old.get_data()); + P1 = reinterpret_cast<float *>(npP1.get_data()); + P2 = reinterpret_cast<float *>(npP2.get_data()); + P1_old = reinterpret_cast<float *>(npP1_old.get_data()); + P2_old = reinterpret_cast<float *>(npP2_old.get_data()); + R1 = reinterpret_cast<float *>(npR1.get_data()); + R2 = reinterpret_cast<float *>(npR2.get_data()); + + /* begin iterations */ for (ll = 0; ll<iter; ll++) { - /*storing old values*/ - copyIm(U, U_old, dimX, dimY, dimZ); + /* computing the gradient of the objective function */ + Obj_func2D(A, D, R1, R2, lambda, dimX, dimY); - /*GS iteration */ - gauss_seidel3D(U, A, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ, lambda, mu); + /*Taking a step towards minus of the gradient*/ + Grad_func2D(P1, P2, D, R1, R2, lambda, dimX, dimY); - if (methTV == 1) updDxDyDz_shrinkAniso3D(U, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ, lambda); - else updDxDyDz_shrinkIso3D(U, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ, lambda); + /* projection step */ + Proj_func2D(P1, P2, methTV, dimX, dimY); - updBxByBz3D(U, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ); + /*updating R and t*/ + tkp1 = (1.0f + sqrt(1.0f + 4.0f*tk*tk))*0.5f; + Rupd_func2D(P1, P1_old, P2, P2_old, R1, R2, tkp1, tk, dimX, dimY); - /* calculate norm to terminate earlier */ + /* calculate norm */ re = 0.0f; re1 = 0.0f; for (j = 0; j<dimX*dimY*dimZ; j++) { - re += pow(U[j] - U_old[j], 2); - re1 += pow(U[j], 2); + re += pow(D[j] - D_old[j], 2); + re1 += pow(D[j], 2); } re = sqrt(re) / sqrt(re1); if (re < epsil) count++; - if (count > 4) break; + if (count > 3) { + Obj_func2D(A, D, P1, P2, lambda, dimX, dimY); + funcval = 0.0f; + for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2); + //funcvalA[0] = sqrt(funcval); + float fv = sqrt(funcval); + std::memcpy(funcvalA, &fv), sizeof(float)); + break; + } /* check that the residual norm is decreasing */ if (ll > 2) { - if (re > re_old) break; + if (re > re_old) { + Obj_func2D(A, D, P1, P2, lambda, dimX, dimY); + funcval = 0.0f; + for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2); + //funcvalA[0] = sqrt(funcval); + float fv = sqrt(funcval); + std::memcpy(funcvalA, &fv), sizeof(float)); + break; + } } + re_old = re; /*printf("%f %i %i \n", re, ll, count); */ + + /*storing old values*/ + copyIm(D, D_old, dimX, dimY, dimZ); + copyIm(P1, P1_old, dimX, dimY, dimZ); + copyIm(P2, P2_old, dimX, dimY, dimZ); + tk = tkp1; + + /* calculating the objective function value */ + if (ll == (iter - 1)) { + Obj_func2D(A, D, P1, P2, lambda, dimX, dimY); + funcval = 0.0f; + for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2); + //funcvalA[0] = sqrt(funcval); + float fv = sqrt(funcval); + std::memcpy(funcvalA, &fv), sizeof(float)); + } + } + //printf("FGP-TV iterations stopped at iteration %i with the function value %f \n", ll, funcvalA[0]); + result.append<np::ndarray>(npD); + result.append<np::ndarray>(out1); + result.append<int>(ll); + } + if (number_of_dims == 3) { + /*D = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P1_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P2_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P3_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + R1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + R2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + R3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));*/ + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + np::dtype dtype = np::dtype::get_builtin<float>(); + + np::ndarray npD = np::zeros(shape, dtype); + np::ndarray npD_old = np::zeros(shape, dtype); + np::ndarray npP1 = np::zeros(shape, dtype); + np::ndarray npP2 = np::zeros(shape, dtype); + np::ndarray npP3 = np::zeros(shape, dtype); + np::ndarray npP1_old = np::zeros(shape, dtype); + np::ndarray npP2_old = np::zeros(shape, dtype); + np::ndarray npP3_old = np::zeros(shape, dtype); + np::ndarray npR1 = np::zeros(shape, dtype); + np::ndarray npR2 = np::zeros(shape, dtype); + np::ndarray npR3 = np::zeros(shape, dtype); + + D = reinterpret_cast<float *>(npD.get_data()); + D_old = reinterpret_cast<float *>(npD_old.get_data()); + P1 = reinterpret_cast<float *>(npP1.get_data()); + P2 = reinterpret_cast<float *>(npP2.get_data()); + P3 = reinterpret_cast<float *>(npP3.get_data()); + P1_old = reinterpret_cast<float *>(npP1_old.get_data()); + P2_old = reinterpret_cast<float *>(npP2_old.get_data()); + P3_old = reinterpret_cast<float *>(npP3_old.get_data()); + R1 = reinterpret_cast<float *>(npR1.get_data()); + R2 = reinterpret_cast<float *>(npR2.get_data()); + R2 = reinterpret_cast<float *>(npR3.get_data()); + /* begin iterations */ + for (ll = 0; ll<iter; ll++) { + + /* computing the gradient of the objective function */ + Obj_func3D(A, D, R1, R2, R3, lambda, dimX, dimY, dimZ); + + /*Taking a step towards minus of the gradient*/ + Grad_func3D(P1, P2, P3, D, R1, R2, R3, lambda, dimX, dimY, dimZ); + + /* projection step */ + Proj_func3D(P1, P2, P3, dimX, dimY, dimZ); + + /*updating R and t*/ + tkp1 = (1.0f + sqrt(1.0f + 4.0f*tk*tk))*0.5f; + Rupd_func3D(P1, P1_old, P2, P2_old, P3, P3_old, R1, R2, R3, tkp1, tk, dimX, dimY, dimZ); + + /* calculate norm - stopping rules*/ + re = 0.0f; re1 = 0.0f; + for (j = 0; j<dimX*dimY*dimZ; j++) + { + re += pow(D[j] - D_old[j], 2); + re1 += pow(D[j], 2); + } + re = sqrt(re) / sqrt(re1); + /* stop if the norm residual is less than the tolerance EPS */ + if (re < epsil) count++; + if (count > 3) { + Obj_func3D(A, D, P1, P2, P3, lambda, dimX, dimY, dimZ); + funcval = 0.0f; + for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2); + //funcvalA[0] = sqrt(funcval); + float fv = sqrt(funcval); + std::memcpy(funcvalA, &fv), sizeof(float)); + break; + } + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) { + Obj_func3D(A, D, P1, P2, P3, lambda, dimX, dimY, dimZ); + funcval = 0.0f; + for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2); + //funcvalA[0] = sqrt(funcval); + float fv = sqrt(funcval); + std::memcpy(funcvalA, &fv), sizeof(float)); + break; + } + } + re_old = re; + /*printf("%f %i %i \n", re, ll, count); */ + + /*storing old values*/ + copyIm(D, D_old, dimX, dimY, dimZ); + copyIm(P1, P1_old, dimX, dimY, dimZ); + copyIm(P2, P2_old, dimX, dimY, dimZ); + copyIm(P3, P3_old, dimX, dimY, dimZ); + tk = tkp1; + + if (ll == (iter - 1)) { + Obj_func3D(A, D, P1, P2, P3, lambda, dimX, dimY, dimZ); + funcval = 0.0f; + for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2); + //funcvalA[0] = sqrt(funcval); + float fv = sqrt(funcval); + std::memcpy(funcvalA, &fv), sizeof(float)); + } + } - printf("SB iterations stopped at iteration: %i\n", ll); + //printf("FGP-TV iterations stopped at iteration %i with the function value %f \n", ll, funcvalA[0]); + result.append<np::ndarray>(npD); + result.append<np::ndarray>(out1); + result.append<int>(ll); } - bp::list result; + return result; } - BOOST_PYTHON_MODULE(fista) { @@ -310,6 +679,7 @@ BOOST_PYTHON_MODULE(fista) np::dtype dt1 = np::dtype::get_builtin<uint8_t>(); np::dtype dt2 = np::dtype::get_builtin<uint16_t>(); - def("mexFunction", mexFunction); + def("SplitBregman_TV", SplitBregman_TV); + def("FGP_TV", FGP_TV); }
\ No newline at end of file |