summaryrefslogtreecommitdiffstats
path: root/python/astra/operator.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/astra/operator.py')
-rw-r--r--python/astra/operator.py33
1 files changed, 11 insertions, 22 deletions
diff --git a/python/astra/operator.py b/python/astra/operator.py
index a3abd5a..0c37353 100644
--- a/python/astra/operator.py
+++ b/python/astra/operator.py
@@ -91,7 +91,7 @@ class OpTomo(scipy.sparse.linalg.LinearOperator):
arr = np.ascontiguousarray(arr)
return arr
- def matvec(self,v):
+ def _matvec(self,v):
"""Implements the forward operator.
:param v: Volume to forward project.
@@ -135,24 +135,16 @@ class OpTomo(scipy.sparse.linalg.LinearOperator):
self.data_mod.delete([vid,sid])
return v.flatten()
- def matmat(self,m):
- """Implements the forward operator with a matrix.
-
- :param m: Volumes to forward project, arranged in columns.
- :type m: :class:`numpy.ndarray`
- """
- out = np.zeros((self.ssize,m.shape[1]),dtype=np.float32)
- for i in range(m.shape[1]):
- out[:,i] = self.matvec(m[:,i].flatten())
- return out
-
def __mul__(self,v):
"""Provides easy forward operator by *.
:param v: Volume to forward project.
:type v: :class:`numpy.ndarray`
"""
- return self.matvec(v)
+ # Catch the case of a forward projection of a 2D/3D image
+ if isinstance(v, np.ndarray) and v.shape==self.vshape:
+ return self._matvec(v)
+ return scipy.sparse.linalg.LinearOperator.__mul__(self, v)
def reconstruct(self, method, s, iterations=1, extraOptions = {}):
"""Reconstruct an object.
@@ -192,17 +184,14 @@ class OpTomoTranspose(scipy.sparse.linalg.LinearOperator):
self.dtype = np.float32
self.shape = (parent.shape[1], parent.shape[0])
- def matvec(self, s):
+ def _matvec(self, s):
return self.parent.rmatvec(s)
def rmatvec(self, v):
return self.parent.matvec(v)
- def matmat(self, m):
- out = np.zeros((self.vsize,m.shape[1]),dtype=np.float32)
- for i in range(m.shape[1]):
- out[:,i] = self.matvec(m[:,i].flatten())
- return out
-
- def __mul__(self,v):
- return self.matvec(v)
+ def __mul__(self,s):
+ # Catch the case of a backprojection of 2D/3D data
+ if isinstance(s, np.ndarray) and s.shape==self.parent.sshape:
+ return self._matvec(s)
+ return scipy.sparse.linalg.LinearOperator.__mul__(self, s)