/* ----------------------------------------------------------------------- Copyright: 2010-2021, imec Vision Lab, University of Antwerp 2014-2021, CWI, Amsterdam Contact: astra@astra-toolbox.com Website: http://www.astra-toolbox.com/ This file is part of the ASTRA Toolbox. The ASTRA Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. The ASTRA Toolbox is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with the ASTRA Toolbox. If not, see . ----------------------------------------------------------------------- */ /** \file astra_mex_matrix_c.cpp * * \brief Create sparse (projection) matrices in the ASTRA workspace */ #include #include "mexHelpFunctions.h" #include "mexInitFunctions.h" #include #include "astra/Globals.h" #include "astra/AstraObjectManager.h" #include "astra/SparseMatrix.h" using namespace std; using namespace astra; //----------------------------------------------------------------------------------------- /** astra_mex_matrix('delete', id1, id2, ...); * * Delete one or more data objects currently stored in the astra-library. * id1, id2, ... : identifiers of the 2d data objects as stored in the astra-library. */ void astra_mex_matrix_delete(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { // step1: read input if (nrhs < 2) { mexErrMsgTxt("Not enough arguments. See the help document for a detailed argument list. \n"); return; } // step2: delete all specified data objects for (int i = 1; i < nrhs; i++) { int iDataID = (int)(mxGetScalar(prhs[i])); CMatrixManager::getSingleton().remove(iDataID); } } //----------------------------------------------------------------------------------------- /** astra_mex_matrix('clear'); * * Delete all data objects currently stored in the astra-library. */ void astra_mex_matrix_clear(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { CMatrixManager::getSingleton().clear(); } static bool matlab_to_astra(const mxArray* _rhs, CSparseMatrix* _pMatrix) { // Check input if (!mxIsSparse (_rhs)) { mexErrMsgTxt("Argument is not a valid MATLAB sparse matrix.\n"); return false; } if (!_pMatrix->isInitialized()) { mexErrMsgTxt("Couldn't initialize data object.\n"); return false; } unsigned int iHeight = mxGetM(_rhs); unsigned int iWidth = mxGetN(_rhs); unsigned long lSize = mxGetNzmax(_rhs); if (_pMatrix->m_lSize < lSize || _pMatrix->m_iHeight < iHeight) { // TODO: support resizing? mexErrMsgTxt("Matrix too large to store in this object.\n"); return false; } // Transpose matrix, as matlab stores a matrix column-by-column // but we want it row-by-row. // 1. Compute sizes of rows. We store these in _pMatrix->m_plRowStarts. // 2. Fill data structure // Complexity: O( #rows + #entries ) for (unsigned int i = 0; i <= iHeight; ++i) _pMatrix->m_plRowStarts[i] = 0; mwIndex *colStarts = mxGetJc(_rhs); mwIndex *rowIndices = mxGetIr(_rhs); double *floatValues = 0; mxLogical *boolValues = 0; bool bLogical = mxIsLogical(_rhs); if (bLogical) boolValues = mxGetLogicals(_rhs); else floatValues = mxGetPr(_rhs); for (mwIndex i = 0; i < colStarts[iWidth]; ++i) { unsigned int iRow = rowIndices[i]; assert(iRow < iHeight); _pMatrix->m_plRowStarts[iRow+1]++; } // Now _pMatrix->m_plRowStarts[i+1] is the number of entries in row i for (unsigned int i = 1; i <= iHeight; ++i) _pMatrix->m_plRowStarts[i] += _pMatrix->m_plRowStarts[i-1]; // Now _pMatrix->m_plRowStarts[i+1] is the number of entries in rows <= i, // so the intended start of row i+1 int iCol = 0; for (mwIndex i = 0; i < colStarts[iWidth]; ++i) { while (i >= colStarts[iCol+1]) iCol++; unsigned int iRow = rowIndices[i]; assert(iRow < iHeight); float32 fVal; if (bLogical) fVal = (float32)boolValues[i]; else fVal = (float32)floatValues[i]; unsigned long lIndex = _pMatrix->m_plRowStarts[iRow]++; _pMatrix->m_pfValues[lIndex] = fVal; _pMatrix->m_piColIndices[lIndex] = iCol; } // Now _pMatrix->m_plRowStarts[i] is the start of row i+1 for (unsigned int i = iHeight; i > 0; --i) _pMatrix->m_plRowStarts[i] = _pMatrix->m_plRowStarts[i-1]; _pMatrix->m_plRowStarts[0] = 0; #if 0 // Debugging: dump matrix for (unsigned int i = 0; i < iHeight; ++i) { printf("Row %d: %ld-%ld\n", i, _pMatrix->m_plRowStarts[i], _pMatrix->m_plRowStarts[i+1]); for (unsigned long j = _pMatrix->m_plRowStarts[i]; j < _pMatrix->m_plRowStarts[i+1]; ++j) { printf("(%d,%d) = %f\n", i, _pMatrix->m_piColIndices[j], _pMatrix->m_pfValues[j]); } } #endif return true; } static bool astra_to_matlab(const CSparseMatrix* _pMatrix, mxArray*& _lhs) { if (!_pMatrix->isInitialized()) { mexErrMsgTxt("Uninitialized data object.\n"); return false; } unsigned int iHeight = _pMatrix->m_iHeight; unsigned int iWidth = _pMatrix->m_iWidth; unsigned long lSize = _pMatrix->m_lSize; _lhs = mxCreateSparse(iHeight, iWidth, lSize, mxREAL); if (!mxIsSparse (_lhs)) { mexErrMsgTxt("Couldn't initialize matlab sparse matrix.\n"); return false; } mwIndex *colStarts = mxGetJc(_lhs); mwIndex *rowIndices = mxGetIr(_lhs); double *floatValues = mxGetPr(_lhs); for (unsigned int i = 0; i <= iWidth; ++i) colStarts[i] = 0; for (unsigned int i = 0; i < _pMatrix->m_plRowStarts[iHeight]; ++i) { unsigned int iCol = _pMatrix->m_piColIndices[i]; assert(iCol < iWidth); colStarts[iCol+1]++; } // Now _pMatrix->m_plRowStarts[i+1] is the number of entries in row i for (unsigned int i = 1; i <= iWidth; ++i) colStarts[i] += colStarts[i-1]; // Now _pMatrix->m_plRowStarts[i+1] is the number of entries in rows <= i, // so the intended start of row i+1 unsigned int iRow = 0; for (unsigned int i = 0; i < _pMatrix->m_plRowStarts[iHeight]; ++i) { while (i >= _pMatrix->m_plRowStarts[iRow+1]) iRow++; unsigned int iCol = _pMatrix->m_piColIndices[i]; assert(iCol < iWidth); double fVal = _pMatrix->m_pfValues[i]; unsigned long lIndex = colStarts[iCol]++; floatValues[lIndex] = fVal; rowIndices[lIndex] = iRow; } // Now _pMatrix->m_plRowStarts[i] is the start of row i+1 for (unsigned int i = iWidth; i > 0; --i) colStarts[i] = colStarts[i-1]; colStarts[0] = 0; return true; } //----------------------------------------------------------------------------------------- /** id = astra_mex_matrix('create', data); * * Create a new matrix object in the astra-library. * data: a sparse MATLAB matrix containing the data. * id: identifier of the matrix object as it is now stored in the astra-library. */ void astra_mex_matrix_create(int& nlhs, mxArray* plhs[], int& nrhs, const mxArray* prhs[]) { // step1: get datatype if (nrhs < 2) { mexErrMsgTxt("Not enough arguments. See the help document for a detailed argument list. \n"); return; } if (!mxIsSparse (prhs[1])) { mexErrMsgTxt("Argument is not a valid MATLAB sparse matrix.\n"); return; } unsigned int iHeight = mxGetM(prhs[1]); unsigned int iWidth = mxGetN(prhs[1]); unsigned long lSize = mxGetNzmax(prhs[1]); CSparseMatrix* pMatrix = new CSparseMatrix(iHeight, iWidth, lSize); // Check initialization if (!pMatrix->isInitialized()) { mexErrMsgTxt("Couldn't initialize data object.\n"); delete pMatrix; return; } bool bResult = matlab_to_astra(prhs[1], pMatrix); if (!bResult) { mexErrMsgTxt("Failed to create data object.\n"); delete pMatrix; return; } // store data object int iIndex = CMatrixManager::getSingleton().store(pMatrix); // return data id if (1 <= nlhs) { plhs[0] = mxCreateDoubleScalar(iIndex); } } //----------------------------------------------------------------------------------------- /** astra_mex_matrix('store', id, data); * * Store a sparse MATLAB matrix in an existing astra matrix dataobject. * id: identifier of the 2d data object as stored in the astra-library. * data: a sparse MATLAB matrix. */ void astra_mex_matrix_store(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { // step1: input if (nrhs < 3) { mexErrMsgTxt("Not enough arguments. See the help document for a detailed argument list. \n"); return; } if (!mxIsDouble(prhs[1])) { mexErrMsgTxt("Identifier should be a scalar value. \n"); return; } int iDataID = (int)(mxGetScalar(prhs[1])); // step2: get data object CSparseMatrix* pMatrix = astra::CMatrixManager::getSingleton().get(iDataID); if (!pMatrix || !pMatrix->isInitialized()) { mexErrMsgTxt("Data object not found or not initialized properly.\n"); return; } bool bResult = matlab_to_astra(prhs[2], pMatrix); if (!bResult) { mexErrMsgTxt("Failed to store matrix.\n"); } } //----------------------------------------------------------------------------------------- /** geom = astra_mex_matrix('get_size', id); * * Fetch the dimensions and size of a matrix stored in the astra-library. * id: identifier of the 2d data object as stored in the astra-library. * geom: a 1x2 matrix containing [rows, columns] */ void astra_mex_matrix_get_size(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { // step1: input if (nrhs < 2) { mexErrMsgTxt("Not enough arguments. See the help document for a detailed argument list. \n"); return; } if (!mxIsDouble(prhs[1])) { mexErrMsgTxt("Identifier should be a scalar value. \n"); return; } int iDataID = (int)(mxGetScalar(prhs[1])); // step2: get data object CSparseMatrix* pMatrix = astra::CMatrixManager::getSingleton().get(iDataID); if (!pMatrix || !pMatrix->isInitialized()) { mexErrMsgTxt("Data object not found or not initialized properly.\n"); return; } // create output // TODO } //----------------------------------------------------------------------------------------- /** data = astra_mex_matrix('get', id); * * Fetch data from the astra-library to a MATLAB matrix. * id: identifier of the matrix data object as stored in the astra-library. * data: MATLAB */ void astra_mex_matrix_get(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { // step1: check input if (nrhs < 2) { mexErrMsgTxt("Not enough arguments. See the help document for a detailed argument list. \n"); return; } if (!mxIsDouble(prhs[1])) { mexErrMsgTxt("Identifier should be a scalar value. \n"); return; } int iDataID = (int)(mxGetScalar(prhs[1])); // step2: get data object CSparseMatrix* pMatrix = astra::CMatrixManager::getSingleton().get(iDataID); if (!pMatrix || !pMatrix->isInitialized()) { mexErrMsgTxt("Data object not found or not initialized properly.\n"); return; } // create output if (1 <= nlhs) { bool bResult = astra_to_matlab(pMatrix, plhs[0]); if (!bResult) { mexErrMsgTxt("Failed to get matrix.\n"); } } } //----------------------------------------------------------------------------------------- /** astra_mex_matrix('info'); * * Print information about all the matrix objects currently stored in the astra-library. */ void astra_mex_matrix_info(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { mexPrintf("%s", astra::CMatrixManager::getSingleton().info().c_str()); } //----------------------------------------------------------------------------------------- static void printHelp() { mexPrintf("Please specify a mode of operation.\n"); mexPrintf("Valid modes: get, delete, clear, store, create, get_size, info\n"); } //----------------------------------------------------------------------------------------- /** * ... = astra_mex_matrix(type,...); */ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { // INPUT0: Mode string sMode = ""; if (1 <= nrhs) { sMode = mexToString(prhs[0]); } else { printHelp(); return; } initASTRAMex(); // SWITCH (MODE) if (sMode == std::string("get")) { astra_mex_matrix_get(nlhs, plhs, nrhs, prhs); } else if (sMode == std::string("delete")) { astra_mex_matrix_delete(nlhs, plhs, nrhs, prhs); } else if (sMode == "clear") { astra_mex_matrix_clear(nlhs, plhs, nrhs, prhs); } else if (sMode == std::string("store")) { astra_mex_matrix_store(nlhs, plhs, nrhs, prhs); } else if (sMode == std::string("create")) { astra_mex_matrix_create(nlhs, plhs, nrhs, prhs); } else if (sMode == std::string("get_size")) { astra_mex_matrix_get_size(nlhs, plhs, nrhs, prhs); } else if (sMode == std::string("info")) { astra_mex_matrix_info(nlhs, plhs, nrhs, prhs); } else { printHelp(); } return; }