/*
-----------------------------------------------------------------------
Copyright: 2010-2021, imec Vision Lab, University of Antwerp
2014-2021, CWI, Amsterdam
Contact: astra@astra-toolbox.com
Website: http://www.astra-toolbox.com/
This file is part of the ASTRA Toolbox.
The ASTRA Toolbox is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
The ASTRA Toolbox is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with the ASTRA Toolbox. If not, see .
-----------------------------------------------------------------------
*/
/** \file astra_mex_matrix_c.cpp
*
* \brief Create sparse (projection) matrices in the ASTRA workspace
*/
#include
#include "mexHelpFunctions.h"
#include "mexInitFunctions.h"
#include
#include "astra/Globals.h"
#include "astra/AstraObjectManager.h"
#include "astra/SparseMatrix.h"
using namespace std;
using namespace astra;
//-----------------------------------------------------------------------------------------
/** astra_mex_matrix('delete', id1, id2, ...);
*
* Delete one or more data objects currently stored in the astra-library.
* id1, id2, ... : identifiers of the 2d data objects as stored in the astra-library.
*/
void astra_mex_matrix_delete(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
// step1: read input
if (nrhs < 2) {
mexErrMsgTxt("Not enough arguments. See the help document for a detailed argument list. \n");
return;
}
// step2: delete all specified data objects
for (int i = 1; i < nrhs; i++) {
int iDataID = (int)(mxGetScalar(prhs[i]));
CMatrixManager::getSingleton().remove(iDataID);
}
}
//-----------------------------------------------------------------------------------------
/** astra_mex_matrix('clear');
*
* Delete all data objects currently stored in the astra-library.
*/
void astra_mex_matrix_clear(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
CMatrixManager::getSingleton().clear();
}
static bool matlab_to_astra(const mxArray* _rhs, CSparseMatrix* _pMatrix)
{
// Check input
if (!mxIsSparse (_rhs)) {
mexErrMsgTxt("Argument is not a valid MATLAB sparse matrix.\n");
return false;
}
if (!_pMatrix->isInitialized()) {
mexErrMsgTxt("Couldn't initialize data object.\n");
return false;
}
unsigned int iHeight = mxGetM(_rhs);
unsigned int iWidth = mxGetN(_rhs);
unsigned long lSize = mxGetNzmax(_rhs);
if (_pMatrix->m_lSize < lSize || _pMatrix->m_iHeight < iHeight) {
// TODO: support resizing?
mexErrMsgTxt("Matrix too large to store in this object.\n");
return false;
}
// Transpose matrix, as matlab stores a matrix column-by-column
// but we want it row-by-row.
// 1. Compute sizes of rows. We store these in _pMatrix->m_plRowStarts.
// 2. Fill data structure
// Complexity: O( #rows + #entries )
for (unsigned int i = 0; i <= iHeight; ++i)
_pMatrix->m_plRowStarts[i] = 0;
mwIndex *colStarts = mxGetJc(_rhs);
mwIndex *rowIndices = mxGetIr(_rhs);
double *floatValues = 0;
mxLogical *boolValues = 0;
bool bLogical = mxIsLogical(_rhs);
if (bLogical)
boolValues = mxGetLogicals(_rhs);
else
floatValues = mxGetPr(_rhs);
for (mwIndex i = 0; i < colStarts[iWidth]; ++i) {
unsigned int iRow = rowIndices[i];
assert(iRow < iHeight);
_pMatrix->m_plRowStarts[iRow+1]++;
}
// Now _pMatrix->m_plRowStarts[i+1] is the number of entries in row i
for (unsigned int i = 1; i <= iHeight; ++i)
_pMatrix->m_plRowStarts[i] += _pMatrix->m_plRowStarts[i-1];
// Now _pMatrix->m_plRowStarts[i+1] is the number of entries in rows <= i,
// so the intended start of row i+1
int iCol = 0;
for (mwIndex i = 0; i < colStarts[iWidth]; ++i) {
while (i >= colStarts[iCol+1])
iCol++;
unsigned int iRow = rowIndices[i];
assert(iRow < iHeight);
float32 fVal;
if (bLogical)
fVal = (float32)boolValues[i];
else
fVal = (float32)floatValues[i];
unsigned long lIndex = _pMatrix->m_plRowStarts[iRow]++;
_pMatrix->m_pfValues[lIndex] = fVal;
_pMatrix->m_piColIndices[lIndex] = iCol;
}
// Now _pMatrix->m_plRowStarts[i] is the start of row i+1
for (unsigned int i = iHeight; i > 0; --i)
_pMatrix->m_plRowStarts[i] = _pMatrix->m_plRowStarts[i-1];
_pMatrix->m_plRowStarts[0] = 0;
#if 0
// Debugging: dump matrix
for (unsigned int i = 0; i < iHeight; ++i) {
printf("Row %d: %ld-%ld\n", i, _pMatrix->m_plRowStarts[i], _pMatrix->m_plRowStarts[i+1]);
for (unsigned long j = _pMatrix->m_plRowStarts[i]; j < _pMatrix->m_plRowStarts[i+1]; ++j) {
printf("(%d,%d) = %f\n", i, _pMatrix->m_piColIndices[j], _pMatrix->m_pfValues[j]);
}
}
#endif
return true;
}
static bool astra_to_matlab(const CSparseMatrix* _pMatrix, mxArray*& _lhs)
{
if (!_pMatrix->isInitialized()) {
mexErrMsgTxt("Uninitialized data object.\n");
return false;
}
unsigned int iHeight = _pMatrix->m_iHeight;
unsigned int iWidth = _pMatrix->m_iWidth;
unsigned long lSize = _pMatrix->m_lSize;
_lhs = mxCreateSparse(iHeight, iWidth, lSize, mxREAL);
if (!mxIsSparse (_lhs)) {
mexErrMsgTxt("Couldn't initialize matlab sparse matrix.\n");
return false;
}
mwIndex *colStarts = mxGetJc(_lhs);
mwIndex *rowIndices = mxGetIr(_lhs);
double *floatValues = mxGetPr(_lhs);
for (unsigned int i = 0; i <= iWidth; ++i)
colStarts[i] = 0;
for (unsigned int i = 0; i < _pMatrix->m_plRowStarts[iHeight]; ++i) {
unsigned int iCol = _pMatrix->m_piColIndices[i];
assert(iCol < iWidth);
colStarts[iCol+1]++;
}
// Now _pMatrix->m_plRowStarts[i+1] is the number of entries in row i
for (unsigned int i = 1; i <= iWidth; ++i)
colStarts[i] += colStarts[i-1];
// Now _pMatrix->m_plRowStarts[i+1] is the number of entries in rows <= i,
// so the intended start of row i+1
unsigned int iRow = 0;
for (unsigned int i = 0; i < _pMatrix->m_plRowStarts[iHeight]; ++i) {
while (i >= _pMatrix->m_plRowStarts[iRow+1])
iRow++;
unsigned int iCol = _pMatrix->m_piColIndices[i];
assert(iCol < iWidth);
double fVal = _pMatrix->m_pfValues[i];
unsigned long lIndex = colStarts[iCol]++;
floatValues[lIndex] = fVal;
rowIndices[lIndex] = iRow;
}
// Now _pMatrix->m_plRowStarts[i] is the start of row i+1
for (unsigned int i = iWidth; i > 0; --i)
colStarts[i] = colStarts[i-1];
colStarts[0] = 0;
return true;
}
//-----------------------------------------------------------------------------------------
/** id = astra_mex_matrix('create', data);
*
* Create a new matrix object in the astra-library.
* data: a sparse MATLAB matrix containing the data.
* id: identifier of the matrix object as it is now stored in the astra-library.
*/
void astra_mex_matrix_create(int& nlhs, mxArray* plhs[], int& nrhs, const mxArray* prhs[])
{
// step1: get datatype
if (nrhs < 2) {
mexErrMsgTxt("Not enough arguments. See the help document for a detailed argument list. \n");
return;
}
if (!mxIsSparse (prhs[1])) {
mexErrMsgTxt("Argument is not a valid MATLAB sparse matrix.\n");
return;
}
unsigned int iHeight = mxGetM(prhs[1]);
unsigned int iWidth = mxGetN(prhs[1]);
unsigned long lSize = mxGetNzmax(prhs[1]);
CSparseMatrix* pMatrix = new CSparseMatrix(iHeight, iWidth, lSize);
// Check initialization
if (!pMatrix->isInitialized()) {
mexErrMsgTxt("Couldn't initialize data object.\n");
delete pMatrix;
return;
}
bool bResult = matlab_to_astra(prhs[1], pMatrix);
if (!bResult) {
mexErrMsgTxt("Failed to create data object.\n");
delete pMatrix;
return;
}
// store data object
int iIndex = CMatrixManager::getSingleton().store(pMatrix);
// return data id
if (1 <= nlhs) {
plhs[0] = mxCreateDoubleScalar(iIndex);
}
}
//-----------------------------------------------------------------------------------------
/** astra_mex_matrix('store', id, data);
*
* Store a sparse MATLAB matrix in an existing astra matrix dataobject.
* id: identifier of the 2d data object as stored in the astra-library.
* data: a sparse MATLAB matrix.
*/
void astra_mex_matrix_store(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
// step1: input
if (nrhs < 3) {
mexErrMsgTxt("Not enough arguments. See the help document for a detailed argument list. \n");
return;
}
if (!mxIsDouble(prhs[1])) {
mexErrMsgTxt("Identifier should be a scalar value. \n");
return;
}
int iDataID = (int)(mxGetScalar(prhs[1]));
// step2: get data object
CSparseMatrix* pMatrix = astra::CMatrixManager::getSingleton().get(iDataID);
if (!pMatrix || !pMatrix->isInitialized()) {
mexErrMsgTxt("Data object not found or not initialized properly.\n");
return;
}
bool bResult = matlab_to_astra(prhs[2], pMatrix);
if (!bResult) {
mexErrMsgTxt("Failed to store matrix.\n");
}
}
//-----------------------------------------------------------------------------------------
/** geom = astra_mex_matrix('get_size', id);
*
* Fetch the dimensions and size of a matrix stored in the astra-library.
* id: identifier of the 2d data object as stored in the astra-library.
* geom: a 1x2 matrix containing [rows, columns]
*/
void astra_mex_matrix_get_size(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
// step1: input
if (nrhs < 2) {
mexErrMsgTxt("Not enough arguments. See the help document for a detailed argument list. \n");
return;
}
if (!mxIsDouble(prhs[1])) {
mexErrMsgTxt("Identifier should be a scalar value. \n");
return;
}
int iDataID = (int)(mxGetScalar(prhs[1]));
// step2: get data object
CSparseMatrix* pMatrix = astra::CMatrixManager::getSingleton().get(iDataID);
if (!pMatrix || !pMatrix->isInitialized()) {
mexErrMsgTxt("Data object not found or not initialized properly.\n");
return;
}
// create output
// TODO
}
//-----------------------------------------------------------------------------------------
/** data = astra_mex_matrix('get', id);
*
* Fetch data from the astra-library to a MATLAB matrix.
* id: identifier of the matrix data object as stored in the astra-library.
* data: MATLAB
*/
void astra_mex_matrix_get(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
// step1: check input
if (nrhs < 2) {
mexErrMsgTxt("Not enough arguments. See the help document for a detailed argument list. \n");
return;
}
if (!mxIsDouble(prhs[1])) {
mexErrMsgTxt("Identifier should be a scalar value. \n");
return;
}
int iDataID = (int)(mxGetScalar(prhs[1]));
// step2: get data object
CSparseMatrix* pMatrix = astra::CMatrixManager::getSingleton().get(iDataID);
if (!pMatrix || !pMatrix->isInitialized()) {
mexErrMsgTxt("Data object not found or not initialized properly.\n");
return;
}
// create output
if (1 <= nlhs) {
bool bResult = astra_to_matlab(pMatrix, plhs[0]);
if (!bResult) {
mexErrMsgTxt("Failed to get matrix.\n");
}
}
}
//-----------------------------------------------------------------------------------------
/** astra_mex_matrix('info');
*
* Print information about all the matrix objects currently stored in the astra-library.
*/
void astra_mex_matrix_info(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
mexPrintf("%s", astra::CMatrixManager::getSingleton().info().c_str());
}
//-----------------------------------------------------------------------------------------
static void printHelp()
{
mexPrintf("Please specify a mode of operation.\n");
mexPrintf("Valid modes: get, delete, clear, store, create, get_size, info\n");
}
//-----------------------------------------------------------------------------------------
/**
* ... = astra_mex_matrix(type,...);
*/
void mexFunction(int nlhs, mxArray* plhs[],
int nrhs, const mxArray* prhs[])
{
// INPUT0: Mode
string sMode = "";
if (1 <= nrhs) {
sMode = mexToString(prhs[0]);
} else {
printHelp();
return;
}
initASTRAMex();
// SWITCH (MODE)
if (sMode == std::string("get")) {
astra_mex_matrix_get(nlhs, plhs, nrhs, prhs);
} else if (sMode == std::string("delete")) {
astra_mex_matrix_delete(nlhs, plhs, nrhs, prhs);
} else if (sMode == "clear") {
astra_mex_matrix_clear(nlhs, plhs, nrhs, prhs);
} else if (sMode == std::string("store")) {
astra_mex_matrix_store(nlhs, plhs, nrhs, prhs);
} else if (sMode == std::string("create")) {
astra_mex_matrix_create(nlhs, plhs, nrhs, prhs);
} else if (sMode == std::string("get_size")) {
astra_mex_matrix_get_size(nlhs, plhs, nrhs, prhs);
} else if (sMode == std::string("info")) {
astra_mex_matrix_info(nlhs, plhs, nrhs, prhs);
} else {
printHelp();
}
return;
}