summaryrefslogtreecommitdiffstats
path: root/matlab/mex
diff options
context:
space:
mode:
authorWillem Jan Palenstijn <wjp@usecode.org>2016-02-22 12:12:27 +0100
committerWillem Jan Palenstijn <wjp@usecode.org>2016-02-22 12:12:27 +0100
commitcd395d2af23530f9da471fd6c512e9890c08f69d (patch)
tree0fb1b8431ef1773d0f4fcee9e780a19a371021a9 /matlab/mex
parent3743fdc534b39958c105f4124ad1130d3e8b042a (diff)
parentd2705f52766716b4d4627a62de92e7a2480b6b86 (diff)
downloadastra-cd395d2af23530f9da471fd6c512e9890c08f69d.tar.gz
astra-cd395d2af23530f9da471fd6c512e9890c08f69d.tar.bz2
astra-cd395d2af23530f9da471fd6c512e9890c08f69d.tar.xz
astra-cd395d2af23530f9da471fd6c512e9890c08f69d.zip
Merge pull request #111 from wjp/pluginbuild
Remove dependency of libastra on libpython
Diffstat (limited to 'matlab/mex')
-rw-r--r--matlab/mex/astra_mex_plugin_c.cpp86
-rw-r--r--matlab/mex/mexInitFunctions.cpp8
2 files changed, 85 insertions, 9 deletions
diff --git a/matlab/mex/astra_mex_plugin_c.cpp b/matlab/mex/astra_mex_plugin_c.cpp
index 177fcf4..4ed534e 100644
--- a/matlab/mex/astra_mex_plugin_c.cpp
+++ b/matlab/mex/astra_mex_plugin_c.cpp
@@ -37,9 +37,63 @@ $Id$
#include "astra/PluginAlgorithm.h"
+#include <Python.h>
+
using namespace std;
using namespace astra;
+static void fixLapackLoading()
+{
+ // When running in Matlab, we need to force numpy
+ // to use its internal lapack library instead of
+ // Matlab's MKL library to avoid errors. To do this,
+ // we set Python's dlopen flags to RTLD_NOW|RTLD_DEEPBIND
+ // and import 'numpy.linalg.lapack_lite' here. We reset
+ // Python's dlopen flags afterwards.
+ PyObject *sys = PyImport_ImportModule("sys");
+ if (sys != NULL) {
+ PyObject *curFlags = PyObject_CallMethod(sys, "getdlopenflags", NULL);
+ if (curFlags != NULL) {
+ PyObject *retVal = PyObject_CallMethod(sys, "setdlopenflags", "i", 10); // RTLD_NOW|RTLD_DEEPBIND
+ if (retVal != NULL) {
+ PyObject *lapack = PyImport_ImportModule("numpy.linalg.lapack_lite");
+ if (lapack != NULL) {
+ Py_DECREF(lapack);
+ }
+ PyObject *retVal2 = PyObject_CallMethod(sys, "setdlopenflags", "O",curFlags);
+ if (retVal2 != NULL) {
+ Py_DECREF(retVal2);
+ }
+ Py_DECREF(retVal);
+ }
+ Py_DECREF(curFlags);
+ }
+ Py_DECREF(sys);
+ }
+}
+
+//-----------------------------------------------------------------------------------------
+/** astra_mex_plugin('init');
+ *
+ * Initialize plugin support by initializing python and importing astra
+ */
+void astra_mex_plugin_init()
+{
+ if(!Py_IsInitialized()){
+ Py_Initialize();
+ PyEval_InitThreads();
+ }
+
+#ifndef _MSC_VER
+ fixLapackLoading();
+#endif
+
+ // Importing astra may be overkill, since we only need to initialize
+ // PythonPluginAlgorithmFactory from astra.plugin_c.
+ PyObject *mod = PyImport_ImportModule("astra");
+ Py_XDECREF(mod);
+}
+
//-----------------------------------------------------------------------------------------
/** astra_mex_plugin('get_registered');
@@ -48,7 +102,11 @@ using namespace astra;
*/
void astra_mex_plugin_get_registered(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
- astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr();
+ astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getFactory();
+ if (!fact) {
+ mexPrintf("Plugin support not initialized.");
+ return;
+ }
std::map<std::string, std::string> mp = fact->getRegisteredMap();
for(std::map<std::string,std::string>::iterator it=mp.begin();it!=mp.end();it++){
mexPrintf("%s: %s\n",it->first.c_str(), it->second.c_str());
@@ -62,9 +120,13 @@ void astra_mex_plugin_get_registered(int nlhs, mxArray* plhs[], int nrhs, const
*/
void astra_mex_plugin_register(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
+ astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getFactory();
+ if (!fact) {
+ mexPrintf("Plugin support not initialized.");
+ return;
+ }
if (2 <= nrhs) {
string class_name = mexToString(prhs[1]);
- astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr();
fact->registerPlugin(class_name);
}else{
mexPrintf("astra_mex_plugin('register', class_name);\n");
@@ -78,9 +140,13 @@ void astra_mex_plugin_register(int nlhs, mxArray* plhs[], int nrhs, const mxArra
*/
void astra_mex_plugin_get_help(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
+ astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getFactory();
+ if (!fact) {
+ mexPrintf("Plugin support not initialized.");
+ return;
+ }
if (2 <= nrhs) {
string name = mexToString(prhs[1]);
- astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr();
mexPrintf((fact->getHelp(name)+"\n").c_str());
}else{
mexPrintf("astra_mex_plugin('get_help', name);\n");
@@ -116,12 +182,14 @@ void mexFunction(int nlhs, mxArray* plhs[],
initASTRAMex();
// SWITCH (MODE)
- if (sMode == std::string("get_registered")) {
- astra_mex_plugin_get_registered(nlhs, plhs, nrhs, prhs);
- }else if (sMode == std::string("get_help")) {
- astra_mex_plugin_get_help(nlhs, plhs, nrhs, prhs);
- }else if (sMode == std::string("register")) {
- astra_mex_plugin_register(nlhs, plhs, nrhs, prhs);
+ if (sMode == "init") {
+ astra_mex_plugin_init();
+ } else if (sMode == std::string("get_registered")) {
+ astra_mex_plugin_get_registered(nlhs, plhs, nrhs, prhs);
+ }else if (sMode == std::string("get_help")) {
+ astra_mex_plugin_get_help(nlhs, plhs, nrhs, prhs);
+ }else if (sMode == std::string("register")) {
+ astra_mex_plugin_register(nlhs, plhs, nrhs, prhs);
} else {
printHelp();
}
diff --git a/matlab/mex/mexInitFunctions.cpp b/matlab/mex/mexInitFunctions.cpp
index bd3df2c..7245af2 100644
--- a/matlab/mex/mexInitFunctions.cpp
+++ b/matlab/mex/mexInitFunctions.cpp
@@ -23,5 +23,13 @@ void initASTRAMex(){
if(!astra::CLogger::setCallbackScreen(&logCallBack)){
mexErrMsgTxt("Error initializing mex functions.");
}
+
mexIsInitialized=true;
+
+
+ // If we have support for plugins, initialize them.
+ // (NB: Call this after setting mexIsInitialized, to avoid recursively
+ // calling initASTRAMex)
+ mexEvalString("if exist('astra_mex_plugin_c') == 3; astra_mex_plugin_c('init'); end");
+
}