From 5a81584cb37fe578d5985b8e78b511076fa75b1c Mon Sep 17 00:00:00 2001
From: Folkert Bleichrodt <F.Bleichrodt@cwi.nl>
Date: Wed, 2 Mar 2016 16:47:25 +0100
Subject: opTomo: output type matches input type

If opTomo is used to do a forward or backprojection, the precision
of the output data now matches the precision of the input data.
So the output will be single precision only if the input is
stored in single precision.
---
 matlab/tools/opTomo.m | 86 +++++++++++++++++----------------------------------
 1 file changed, 29 insertions(+), 57 deletions(-)

(limited to 'matlab')

diff --git a/matlab/tools/opTomo.m b/matlab/tools/opTomo.m
index 71dfb1e..33c8565 100644
--- a/matlab/tools/opTomo.m
+++ b/matlab/tools/opTomo.m
@@ -44,11 +44,9 @@ classdef opTomo < opSpot
         vol_id
         fp_alg_id
         bp_alg_id
+        proj_id
         % ASTRA IDs handle
         astra_handle
-        % geometries
-        proj_geom;
-        vol_geom;
     end % properties
     
     properties ( SetAccess = private, GetAccess = public )
@@ -139,6 +137,17 @@ classdef opTomo < opSpot
                     error(['Only type ' 39 'cuda' 39 ' is supported ' ...
                            'for 3D geometries.'])
                 end
+
+                % setup projector
+                cfg = astra_struct('cuda3d');
+                cfg.ProjectionGeometry = proj_geom;
+                cfg.VolumeGeometry = vol_geom;
+                cfg.option.GPUindex = gpu_index;
+
+                % create projector
+                op.proj_id = astra_mex_projector3d('create', cfg);
+                % create handle to ASTRA object, for cleaning up
+                op.astra_handle = opTomo_helper_handle(op.proj_id);
                 
                 % create a function handle
                 op.funHandle = @opTomo_intrnl3D;
@@ -148,8 +157,6 @@ classdef opTomo < opSpot
             % pass object properties
             op.proj_size   = proj_size;
             op.vol_size    = vol_size;
-            op.proj_geom   = proj_geom;
-            op.vol_geom    = vol_geom;
             op.cflag       = false;
             op.sweepflag   = false;
 
@@ -170,11 +177,6 @@ classdef opTomo < opSpot
                 x = full(x);
             end
             
-            % convert input to single
-            if isa(x, 'single') == false
-                x = single(x);
-            end
-            
             % the multiplication
             y = op.funHandle(op, x, mode);
             
@@ -194,7 +196,7 @@ classdef opTomo < opSpot
         function y = opTomo_intrnl2D(op,x,mode)
                        
             if mode == 1              
-                % X is passed as a vector, reshape it into an image.             
+                % x is passed as a vector, reshape it into an image.             
                 x = reshape(x, op.vol_size);
                 
                 % Matlab data copied to ASTRA data
@@ -204,9 +206,13 @@ classdef opTomo < opSpot
                 astra_mex_algorithm('iterate', op.fp_alg_id);
                 
                 % retrieve Matlab array
-                y = astra_mex_data2d('get_single', op.sino_id);
+                if isa(x, 'single')
+                    y = astra_mex_data2d('get_single', op.sino_id);
+                else
+                    y = astra_mex_data2d('get', op.sino_id);
+                end
             else
-                % X is passed as a vector, reshape it into a sinogram.
+                % x is passed as a vector, reshape it into a sinogram.
                 x = reshape(x, op.proj_size);
                 
                 % Matlab data copied to ASTRA data
@@ -216,8 +222,13 @@ classdef opTomo < opSpot
                 astra_mex_algorithm('iterate', op.bp_alg_id);
                 
                 % retrieve Matlab array
-                y = astra_mex_data2d('get_single', op.vol_id);
+                if isa(x, 'single')
+                    y = astra_mex_data2d('get_single', op.vol_id);
+                else
+                    y = astra_mex_data2d('get', op.vol_id);
+                end
             end
+
         end % opTomo_intrnl2D
         
         
@@ -225,55 +236,16 @@ classdef opTomo < opSpot
         function y = opTomo_intrnl3D(op,x,mode)
             
             if mode == 1
-                % X is passed as a vector, reshape it into an image
+                % x is passed as a vector, reshape it into an image
                 x = reshape(x, op.vol_size);
                 
-                % initialize output
-                y = zeros(op.proj_size, 'single');
-                
-                % link matlab array to ASTRA
-                vol_id  = astra_mex_data3d_c('link', '-vol', op.vol_geom, x, 0);
-                sino_id = astra_mex_data3d_c('link', '-sino', op.proj_geom, y, 1);
-                
-                % initialize fp algorithm
-                cfg = astra_struct('FP3D_CUDA');
-                cfg.ProjectionDataId = sino_id;
-                cfg.VolumeDataId     = vol_id;
-                
-                alg_id = astra_mex_algorithm('create', cfg);
-                             
                 % forward projection
-                astra_mex_algorithm('iterate', alg_id);
-                
-                % cleanup
-                astra_mex_data3d('delete', vol_id);
-                astra_mex_data3d('delete', sino_id);
-                astra_mex_algorithm('delete', alg_id);
+                y = astra_mex_direct('FP3D', op.proj_id, x);
             else
-                % X is passed as a vector, reshape it into projection data
+                % x is passed as a vector, reshape it into projection data
                 x = reshape(x, op.proj_size);
-                
-                % initialize output
-                y = zeros(op.vol_size,'single');
-                
-                % link matlab array to ASTRA
-                vol_id  = astra_mex_data3d_c('link', '-vol', op.vol_geom, y, 1);
-                sino_id = astra_mex_data3d_c('link', '-sino', op.proj_geom, x, 0);
                                 
-                % initialize bp algorithm
-                cfg = astra_struct('BP3D_CUDA');
-                cfg.ProjectionDataId     = sino_id;
-                cfg.ReconstructionDataId = vol_id;
-                
-                alg_id = astra_mex_algorithm('create', cfg);
-
-                % backprojection
-                astra_mex_algorithm('iterate', alg_id);
-                
-                % cleanup
-                astra_mex_data3d('delete', vol_id);
-                astra_mex_data3d('delete', sino_id);
-                astra_mex_algorithm('delete', alg_id);
+                y = astra_mex_direct('BP3D', op.proj_id, x);
             end 
         end % opTomo_intrnl3D
         
-- 
cgit v1.2.3