From 3cae1d138c53a3fd042de3d2c9d9a07cf0650e0f Mon Sep 17 00:00:00 2001 From: "Daniel M. Pelt" Date: Tue, 24 Feb 2015 12:35:45 +0100 Subject: Added Python interface --- python/astra/utils.pyx | 260 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 260 insertions(+) create mode 100644 python/astra/utils.pyx (limited to 'python/astra/utils.pyx') diff --git a/python/astra/utils.pyx b/python/astra/utils.pyx new file mode 100644 index 0000000..53e84a9 --- /dev/null +++ b/python/astra/utils.pyx @@ -0,0 +1,260 @@ +#----------------------------------------------------------------------- +#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see . +# +#----------------------------------------------------------------------- +# distutils: language = c++ +# distutils: libraries = astra + +import numpy as np +import six +from libcpp.string cimport string +from libcpp.list cimport list +from libcpp.vector cimport vector +from cython.operator cimport dereference as deref +from cpython.version cimport PY_MAJOR_VERSION + +cimport PyXMLDocument +from .PyXMLDocument cimport XMLDocument +from .PyXMLDocument cimport XMLNode +from .PyIncludes cimport * + + +cdef XMLDocument * dict2XML(string rootname, dc): + cdef XMLDocument * doc = PyXMLDocument.createDocument(rootname) + cdef XMLNode * node = doc.getRootNode() + try: + readDict(node, dc) + except: + six.print_('Error reading XML') + del doc + doc = NULL + finally: + del node + return doc + +def convert_item(item): + if isinstance(item, six.string_types): + return item.encode('ascii') + + if type(item) is not dict: + return item + + out_dict = {} + for k in item: + out_dict[convert_item(k)] = convert_item(item[k]) + return out_dict + + +def wrap_to_bytes(value): + if isinstance(value, six.binary_type): + return value + s = str(value) + if PY_MAJOR_VERSION == 3: + s = s.encode('ascii') + return s + + +def wrap_from_bytes(value): + s = value + if PY_MAJOR_VERSION == 3: + s = s.decode('ascii') + return s + + +cdef void readDict(XMLNode * root, _dc): + cdef XMLNode * listbase + cdef XMLNode * itm + cdef int i + cdef int j + + dc = convert_item(_dc) + for item in dc: + val = dc[item] + if isinstance(val, np.ndarray): + if val.size == 0: + break + listbase = root.addChildNode(item) + listbase.addAttribute(< string > six.b('listsize'), < float32 > val.size) + index = 0 + if val.ndim == 2: + for i in range(val.shape[0]): + for j in range(val.shape[1]): + itm = listbase.addChildNode(six.b('ListItem')) + itm.addAttribute(< string > six.b('index'), < float32 > index) + itm.addAttribute( < string > six.b('value'), < float32 > val[i, j]) + index += 1 + del itm + elif val.ndim == 1: + for i in range(val.shape[0]): + itm = listbase.addChildNode(six.b('ListItem')) + itm.addAttribute(< string > six.b('index'), < float32 > index) + itm.addAttribute(< string > six.b('value'), < float32 > val[i]) + index += 1 + del itm + else: + raise Exception("Only 1 or 2 dimensions are allowed") + del listbase + elif isinstance(val, dict): + if item == six.b('option') or item == six.b('options') or item == six.b('Option') or item == six.b('Options'): + readOptions(root, val) + else: + itm = root.addChildNode(item) + readDict(itm, val) + del itm + else: + if item == six.b('type'): + root.addAttribute(< string > six.b('type'), wrap_to_bytes(val)) + else: + itm = root.addChildNode(item, wrap_to_bytes(val)) + del itm + +cdef void readOptions(XMLNode * node, dc): + cdef XMLNode * listbase + cdef XMLNode * itm + cdef int i + cdef int j + for item in dc: + val = dc[item] + if node.hasOption(item): + raise Exception('Duplicate Option: %s' % item) + if isinstance(val, np.ndarray): + if val.size == 0: + break + listbase = node.addChildNode(six.b('Option')) + listbase.addAttribute(< string > six.b('key'), < string > item) + listbase.addAttribute(< string > six.b('listsize'), < float32 > val.size) + index = 0 + if val.ndim == 2: + for i in range(val.shape[0]): + for j in range(val.shape[1]): + itm = listbase.addChildNode(six.b('ListItem')) + itm.addAttribute(< string > six.b('index'), < float32 > index) + itm.addAttribute( < string > six.b('value'), < float32 > val[i, j]) + index += 1 + del itm + elif val.ndim == 1: + for i in range(val.shape[0]): + itm = listbase.addChildNode(six.b('ListItem')) + itm.addAttribute(< string > six.b('index'), < float32 > index) + itm.addAttribute(< string > six.b('value'), < float32 > val[i]) + index += 1 + del itm + else: + raise Exception("Only 1 or 2 dimensions are allowed") + del listbase + else: + node.addOption(item, wrap_to_bytes(val)) + +cdef vectorToNumpy(vector[float32] inp): + cdef int i + cdef int sz = inp.size() + ret = np.empty(sz) + for i in range(sz): + ret[i] = inp[i] + return ret + +cdef XMLNode2dict(XMLNode * node): + cdef XMLNode * subnode + cdef list[XMLNode * ] nodes + cdef list[XMLNode * ].iterator it + dct = {} + if node.hasAttribute(six.b('type')): + dct['type'] = node.getAttribute(six.b('type')) + nodes = node.getNodes() + it = nodes.begin() + while it != nodes.end(): + subnode = deref(it) + if subnode.hasAttribute(six.b('listsize')): + dct[subnode.getName( + )] = vectorToNumpy(subnode.getContentNumericalArray()) + else: + dct[subnode.getName()] = subnode.getContent() + del subnode + return dct + +cdef XML2dict(XMLDocument * xml): + cdef XMLNode * node = xml.getRootNode() + dct = XMLNode2dict(node) + del node; + return dct; + +cdef createProjectionGeometryStruct(CProjectionGeometry2D * geom): + cdef int i + cdef CFanFlatVecProjectionGeometry2D * fanvecGeom + # cdef SFanProjection* p + dct = {} + dct['DetectorCount'] = geom.getDetectorCount() + if not geom.isOfType(< string > six.b('fanflat_vec')): + dct['DetectorWidth'] = geom.getDetectorWidth() + angles = np.empty(geom.getProjectionAngleCount()) + for i in range(geom.getProjectionAngleCount()): + angles[i] = geom.getProjectionAngle(i) + dct['ProjectionAngles'] = angles + else: + raise Exception("Not yet implemented") + # fanvecGeom = geom + # vecs = np.empty(fanvecGeom.getProjectionAngleCount()*6) + # iDetCount = pVecGeom.getDetectorCount() + # for i in range(fanvecGeom.getProjectionAngleCount()): + # p = &fanvecGeom.getProjectionVectors()[i]; + # out[6*i + 0] = p.fSrcX + # out[6*i + 1] = p.fSrcY + # out[6*i + 2] = p.fDetSX + 0.5f*iDetCount*p.fDetUX + # out[6*i + 3] = p.fDetSY + 0.5f*iDetCount*p.fDetUY + # out[6*i + 4] = p.fDetUX + # out[6*i + 5] = p.fDetUY + # dct['Vectors'] = vecs + if (geom.isOfType(< string > six.b('parallel'))): + dct["type"] = "parallel" + elif (geom.isOfType(< string > six.b('fanflat'))): + raise Exception("Not yet implemented") + # astra::CFanFlatProjectionGeometry2D* pFanFlatGeom = dynamic_cast(_pProjGeom) + # mGeometryInfo["DistanceOriginSource"] = mxCreateDoubleScalar(pFanFlatGeom->getOriginSourceDistance()) + # mGeometryInfo["DistanceOriginDetector"] = + # mxCreateDoubleScalar(pFanFlatGeom->getOriginDetectorDistance()) + dct["type"] = "fanflat" + elif (geom.isOfType(< string > six.b('sparse_matrix'))): + raise Exception("Not yet implemented") + # astra::CSparseMatrixProjectionGeometry2D* pSparseMatrixGeom = + # dynamic_cast(_pProjGeom); + dct["type"] = "sparse_matrix" + # dct["MatrixID"] = + # mxCreateDoubleScalar(CMatrixManager::getSingleton().getIndex(pSparseMatrixGeom->getMatrix())) + elif(geom.isOfType(< string > six.b('fanflat_vec'))): + dct["type"] = "fanflat_vec" + return dct + +cdef createVolumeGeometryStruct(CVolumeGeometry2D * geom): + mGeometryInfo = {} + mGeometryInfo["GridColCount"] = geom.getGridColCount() + mGeometryInfo["GridRowCount"] = geom.getGridRowCount() + + mGeometryOptions = {} + mGeometryOptions["WindowMinX"] = geom.getWindowMinX() + mGeometryOptions["WindowMaxX"] = geom.getWindowMaxX() + mGeometryOptions["WindowMinY"] = geom.getWindowMinY() + mGeometryOptions["WindowMaxY"] = geom.getWindowMaxY() + + mGeometryInfo["option"] = mGeometryOptions + return mGeometryInfo -- cgit v1.2.3