diff options
| -rw-r--r-- | include/astra/PluginAlgorithm.h | 3 | ||||
| -rw-r--r-- | matlab/mex/astra_mex_plugin_c.cpp | 23 | ||||
| -rw-r--r-- | python/astra/plugin.py | 71 | ||||
| -rw-r--r-- | python/astra/plugin_c.pyx | 14 | ||||
| -rw-r--r-- | samples/python/s018_plugin.py | 23 | ||||
| -rw-r--r-- | src/PluginAlgorithm.cpp | 95 | 
6 files changed, 138 insertions, 91 deletions
diff --git a/include/astra/PluginAlgorithm.h b/include/astra/PluginAlgorithm.h index a82c579..b56228e 100644 --- a/include/astra/PluginAlgorithm.h +++ b/include/astra/PluginAlgorithm.h @@ -64,9 +64,12 @@ public:      CPluginAlgorithm * getPlugin(std::string name);      bool registerPlugin(std::string name, std::string className); +    bool registerPlugin(std::string className);      bool registerPluginClass(std::string name, PyObject * className); +    bool registerPluginClass(PyObject * className);      PyObject * getRegistered(); +    std::map<std::string, std::string> getRegisteredMap();      std::string getHelp(std::string name); diff --git a/matlab/mex/astra_mex_plugin_c.cpp b/matlab/mex/astra_mex_plugin_c.cpp index 2d9b9a0..177fcf4 100644 --- a/matlab/mex/astra_mex_plugin_c.cpp +++ b/matlab/mex/astra_mex_plugin_c.cpp @@ -37,9 +37,6 @@ $Id$  #include "astra/PluginAlgorithm.h" -#include "Python.h" -#include "bytesobject.h" -  using namespace std;  using namespace astra; @@ -52,29 +49,25 @@ using namespace astra;  void astra_mex_plugin_get_registered(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])  {      astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr(); -    PyObject *dict = fact->getRegistered(); -    PyObject *key, *value; -    Py_ssize_t pos = 0; -    while (PyDict_Next(dict, &pos, &key, &value)) { -        mexPrintf("%s: %s\n",PyBytes_AsString(key),PyBytes_AsString(value)); +    std::map<std::string, std::string> mp = fact->getRegisteredMap(); +    for(std::map<std::string,std::string>::iterator it=mp.begin();it!=mp.end();it++){ +        mexPrintf("%s: %s\n",it->first.c_str(), it->second.c_str());      } -    Py_DECREF(dict);  }  //----------------------------------------------------------------------------------------- -/** astra_mex_plugin('register', name, class_name); +/** astra_mex_plugin('register', class_name);   *   * Register plugin.   */  void astra_mex_plugin_register(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])  { -    if (3 <= nrhs) { -        string name = mexToString(prhs[1]); -        string class_name = mexToString(prhs[2]); +    if (2 <= nrhs) { +        string class_name = mexToString(prhs[1]);          astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr(); -        fact->registerPlugin(name, class_name); +        fact->registerPlugin(class_name);      }else{ -        mexPrintf("astra_mex_plugin('register', name, class_name);\n"); +        mexPrintf("astra_mex_plugin('register', class_name);\n");      }  } diff --git a/python/astra/plugin.py b/python/astra/plugin.py index f8fc3bd..4b32e6e 100644 --- a/python/astra/plugin.py +++ b/python/astra/plugin.py @@ -32,60 +32,47 @@ import traceback  class base(object):      def astra_init(self, cfg): -        try: -            args, varargs, varkw, defaults = inspect.getargspec(self.initialize) -            if not defaults is None: -                nopt = len(defaults) -            else: -                nopt = 0 -            if nopt>0: -                req = args[2:-nopt] -                opt = args[-nopt:] -            else: -                req = args[2:] -                opt = [] +        args, varargs, varkw, defaults = inspect.getargspec(self.initialize) +        if not defaults is None: +            nopt = len(defaults) +        else: +            nopt = 0 +        if nopt>0: +            req = args[2:-nopt] +            opt = args[-nopt:] +        else: +            req = args[2:] +            opt = [] -            try: -                optDict = cfg['options'] -            except KeyError: -                optDict = {} +        try: +            optDict = cfg['options'] +        except KeyError: +            optDict = {} -            cfgKeys = set(optDict.keys()) -            reqKeys = set(req) -            optKeys = set(opt) +        cfgKeys = set(optDict.keys()) +        reqKeys = set(req) +        optKeys = set(opt) -            if not reqKeys.issubset(cfgKeys): -                for key in reqKeys.difference(cfgKeys): -                    log.error("Required option '" + key + "' for plugin '" + self.__class__.__name__ + "' not specified") -                raise ValueError("Missing required options") +        if not reqKeys.issubset(cfgKeys): +            for key in reqKeys.difference(cfgKeys): +                log.error("Required option '" + key + "' for plugin '" + self.__class__.__name__ + "' not specified") +            raise ValueError("Missing required options") -            if not cfgKeys.issubset(reqKeys | optKeys): -                log.warn(self.__class__.__name__ + ": unused configuration option: " + str(list(cfgKeys.difference(reqKeys | optKeys)))) +        if not cfgKeys.issubset(reqKeys | optKeys): +            log.warn(self.__class__.__name__ + ": unused configuration option: " + str(list(cfgKeys.difference(reqKeys | optKeys)))) -            args = [optDict[k] for k in req] -            kwargs = dict((k,optDict[k]) for k in opt if k in optDict) -            self.initialize(cfg, *args, **kwargs) -        except Exception: -            log.error(traceback.format_exc().replace("%","%%")) -            raise +        args = [optDict[k] for k in req] +        kwargs = dict((k,optDict[k]) for k in opt if k in optDict) +        self.initialize(cfg, *args, **kwargs) -    def astra_run(self, its): -        try: -            self.run(its) -        except Exception: -            log.error(traceback.format_exc().replace("%","%%")) -            raise - -def register(name, className): +def register(className):      """Register plugin with ASTRA. -    :param name: Plugin name to register -    :type name: :class:`str`      :param className: Class name or class object to register      :type className: :class:`str` or :class:`class`      """ -    p.register(name,className) +    p.register(className)  def get_registered():      """Get dictionary of registered plugins. diff --git a/python/astra/plugin_c.pyx b/python/astra/plugin_c.pyx index 91b3cd5..8d6816b 100644 --- a/python/astra/plugin_c.pyx +++ b/python/astra/plugin_c.pyx @@ -38,7 +38,9 @@ from . import utils  cdef extern from "astra/PluginAlgorithm.h" namespace "astra":      cdef cppclass CPluginAlgorithmFactory: +        bool registerPlugin(string className)          bool registerPlugin(string name, string className) +        bool registerPluginClass(object className)          bool registerPluginClass(string name, object className)          object getRegistered()          string getHelp(string name) @@ -46,11 +48,17 @@ cdef extern from "astra/PluginAlgorithm.h" namespace "astra":  cdef extern from "astra/PluginAlgorithm.h" namespace "astra::CPluginAlgorithmFactory":      cdef CPluginAlgorithmFactory* getSingletonPtr() -def register(name, className): +def register(className, name=None):      if inspect.isclass(className): -        fact.registerPluginClass(six.b(name), className) +        if name==None: +            fact.registerPluginClass(className) +        else: +            fact.registerPluginClass(six.b(name), className)      else: -        fact.registerPlugin(six.b(name), six.b(className)) +        if name==None: +            fact.registerPlugin(six.b(className)) +        else: +            fact.registerPlugin(six.b(name), six.b(className))  def get_registered():      return fact.getRegistered() diff --git a/samples/python/s018_plugin.py b/samples/python/s018_plugin.py index 90e09ac..31cca95 100644 --- a/samples/python/s018_plugin.py +++ b/samples/python/s018_plugin.py @@ -38,6 +38,10 @@ class SIRTPlugin(astra.plugin.base):      'rel_factor': relaxation factor (optional)      """ +    # The astra_name variable defines the name to use to +    # call the plugin from ASTRA +    astra_name = "SIRT-PLUGIN" +      def initialize(self,cfg, rel_factor = 1):          self.W = astra.OpTomo(cfg['ProjectorId'])          self.vid = cfg['ReconstructionDataId'] @@ -68,18 +72,13 @@ if __name__=='__main__':      sinogram = sinogram.reshape([180, 384])      # Register the plugin with ASTRA -    # A default set of plugins to load can be defined in: -    #     - /etc/astra-toolbox/plugins.txt -    #     - [ASTRA_INSTALL_PATH]/python/astra/plugins.txt -    #     - [USER_HOME_PATH]/.astra-toolbox/plugins.txt -    #     - [ASTRA_PLUGIN_PATH environment variable]/plugins.txt -    # In these files, create a separate line for each plugin with: -    # [PLUGIN_ASTRA_NAME] [FULL_PLUGIN_CLASS] -    # -    # So in this case, it would be a line: -    # SIRT-PLUGIN s018_plugin.SIRTPlugin -    # -    astra.plugin.register('SIRT-PLUGIN','s018_plugin.SIRTPlugin') +    # First we import the package that contains the plugin +    import s018_plugin +    # Then, we register the plugin class with ASTRA +    astra.plugin.register(s018_plugin.SIRTPlugin) + +    # Get a list of registered plugins +    six.print_(astra.plugin.get_registered())      # To get help on a registered plugin, use get_help      six.print_(astra.plugin.get_help('SIRT-PLUGIN')) diff --git a/src/PluginAlgorithm.cpp b/src/PluginAlgorithm.cpp index d6cf731..7f7ff61 100644 --- a/src/PluginAlgorithm.cpp +++ b/src/PluginAlgorithm.cpp @@ -100,7 +100,10 @@ bool CPluginAlgorithm::initialize(const Config& _cfg){      PyObject *cfgDict = XMLNode2dict(_cfg.self);      PyObject *retVal = PyObject_CallMethod(instance, "astra_init", "O",cfgDict);      Py_DECREF(cfgDict); -    if(retVal==NULL) return false; +    if(retVal==NULL){ +        logPythonError(); +        return false; +    }      m_bIsInitialized = true;      Py_DECREF(retVal);      return m_bIsInitialized; @@ -108,8 +111,11 @@ bool CPluginAlgorithm::initialize(const Config& _cfg){  void CPluginAlgorithm::run(int _iNrIterations){      if(instance==NULL) return; -    PyObject *retVal = PyObject_CallMethod(instance, "astra_run", "i",_iNrIterations); -    if(retVal==NULL) return; +    PyObject *retVal = PyObject_CallMethod(instance, "run", "i",_iNrIterations); +    if(retVal==NULL){ +        logPythonError(); +        return; +    }      Py_DECREF(retVal);  } @@ -157,18 +163,6 @@ CPluginAlgorithmFactory::~CPluginAlgorithmFactory(){      if(six!=NULL) Py_DECREF(six);  } -bool CPluginAlgorithmFactory::registerPlugin(std::string name, std::string className){ -    PyObject *str = PyBytes_FromString(className.c_str()); -    PyDict_SetItemString(pluginDict, name.c_str(), str); -    Py_DECREF(str); -    return true; -} - -bool CPluginAlgorithmFactory::registerPluginClass(std::string name, PyObject * className){ -    PyDict_SetItemString(pluginDict, name.c_str(), className); -    return true; -} -  PyObject * getClassFromString(std::string str){      std::vector<std::string> items;      boost::split(items, str, boost::is_any_of(".")); @@ -190,6 +184,43 @@ PyObject * getClassFromString(std::string str){      return pyclass;  } +bool CPluginAlgorithmFactory::registerPlugin(std::string name, std::string className){ +    PyObject *str = PyBytes_FromString(className.c_str()); +    PyDict_SetItemString(pluginDict, name.c_str(), str); +    Py_DECREF(str); +    return true; +} + +bool CPluginAlgorithmFactory::registerPlugin(std::string className){ +    PyObject *pyclass = getClassFromString(className); +    if(pyclass==NULL) return false; +    bool ret = registerPluginClass(pyclass); +    Py_DECREF(pyclass); +    return ret; +} + +bool CPluginAlgorithmFactory::registerPluginClass(std::string name, PyObject * className){ +    PyDict_SetItemString(pluginDict, name.c_str(), className); +    return true; +} + +bool CPluginAlgorithmFactory::registerPluginClass(PyObject * className){ +    PyObject *astra_name = PyObject_GetAttrString(className,"astra_name"); +    if(astra_name==NULL){ +        logPythonError(); +        return false; +    } +    PyObject *retb = PyObject_CallMethod(six,"b","O",astra_name); +    if(retb!=NULL){ +        PyDict_SetItemString(pluginDict,PyBytes_AsString(retb),className); +        Py_DECREF(retb); +    }else{ +        logPythonError(); +    } +    Py_DECREF(astra_name); +    return true; +} +  CPluginAlgorithm * CPluginAlgorithmFactory::getPlugin(std::string name){      PyObject *className = PyDict_GetItemString(pluginDict, name.c_str());      if(className==NULL) return NULL; @@ -212,12 +243,34 @@ PyObject * CPluginAlgorithmFactory::getRegistered(){      return pluginDict;  } +std::map<std::string, std::string> CPluginAlgorithmFactory::getRegisteredMap(){ +    std::map<std::string, std::string> ret; +    PyObject *key, *value; +    Py_ssize_t pos = 0; +    while (PyDict_Next(pluginDict, &pos, &key, &value)) { +        PyObject * keyb = PyObject_Bytes(key); +        PyObject * valb = PyObject_Bytes(value); +        ret[PyBytes_AsString(keyb)] = PyBytes_AsString(valb); +        Py_DECREF(keyb); +        Py_DECREF(valb); +    } +    return ret; +} +  std::string CPluginAlgorithmFactory::getHelp(std::string name){      PyObject *className = PyDict_GetItemString(pluginDict, name.c_str()); -    if(className==NULL) return ""; -    std::string str = std::string(PyBytes_AsString(className)); +    if(className==NULL){ +        ASTRA_ERROR("Plugin %s not found!",name.c_str()); +        return ""; +    }      std::string ret = ""; -    PyObject *pyclass = getClassFromString(str); +    PyObject *pyclass; +    if(PyBytes_Check(className)){ +        std::string str = std::string(PyBytes_AsString(className)); +        pyclass = getClassFromString(str); +    }else{ +        pyclass = className; +    }      if(pyclass==NULL) return "";      if(inspect!=NULL && six!=NULL){          PyObject *retVal = PyObject_CallMethod(inspect,"getdoc","O",pyclass); @@ -228,9 +281,13 @@ std::string CPluginAlgorithmFactory::getHelp(std::string name){                  ret = std::string(PyBytes_AsString(retb));                  Py_DECREF(retb);              } +        }else{ +            logPythonError();          }      } -    Py_DECREF(pyclass); +    if(PyBytes_Check(className)){ +        Py_DECREF(pyclass); +    }      return ret;  }  | 
