From d91b51f6d58003de84a9d6dd8189fceba0e81a5a Mon Sep 17 00:00:00 2001
From: "Daniel M. Pelt" <D.M.Pelt@cwi.nl>
Date: Mon, 20 Jul 2015 14:07:21 +0200
Subject: Allow registering plugins without explicit name, and fix exception
 handling when running in Matlab

---
 include/astra/PluginAlgorithm.h   |  3 ++
 matlab/mex/astra_mex_plugin_c.cpp | 23 ++++------
 python/astra/plugin.py            | 71 ++++++++++++-----------------
 python/astra/plugin_c.pyx         | 14 ++++--
 samples/python/s018_plugin.py     | 23 +++++-----
 src/PluginAlgorithm.cpp           | 95 +++++++++++++++++++++++++++++++--------
 6 files changed, 138 insertions(+), 91 deletions(-)

diff --git a/include/astra/PluginAlgorithm.h b/include/astra/PluginAlgorithm.h
index a82c579..b56228e 100644
--- a/include/astra/PluginAlgorithm.h
+++ b/include/astra/PluginAlgorithm.h
@@ -64,9 +64,12 @@ public:
     CPluginAlgorithm * getPlugin(std::string name);
 
     bool registerPlugin(std::string name, std::string className);
+    bool registerPlugin(std::string className);
     bool registerPluginClass(std::string name, PyObject * className);
+    bool registerPluginClass(PyObject * className);
     
     PyObject * getRegistered();
+    std::map<std::string, std::string> getRegisteredMap();
     
     std::string getHelp(std::string name);
 
diff --git a/matlab/mex/astra_mex_plugin_c.cpp b/matlab/mex/astra_mex_plugin_c.cpp
index 2d9b9a0..177fcf4 100644
--- a/matlab/mex/astra_mex_plugin_c.cpp
+++ b/matlab/mex/astra_mex_plugin_c.cpp
@@ -37,9 +37,6 @@ $Id$
 
 #include "astra/PluginAlgorithm.h"
 
-#include "Python.h"
-#include "bytesobject.h"
-
 using namespace std;
 using namespace astra;
 
