summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--matlab/tools/opTomo.m81
-rw-r--r--python/astra/utils.pyx12
-rw-r--r--src/CudaDataOperationAlgorithm.cpp2
-rw-r--r--src/FilteredBackProjectionAlgorithm.cpp14
-rw-r--r--src/Float32VolumeData3DMemory.cpp1
-rw-r--r--src/SartAlgorithm.cpp24
-rw-r--r--src/XMLNode.cpp4
7 files changed, 66 insertions, 72 deletions
diff --git a/matlab/tools/opTomo.m b/matlab/tools/opTomo.m
index 71dfb1e..04b3634 100644
--- a/matlab/tools/opTomo.m
+++ b/matlab/tools/opTomo.m
@@ -44,11 +44,9 @@ classdef opTomo < opSpot
vol_id
fp_alg_id
bp_alg_id
+ proj_id
% ASTRA IDs handle
astra_handle
- % geometries
- proj_geom;
- vol_geom;
end % properties
properties ( SetAccess = private, GetAccess = public )
@@ -139,6 +137,17 @@ classdef opTomo < opSpot
error(['Only type ' 39 'cuda' 39 ' is supported ' ...
'for 3D geometries.'])
end
+
+ % setup projector
+ cfg = astra_struct('cuda3d');
+ cfg.ProjectionGeometry = proj_geom;
+ cfg.VolumeGeometry = vol_geom;
+ cfg.option.GPUindex = gpu_index;
+
+ % create projector
+ op.proj_id = astra_mex_projector3d('create', cfg);
+ % create handle to ASTRA object, for cleaning up
+ op.astra_handle = opTomo_helper_handle(op.proj_id);
% create a function handle
op.funHandle = @opTomo_intrnl3D;
@@ -148,8 +157,6 @@ classdef opTomo < opSpot
% pass object properties
op.proj_size = proj_size;
op.vol_size = vol_size;
- op.proj_geom = proj_geom;
- op.vol_geom = vol_geom;
op.cflag = false;
op.sweepflag = false;
@@ -169,10 +176,12 @@ classdef opTomo < opSpot
if issparse(x)
x = full(x);
end
-
- % convert input to single
- if isa(x, 'single') == false
+
+ if isa(x, 'double')
+ isdouble = true;
x = single(x);
+ else
+ isdouble = false;
end
% the multiplication
@@ -180,6 +189,10 @@ classdef opTomo < opSpot
% make sure output is column vector
y = y(:);
+
+ if isdouble
+ y = double(y);
+ end
end % multiply
@@ -194,7 +207,7 @@ classdef opTomo < opSpot
function y = opTomo_intrnl2D(op,x,mode)
if mode == 1
- % X is passed as a vector, reshape it into an image.
+ % x is passed as a vector, reshape it into an image.
x = reshape(x, op.vol_size);
% Matlab data copied to ASTRA data
@@ -206,7 +219,7 @@ classdef opTomo < opSpot
% retrieve Matlab array
y = astra_mex_data2d('get_single', op.sino_id);
else
- % X is passed as a vector, reshape it into a sinogram.
+ % x is passed as a vector, reshape it into a sinogram.
x = reshape(x, op.proj_size);
% Matlab data copied to ASTRA data
@@ -218,6 +231,7 @@ classdef opTomo < opSpot
% retrieve Matlab array
y = astra_mex_data2d('get_single', op.vol_id);
end
+
end % opTomo_intrnl2D
@@ -225,55 +239,16 @@ classdef opTomo < opSpot
function y = opTomo_intrnl3D(op,x,mode)
if mode == 1
- % X is passed as a vector, reshape it into an image
+ % x is passed as a vector, reshape it into an image
x = reshape(x, op.vol_size);
- % initialize output
- y = zeros(op.proj_size, 'single');
-
- % link matlab array to ASTRA
- vol_id = astra_mex_data3d_c('link', '-vol', op.vol_geom, x, 0);
- sino_id = astra_mex_data3d_c('link', '-sino', op.proj_geom, y, 1);
-
- % initialize fp algorithm
- cfg = astra_struct('FP3D_CUDA');
- cfg.ProjectionDataId = sino_id;
- cfg.VolumeDataId = vol_id;
-
- alg_id = astra_mex_algorithm('create', cfg);
-
% forward projection
- astra_mex_algorithm('iterate', alg_id);
-
- % cleanup
- astra_mex_data3d('delete', vol_id);
- astra_mex_data3d('delete', sino_id);
- astra_mex_algorithm('delete', alg_id);
+ y = astra_mex_direct('FP3D', op.proj_id, x);
else
- % X is passed as a vector, reshape it into projection data
+ % x is passed as a vector, reshape it into projection data
x = reshape(x, op.proj_size);
-
- % initialize output
- y = zeros(op.vol_size,'single');
-
- % link matlab array to ASTRA
- vol_id = astra_mex_data3d_c('link', '-vol', op.vol_geom, y, 1);
- sino_id = astra_mex_data3d_c('link', '-sino', op.proj_geom, x, 0);
- % initialize bp algorithm
- cfg = astra_struct('BP3D_CUDA');
- cfg.ProjectionDataId = sino_id;
- cfg.ReconstructionDataId = vol_id;
-
- alg_id = astra_mex_algorithm('create', cfg);
-
- % backprojection
- astra_mex_algorithm('iterate', alg_id);
-
- % cleanup
- astra_mex_data3d('delete', vol_id);
- astra_mex_data3d('delete', sino_id);
- astra_mex_algorithm('delete', alg_id);
+ y = astra_mex_direct('BP3D', op.proj_id, x);
end
end % opTomo_intrnl3D
diff --git a/python/astra/utils.pyx b/python/astra/utils.pyx
index 52c2a8d..34d1902 100644
--- a/python/astra/utils.pyx
+++ b/python/astra/utils.pyx
@@ -32,7 +32,7 @@ import six
if six.PY3:
import builtins
else:
- import __builtin__
+ import __builtin__ as builtins
from libcpp.string cimport string
from libcpp.vector cimport vector
from libcpp.list cimport list
@@ -95,14 +95,14 @@ cdef void readDict(XMLNode root, _dc):
dc = convert_item(_dc)
for item in dc:
val = dc[item]
- if isinstance(val, __builtins__.list) or isinstance(val, tuple):
+ if isinstance(val, builtins.list) or isinstance(val, tuple):
val = np.array(val,dtype=np.float64)
if isinstance(val, np.ndarray):
if val.size == 0:
break
listbase = root.addChildNode(item)
contig_data = np.ascontiguousarray(val,dtype=np.float64)
- data = <double*>np.PyArray_DATA(contig_data)
+ data = <double*>np.PyArray_DATA(contig_data)
if val.ndim == 2:
listbase.setContent(data, val.shape[1], val.shape[0], False)
elif val.ndim == 1:
@@ -119,6 +119,8 @@ cdef void readDict(XMLNode root, _dc):
if item == six.b('type'):
root.addAttribute(< string > six.b('type'), <string> wrap_to_bytes(val))
else:
+ if isinstance(val, builtins.bool):
+ val = int(val)
itm = root.addChildNode(item, wrap_to_bytes(val))
cdef void readOptions(XMLNode node, dc):
@@ -131,7 +133,7 @@ cdef void readOptions(XMLNode node, dc):
val = dc[item]
if node.hasOption(item):
raise Exception('Duplicate Option: %s' % item)
- if isinstance(val, __builtins__.list) or isinstance(val, tuple):
+ if isinstance(val, builtins.list) or isinstance(val, tuple):
val = np.array(val,dtype=np.float64)
if isinstance(val, np.ndarray):
if val.size == 0:
@@ -147,6 +149,8 @@ cdef void readOptions(XMLNode node, dc):
else:
raise Exception("Only 1 or 2 dimensions are allowed")
else:
+ if isinstance(val, builtins.bool):
+ val = int(val)
node.addOption(item, wrap_to_bytes(val))
cdef configToDict(Config *cfg):
diff --git a/src/CudaDataOperationAlgorithm.cpp b/src/CudaDataOperationAlgorithm.cpp
index 15886a4..82b676b 100644
--- a/src/CudaDataOperationAlgorithm.cpp
+++ b/src/CudaDataOperationAlgorithm.cpp
@@ -76,7 +76,7 @@ bool CCudaDataOperationAlgorithm::initialize(const Config& _cfg)
node = _cfg.self.getSingleNode("DataId");
ASTRA_CONFIG_CHECK(node, "CCudaDataOperationAlgorithm", "No DataId tag specified.");
vector<string> data = node.getContentArray();
- for (vector<string>::iterator it = data.begin(); it != data.end(); it++){
+ for (vector<string>::iterator it = data.begin(); it != data.end(); ++it){
int id = StringUtil::stringToInt(*it);
m_pData.push_back(dynamic_cast<CFloat32Data2D*>(CData2DManager::getSingleton().get(id)));
}
diff --git a/src/FilteredBackProjectionAlgorithm.cpp b/src/FilteredBackProjectionAlgorithm.cpp
index c195578..ccbfec6 100644
--- a/src/FilteredBackProjectionAlgorithm.cpp
+++ b/src/FilteredBackProjectionAlgorithm.cpp
@@ -117,12 +117,10 @@ bool CFilteredBackProjectionAlgorithm::initialize(const Config& _cfg)
int angleCount = projectionIndex.size();
int detectorCount = m_pProjector->getProjectionGeometry()->getDetectorCount();
+ // TODO: There is no need to allocate this. Better just
+ // create the CFloat32ProjectionData2D object directly, and use its
+ // memory.
float32 * sinogramData2D = new float32[angleCount* detectorCount];
- float32 ** sinogramData = new float32*[angleCount];
- for (int i = 0; i < angleCount; i++)
- {
- sinogramData[i] = &(sinogramData2D[i * detectorCount]);
- }
float32 * projectionAngles = new float32[angleCount];
float32 detectorWidth = m_pProjector->getProjectionGeometry()->getDetectorWidth();
@@ -130,6 +128,8 @@ bool CFilteredBackProjectionAlgorithm::initialize(const Config& _cfg)
for (int i = 0; i < angleCount; i ++) {
if (projectionIndex[i] > m_pProjector->getProjectionGeometry()->getProjectionAngleCount() -1 )
{
+ delete[] sinogramData2D;
+ delete[] projectionAngles;
ASTRA_ERROR("Invalid Projection Index");
return false;
} else {
@@ -139,7 +139,6 @@ bool CFilteredBackProjectionAlgorithm::initialize(const Config& _cfg)
{
sinogramData2D[i*detectorCount+ iDetector] = m_pSinogram->getData2D()[orgIndex][iDetector];
}
-// sinogramData[i] = m_pSinogram->getSingleProjectionData(projectionIndex[i]);
projectionAngles[i] = m_pProjector->getProjectionGeometry()->getProjectionAngle((int)projectionIndex[i] );
}
@@ -148,6 +147,9 @@ bool CFilteredBackProjectionAlgorithm::initialize(const Config& _cfg)
CParallelProjectionGeometry2D * pg = new CParallelProjectionGeometry2D(angleCount, detectorCount,detectorWidth,projectionAngles);
m_pProjector = new CParallelBeamLineKernelProjector2D(pg,m_pReconstruction->getGeometry());
m_pSinogram = new CFloat32ProjectionData2D(pg, sinogramData2D);
+
+ delete[] sinogramData2D;
+ delete[] projectionAngles;
}
// TODO: check that the angles are linearly spaced between 0 and pi
diff --git a/src/Float32VolumeData3DMemory.cpp b/src/Float32VolumeData3DMemory.cpp
index af45cb9..14adb1a 100644
--- a/src/Float32VolumeData3DMemory.cpp
+++ b/src/Float32VolumeData3DMemory.cpp
@@ -136,7 +136,6 @@ CFloat32VolumeData2D * CFloat32VolumeData3DMemory::fetchSliceZ(int _iSliceIndex)
CFloat32VolumeData2D* res = new CFloat32VolumeData2D(&volGeom);
// copy data
- int iSliceCount = m_pGeometry->getGridSliceCount();
float * pfTargetData = res->getData();
for(int iRowIndex = 0; iRowIndex < iRowCount; iRowIndex++)
{
diff --git a/src/SartAlgorithm.cpp b/src/SartAlgorithm.cpp
index 403f851..8df3342 100644
--- a/src/SartAlgorithm.cpp
+++ b/src/SartAlgorithm.cpp
@@ -278,9 +278,8 @@ void CSartAlgorithm::run(int _iNrIterations)
m_bShouldAbort = false;
- int iIteration = 0;
-
// data projectors
+ CDataProjectorInterface* pFirstForwardProjector;
CDataProjectorInterface* pForwardProjector;
CDataProjectorInterface* pBackProjector;
@@ -298,7 +297,7 @@ void CSartAlgorithm::run(int _iNrIterations)
// first time forward projection data projector,
// also computes total pixel weight and total ray length
- pForwardProjector = dispatchDataProjector(
+ pFirstForwardProjector = dispatchDataProjector(
m_pProjector,
SinogramMaskPolicy(m_pSinogramMask), // sinogram mask
ReconstructionMaskPolicy(m_pReconstructionMask), // reconstruction mask
@@ -309,16 +308,30 @@ void CSartAlgorithm::run(int _iNrIterations)
m_bUseSinogramMask, m_bUseReconstructionMask, true // options on/off
);
+ // forward projection data projector
+ pForwardProjector = dispatchDataProjector(
+ m_pProjector,
+ SinogramMaskPolicy(m_pSinogramMask), // sinogram mask
+ ReconstructionMaskPolicy(m_pReconstructionMask), // reconstruction mask
+ CombinePolicy<DiffFPPolicy, TotalPixelWeightPolicy>( // 2 basic operations
+ DiffFPPolicy(m_pReconstruction, m_pDiffSinogram, m_pSinogram), // forward projection with difference calculation
+ TotalPixelWeightPolicy(m_pTotalPixelWeight)), // calculate the total pixel weights
+ m_bUseSinogramMask, m_bUseReconstructionMask, true // options on/off
+ );
+
// iteration loop
- for (; iIteration < _iNrIterations && !m_bShouldAbort; ++iIteration) {
+ for (int iIteration = 0; iIteration < _iNrIterations && !m_bShouldAbort; ++iIteration) {
int iProjection = m_piProjectionOrder[m_iIterationCount % m_iProjectionCount];
// forward projection and difference calculation
m_pTotalPixelWeight->setData(0.0f);
- pForwardProjector->projectSingleProjection(iProjection);
+ if (iIteration < m_iProjectionCount)
+ pFirstForwardProjector->projectSingleProjection(iProjection);
+ else
+ pForwardProjector->projectSingleProjection(iProjection);
// backprojection
pBackProjector->projectSingleProjection(iProjection);
// update iteration count
@@ -331,6 +344,7 @@ void CSartAlgorithm::run(int _iNrIterations)
}
+ ASTRA_DELETE(pFirstForwardProjector);
ASTRA_DELETE(pForwardProjector);
ASTRA_DELETE(pBackProjector);
diff --git a/src/XMLNode.cpp b/src/XMLNode.cpp
index 40a9b22..cf268c2 100644
--- a/src/XMLNode.cpp
+++ b/src/XMLNode.cpp
@@ -158,7 +158,7 @@ vector<string> XMLNode::getContentArray() const
vector<string> res(iSize);
// loop all list item nodes
list<XMLNode> nodes = getNodes("ListItem");
- for (list<XMLNode>::iterator it = nodes.begin(); it != nodes.end(); it++) {
+ for (list<XMLNode>::iterator it = nodes.begin(); it != nodes.end(); ++it) {
int iIndex = it->getAttributeNumerical("index");
string sValue = it->getAttribute("value");
ASTRA_ASSERT(iIndex < iSize);
@@ -290,7 +290,7 @@ vector<float32> XMLNode::getOptionNumericalArray(string _sKey) const
if (!hasOption(_sKey)) return vector<float32>();
list<XMLNode> nodes = getNodes("Option");
- for (list<XMLNode>::iterator it = nodes.begin(); it != nodes.end(); it++) {
+ for (list<XMLNode>::iterator it = nodes.begin(); it != nodes.end(); ++it) {
if (it->getAttribute("key") == _sKey) {
vector<float32> vals = it->getContentNumericalArray();
return vals;