diff options
| author | Willem Jan Palenstijn <wjp@usecode.org> | 2016-04-26 15:45:01 +0200 | 
|---|---|---|
| committer | Willem Jan Palenstijn <wjp@usecode.org> | 2016-04-26 15:45:01 +0200 | 
| commit | d60df8bbd0e17016036c279720d6e3464a4d295c (patch) | |
| tree | 281f9beb8b11f891b4f5aa03caa844e94a1d6f74 | |
| parent | c659dd8c2f1d5dcb9cc00e2a8786588ae8427278 (diff) | |
| parent | ed717202a0c917958892e26322d6ea5173f7b32c (diff) | |
| download | astra-d60df8bbd0e17016036c279720d6e3464a4d295c.tar.gz astra-d60df8bbd0e17016036c279720d6e3464a4d295c.tar.bz2 astra-d60df8bbd0e17016036c279720d6e3464a4d295c.tar.xz astra-d60df8bbd0e17016036c279720d6e3464a4d295c.zip | |
Merge pull request #47 from wjp/OpTomo_out
Give OpTomo FP/BP functions with optional out argument
| -rw-r--r-- | python/astra/optomo.py | 96 | ||||
| -rw-r--r-- | samples/python/s018_plugin.py | 34 | 
2 files changed, 87 insertions, 43 deletions
| diff --git a/python/astra/optomo.py b/python/astra/optomo.py index 5a92998..dde719e 100644 --- a/python/astra/optomo.py +++ b/python/astra/optomo.py @@ -111,21 +111,7 @@ class OpTomo(scipy.sparse.linalg.LinearOperator):          :param v: Volume to forward project.          :type v: :class:`numpy.ndarray`          """ -        v = self.__checkArray(v, self.vshape) -        vid = self.data_mod.link('-vol',self.vg,v) -        s = np.zeros(self.sshape,dtype=np.float32) -        sid = self.data_mod.link('-sino',self.pg,s) - -        cfg = creators.astra_dict('FP'+self.appendString) -        cfg['ProjectionDataId'] = sid -        cfg['VolumeDataId'] = vid -        cfg['ProjectorId'] = self.proj_id -        fp_id = algorithm.create(cfg) -        algorithm.run(fp_id) - -        algorithm.delete(fp_id) -        self.data_mod.delete([vid,sid]) -        return s.ravel() +        return self.FP(v, out=None).ravel()      def rmatvec(self,s):          """Implements the transpose operator. @@ -133,21 +119,7 @@ class OpTomo(scipy.sparse.linalg.LinearOperator):          :param s: The projection data.          :type s: :class:`numpy.ndarray`          """ -        s = self.__checkArray(s, self.sshape) -        sid = self.data_mod.link('-sino',self.pg,s) -        v = np.zeros(self.vshape,dtype=np.float32) -        vid = self.data_mod.link('-vol',self.vg,v) - -        cfg = creators.astra_dict('BP'+self.appendString) -        cfg['ProjectionDataId'] = sid -        cfg['ReconstructionDataId'] = vid -        cfg['ProjectorId'] = self.proj_id -        bp_id = algorithm.create(cfg) -        algorithm.run(bp_id) - -        algorithm.delete(bp_id) -        self.data_mod.delete([vid,sid]) -        return v.ravel() +        return self.BP(s, out=None).ravel()      def __mul__(self,v):          """Provides easy forward operator by *. @@ -189,6 +161,70 @@ class OpTomo(scipy.sparse.linalg.LinearOperator):          self.data_mod.delete([vid,sid])          return v +    def FP(self,v,out=None): +        """Perform forward projection. + +        Output must have the right 2D/3D shape. Input may also be flattened. + +        Output must also be contiguous and float32. This isn't required for the +        input, but it is more efficient if it is. + +        :param v: Volume to forward project. +        :type v: :class:`numpy.ndarray` +        :param out: Array to store result in. +        :type out: :class:`numpy.ndarray` +        """ + +        v = self.__checkArray(v, self.vshape) +        vid = self.data_mod.link('-vol',self.vg,v) +        if out is None: +            out = np.zeros(self.sshape,dtype=np.float32) +        sid = self.data_mod.link('-sino',self.pg,out) + +        cfg = creators.astra_dict('FP'+self.appendString) +        cfg['ProjectionDataId'] = sid +        cfg['VolumeDataId'] = vid +        cfg['ProjectorId'] = self.proj_id +        fp_id = algorithm.create(cfg) +        algorithm.run(fp_id) + +        algorithm.delete(fp_id) +        self.data_mod.delete([vid,sid]) +        return out + +    def BP(self,s,out=None): +        """Perform backprojection. + +        Output must have the right 2D/3D shape. Input may also be flattened. + +        Output must also be contiguous and float32. This isn't required for the +        input, but it is more efficient if it is. + +        :param : The projection data. +        :type s: :class:`numpy.ndarray` +        :param out: Array to store result in. +        :type out: :class:`numpy.ndarray` +        """ +        s = self.__checkArray(s, self.sshape) +        sid = self.data_mod.link('-sino',self.pg,s) +        if out is None: +            out = np.zeros(self.vshape,dtype=np.float32) +        vid = self.data_mod.link('-vol',self.vg,out) + +        cfg = creators.astra_dict('BP'+self.appendString) +        cfg['ProjectionDataId'] = sid +        cfg['ReconstructionDataId'] = vid +        cfg['ProjectorId'] = self.proj_id +        bp_id = algorithm.create(cfg) +        algorithm.run(bp_id) + +        algorithm.delete(bp_id) +        self.data_mod.delete([vid,sid]) +        return out + + + +  class OpTomoTranspose(scipy.sparse.linalg.LinearOperator):      """This object provides the transpose operation (``.T``) of the OpTomo object. diff --git a/samples/python/s018_plugin.py b/samples/python/s018_plugin.py index 31cca95..85b5486 100644 --- a/samples/python/s018_plugin.py +++ b/samples/python/s018_plugin.py @@ -30,30 +30,38 @@ import six  # Define the plugin class (has to subclass astra.plugin.base)  # Note that usually, these will be defined in a separate package/module -class SIRTPlugin(astra.plugin.base): -    """Example of an ASTRA plugin class, implementing a simple 2D SIRT algorithm. +class LandweberPlugin(astra.plugin.base): +    """Example of an ASTRA plugin class, implementing a simple 2D Landweber algorithm.      Options: -    'rel_factor': relaxation factor (optional) +    'Relaxation': relaxation factor (optional)      """      # The astra_name variable defines the name to use to      # call the plugin from ASTRA -    astra_name = "SIRT-PLUGIN" +    astra_name = "LANDWEBER-PLUGIN" -    def initialize(self,cfg, rel_factor = 1): +    def initialize(self,cfg, Relaxation = 1):          self.W = astra.OpTomo(cfg['ProjectorId'])          self.vid = cfg['ReconstructionDataId']          self.sid = cfg['ProjectionDataId'] -        self.rel = rel_factor +        self.rel = Relaxation      def run(self, its):          v = astra.data2d.get_shared(self.vid)          s = astra.data2d.get_shared(self.sid) +        tv = np.zeros(v.shape, dtype=np.float32) +        ts = np.zeros(s.shape, dtype=np.float32)          W = self.W          for i in range(its): -            v[:] += self.rel*(W.T*(s - (W*v).reshape(s.shape))).reshape(v.shape)/s.size +            W.FP(v,out=ts) +            ts -= s # ts = W*v - s + +            W.BP(ts,out=tv) +            tv *= self.rel / s.size + +            v -= tv # v = v - rel * W'*(W*v-s) / s.size  if __name__=='__main__': @@ -75,20 +83,20 @@ if __name__=='__main__':      # First we import the package that contains the plugin      import s018_plugin      # Then, we register the plugin class with ASTRA -    astra.plugin.register(s018_plugin.SIRTPlugin) +    astra.plugin.register(s018_plugin.LandweberPlugin)      # Get a list of registered plugins      six.print_(astra.plugin.get_registered())      # To get help on a registered plugin, use get_help -    six.print_(astra.plugin.get_help('SIRT-PLUGIN')) +    six.print_(astra.plugin.get_help('LANDWEBER-PLUGIN'))      # Create data structures      sid = astra.data2d.create('-sino', proj_geom, sinogram)      vid = astra.data2d.create('-vol', vol_geom)      # Create config using plugin name -    cfg = astra.astra_dict('SIRT-PLUGIN') +    cfg = astra.astra_dict('LANDWEBER-PLUGIN')      cfg['ProjectorId'] = proj_id      cfg['ProjectionDataId'] = sid      cfg['ReconstructionDataId'] = vid @@ -103,18 +111,18 @@ if __name__=='__main__':      rec = astra.data2d.get(vid)      # Options for the plugin go in cfg['option'] -    cfg = astra.astra_dict('SIRT-PLUGIN') +    cfg = astra.astra_dict('LANDWEBER-PLUGIN')      cfg['ProjectorId'] = proj_id      cfg['ProjectionDataId'] = sid      cfg['ReconstructionDataId'] = vid      cfg['option'] = {} -    cfg['option']['rel_factor'] = 1.5 +    cfg['option']['Relaxation'] = 1.5      alg_id_rel = astra.algorithm.create(cfg)      astra.algorithm.run(alg_id_rel, 100)      rec_rel = astra.data2d.get(vid)      # We can also use OpTomo to call the plugin -    rec_op = W.reconstruct('SIRT-PLUGIN', sinogram, 100, extraOptions={'rel_factor':1.5}) +    rec_op = W.reconstruct('LANDWEBER-PLUGIN', sinogram, 100, extraOptions={'Relaxation':1.5})      import pylab as pl      pl.gray() | 