@@ -52,29 +49,25 @@ using namespace astra;
 void astra_mex_plugin_get_registered(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
 {
     astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr();
-    PyObject *dict = fact->getRegistered();
-    PyObject *key, *value;
-    Py_ssize_t pos = 0;
-    while (PyDict_Next(dict, &pos, &key, &value)) {
-        mexPrintf("%s: %s\n",PyBytes_AsString(key),PyBytes_AsString(value));
+    std::map<std::string, std::string> mp = fact->getRegisteredMap();
+    for(std::map<std::string,std::string>::iterator it=mp.begin();it!=mp.end();it++){
+        mexPrintf("%s: %s\n",it->first.c_str(), it->second.c_str());
     }
-    Py_DECREF(dict);
 }
 
 //-----------------------------------------------------------------------------------------
-/** astra_mex_plugin('register', name, class_name);
+/** astra_mex_plugin('register', class_name);
  *
  * Register plugin.
  */
 void astra_mex_plugin_register(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
 {
-    if (3 <= nrhs) {
-        string name = mexToString(prhs[1]);
-        string class_name = mexToString(prhs[2]);
+    if (2 <= nrhs) {
+        string class_name = mexToString(prhs[1]);
         astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr();
-        fact->registerPlugin(name, class_name);
+        fact->registerPlugin(class_name);
     }else{
-        mexPrintf("astra_mex_plugin('register', name, class_name);\n");
+        mexPrintf("astra_mex_plugin('register', class_name);\n");
     }
 }
 
diff --git a/python/astra/plugin.py b/python/astra/plugin.py
index f8fc3bd..4b32e6e 100644
--- a/python/astra/plugin.py
+++ b/python/astra/plugin.py
@@ -32,60 +32,47 @@ import traceback
 class base(object):
 
     def astra_init(self, cfg):
-        try:
-            args, varargs, varkw, defaults = inspect.getargspec(self.initialize)
-            if not defaults is None:
-                nopt = len(defaults)
-            else:
-                nopt = 0
-            if nopt>0:
-                req = args[2:-nopt]
-                opt = args[-nopt:]
-            else:
-                req = args[2:]
-                opt = []
+        args, varargs, varkw, defaults = inspect.getargspec(self.initialize)
+        if not defaults is None:
+            nopt = len(defaults)
+        else:
+            nopt = 0
+        if nopt>0:
+            req = args[2:-nopt]
+            opt = args[-nopt:]
+        else:
+            req = args[2:]
+            opt = []
 
-            try:
-                optDict = cfg['options']
-            except KeyError:
-                optDict = {}
+        try:
+            optDict = cfg['options']
+        except KeyError:
+            optDict = {}
 
-            cfgKeys = set(optDict.keys())
-            reqKeys = set(req)
-            optKeys = set(opt)
+        cfgKeys = set(optDict.keys())
+        reqKeys = set(req)
+        optKeys = set(opt)
 
-            if not reqKeys.issubset(cfgKeys):
-                for key in reqKeys.difference(cfgKeys):
-                    log.error("Required option '" + key + "' for plugin '" + self.__class__.__name__ + "' not specified")
-                raise ValueError("Missing required options")
+        if not reqKeys.issubset(cfgKeys):
+            for key in reqKeys.difference(cfgKeys):
+                log.error("Required option '" + key + "' for plugin '" + self.__class__.__name__ + "' not specified")
+            raise ValueError("Missing required options")
 
-            if not cfgKeys.issubset(reqKeys | optKeys):
-                log.warn(self.__class__.__name__ + ": unused configuration option: " + str(list(cfgKeys.difference(reqKeys | optKeys))))
+        if not cfgKeys.issubset(reqKeys | optKeys):
+            log.warn(self.__class__.__name__ + ": unused configuration option: " + str(list(cfgKeys.difference(reqKeys | optKeys))))
 
-            args = [optDict[k] for k in req]
-            kwargs = dict((k,optDict[k]) for k in opt if k in optDict)
-            self.initialize(cfg, *args, **kwargs)
-        except Exception:
-            log.error(traceback.format_exc().replace("%","%%"))
-            raise
+        args = [optDict[k] for k in req]
+        kwargs = dict((k,optDict[k]) for k in opt if k in optDict)
+        self.initialize(cfg, *args, **kwargs)
 
-    def astra_run(self, its):
-        try:
-            self.run(its)
-        except Exception:
-            log.error(traceback.format_exc().replace("%","%%"))
-            raise
-
-def register(name, className):
+def register(className):
     """Register plugin with ASTRA.
     
-    :param name: Plugin name to register
-    :type name: :class:`str`
     :param className: Class name or class object to register
     :type className: :class:`str` or :class:`class`
     
     """
-    p.register(name,className)
+    p.register(className)
 
 def get_registered():
     """Get dictionary of registered plugins.
diff --git a/python/astra/plugin_c.pyx b/python/astra/plugin_c.pyx
index 91b3cd5..8d6816b 100644
--- a/python/astra/plugin_c.pyx
+++ b/python/astra/plugin_c.pyx
@@ -38,7 +38,9 @@ from . import utils
 
 cdef extern from "astra/PluginAlgorithm.h" namespace "astra":
     cdef cppclass CPluginAlgorithmFactory:
+        bool registerPlugin(string className)
         bool registerPlugin(string name, string className)
+        bool registerPluginClass(object className)
         bool registerPluginClass(string name, object className)
         object getRegistered()
         string getHelp(string name)
@@ -46,11 +48,17 @@ cdef extern from "astra/PluginAlgorithm.h" namespace "astra":
 cdef extern from "astra/PluginAlgorithm.h" namespace "astra::CPluginAlgorithmFactory":
     cdef CPluginAlgorithmFactory* getSingletonPtr()
 
-def register(name, className):
+def register(className, name=None):
     if inspect.isclass(className):
-        fact.registerPluginClass(six.b(name), className)
+        if name==None:
+            fact.registerPluginClass(className)
+        else:
+            fact.registerPluginClass(six.b(name), className)
     else:
-        fact.registerPlugin(six.b(name), six.b(className))
+        if name==None:
+            fact.registerPlugin(six.b(className))
+        else:
+            fact.registerPlugin(six.b(name), six.b(className))
 
 def get_registered():
     return fact.getRegistered()
diff --git a/samples/python/s018_plugin.py b/samples/python/s018_plugin.py
index 90e09ac..31cca95 100644
--- a/samples/python/s018_plugin.py
+++ b/samples/python/s018_plugin.py
@@ -38,6 +38,10 @@ class SIRTPlugin(astra.plugin.base):
     'rel_factor': relaxation factor (optional)
     """
 
+    # The astra_name variable defines the name to use to
+    # call the plugin from ASTRA
+    astra_name = "SIRT-PLUGIN"
+
     def initialize(self,cfg, rel_factor = 1):
         self.W = astra.OpTomo(cfg['ProjectorId'])
         self.vid = cfg['ReconstructionDataId']
@@ -68,18 +72,13 @@ if __name__=='__main__':
     sinogram = sinogram.reshape([180, 384])
 
     # Register the plugin with ASTRA
-    # A default set of plugins to load can be defined in:
-    #     - /etc/astra-toolbox/plugins.txt
-    #     - [ASTRA_INSTALL_PATH]/python/astra/plugins.txt
-    #     - [USER_HOME_PATH]/.astra-toolbox/plugins.txt
-    #     - [ASTRA_PLUGIN_PATH environment variable]/plugins.txt
-    # In these files, create a separate line for each plugin with:
-    # [PLUGIN_ASTRA_NAME] [FULL_PLUGIN_CLASS]
-    #
-    # So in this case, it would be a line:
-    # SIRT-PLUGIN s018_plugin.SIRTPlugin
-    #
-    astra.plugin.register('SIRT-PLUGIN','s018_plugin.SIRTPlugin')
+    # First we import the package that contains the plugin
+    import s018_plugin
+    # Then, we register the plugin class with ASTRA
+    astra.plugin.register(s018_plugin.SIRTPlugin)
+
+    # Get a list of registered plugins
+    six.print_(astra.plugin.get_registered())
 
     # To get help on a registered plugin, use get_help
     six.print_(astra.plugin.get_help('SIRT-PLUGIN'))
diff --git a/src/PluginAlgorithm.cpp b/src/PluginAlgorithm.cpp
index d6cf731..7f7ff61 100644
--- a/src/PluginAlgorithm.cpp
+++ b/src/PluginAlgorithm.cpp
@@ -100,7 +100,10 @@ bool CPluginAlgorithm::initialize(const Config& _cfg){
     PyObject *cfgDict = XMLNode2dict(_cfg.self);
     PyObject *retVal = PyObject_CallMethod(instance, "astra_init", "O",cfgDict);
     Py_DECREF(cfgDict);
-    if(retVal==NULL) return false;
+    if(retVal==NULL){
+        logPythonError();
+        return false;
+    }
     m_bIsInitialized = true;
     Py_DECREF(retVal);
     return m_bIsInitialized;
@@ -108,8 +111,11 @@ bool CPluginAlgorithm::initialize(const Config& _cfg){
 
 void CPluginAlgorithm::run(int _iNrIterations){
     if(instance==NULL) return;
-    PyObject *retVal = PyObject_CallMethod(instance, "astra_run", "i",_iNrIterations);
-    if(retVal==NULL) return;
+    PyObject *retVal = PyObject_CallMethod(instance, "run", "i",_iNrIterations);
+    if(retVal==NULL){
+        logPythonError();
+        return;
+    }
     Py_DECREF(retVal);
 }
 
@@ -157,18 +163,6 @@ CPluginAlgorithmFactory::~CPluginAlgorithmFactory(){
     if(six!=NULL) Py_DECREF(six);
 }
 
-bool CPluginAlgorithmFactory::registerPlugin(std::string name, std::string className){
-    PyObject *str = PyBytes_FromString(className.c_str());
-    PyDict_SetItemString(pluginDict, name.c_str(), str);
-    Py_DECREF(str);
-    return true;
-}
-
-bool CPluginAlgorithmFactory::registerPluginClass(std::string name, PyObject * className){
-    PyDict_SetItemString(pluginDict, name.c_str(), className);
-    return true;
-}
-
 PyObject * getClassFromString(std::string str){
     std::vector<std::string> items;
     boost::split(items, str, boost::is_any_of("."));
@@ -190,6 +184,43 @@ PyObject * getClassFromString(std::string str){
     return pyclass;
 }
 
+bool CPluginAlgorithmFactory::registerPlugin(std::string name, std::string className){
+    PyObject *str = PyBytes_FromString(className.c_str());
+    PyDict_SetItemString(pluginDict, name.c_str(), str);
+    Py_DECREF(str);
+    return true;
+}
+
+bool CPluginAlgorithmFactory::registerPlugin(std::string className){
+    PyObject *pyclass = getClassFromString(className);
+    if(pyclass==NULL) return false;
+    bool ret = registerPluginClass(pyclass);
+    Py_DECREF(pyclass);
+    return ret;
+}
+
+bool CPluginAlgorithmFactory::registerPluginClass(std::string name, PyObject * className){
+    PyDict_SetItemString(pluginDict, name.c_str(), className);
+    return true;
+}
+
+bool CPluginAlgorithmFactory::registerPluginClass(PyObject * className){
+    PyObject *astra_name = PyObject_GetAttrString(className,"astra_name");
+    if(astra_name==NULL){
+        logPythonError();
+        return false;
+    }
+    PyObject *retb = PyObject_CallMethod(six,"b","O",astra_name);
+    if(retb!=NULL){
+        PyDict_SetItemString(pluginDict,PyBytes_AsString(retb),className);
+        Py_DECREF(retb);
+    }else{
+        logPythonError();
+    }
+    Py_DECREF(astra_name);
+    return true;
+}
+
 CPluginAlgorithm * CPluginAlgorithmFactory::getPlugin(std::string name){
     PyObject *className = PyDict_GetItemString(pluginDict, name.c_str());
     if(className==NULL) return NULL;
@@ -212,12 +243,34 @@ PyObject * CPluginAlgorithmFactory::getRegistered(){
     return pluginDict;
 }
 
+std::map<std::string, std::string> CPluginAlgorithmFactory::getRegisteredMap(){
+    std::map<std::string, std::string> ret;
+    PyObject *key, *value;
+    Py_ssize_t pos = 0;
+    while (PyDict_Next(pluginDict, &pos, &key, &value)) {
+        PyObject * keyb = PyObject_Bytes(key);
+        PyObject * valb = PyObject_Bytes(value);
+        ret[PyBytes_AsString(keyb)] = PyBytes_AsString(valb);
+        Py_DECREF(keyb);
+        Py_DECREF(valb);
+    }
+    return ret;
+}
+
 std::string CPluginAlgorithmFactory::getHelp(std::string name){
     PyObject *className = PyDict_GetItemString(pluginDict, name.c_str());
-    if(className==NULL) return "";
-    std::string str = std::string(PyBytes_AsString(className));
+    if(className==NULL){
+        ASTRA_ERROR("Plugin %s not found!",name.c_str());
+        return "";
+    }
     std::string ret = "";
-    PyObject *pyclass = getClassFromString(str);
+    PyObject *pyclass;
+    if(PyBytes_Check(className)){
+        std::string str = std::string(PyBytes_AsString(className));
+        pyclass = getClassFromString(str);
+    }else{
+        pyclass = className;
+    }
     if(pyclass==NULL) return "";
     if(inspect!=NULL && six!=NULL){
         PyObject *retVal = PyObject_CallMethod(inspect,"getdoc","O",pyclass);
@@ -228,9 +281,13 @@ std::string CPluginAlgorithmFactory::getHelp(std::string name){
                 ret = std::string(PyBytes_AsString(retb));
                 Py_DECREF(retb);
             }
+        }else{
+            logPythonError();
         }
     }
-    Py_DECREF(pyclass);
+    if(PyBytes_Check(className)){
+        Py_DECREF(pyclass);
+    }
     return ret;
 }
 
-- 
cgit v1.2.3