summaryrefslogtreecommitdiffstats
path: root/Wrappers
diff options
context:
space:
mode:
Diffstat (limited to 'Wrappers')
-rwxr-xr-xWrappers/Python/ccpi/framework/framework.py4
-rw-r--r--Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py89
-rw-r--r--Wrappers/Python/ccpi/optimisation/algorithms/__init__.py1
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algs.py9
-rw-r--r--Wrappers/Python/wip/demo_test_sirt.py90
5 files changed, 176 insertions, 17 deletions
diff --git a/Wrappers/Python/ccpi/framework/framework.py b/Wrappers/Python/ccpi/framework/framework.py
index 7516447..ffc91ae 100755
--- a/Wrappers/Python/ccpi/framework/framework.py
+++ b/Wrappers/Python/ccpi/framework/framework.py
@@ -707,6 +707,10 @@ class DataContainer(object):
def maximum(self, x2, *args, **kwargs):
return self.pixel_wise_binary(numpy.maximum, x2, *args, **kwargs)
+ def minimum(self,x2, out=None, *args, **kwargs):
+ return self.pixel_wise_binary(numpy.minimum, x2=x2, out=out, *args, **kwargs)
+
+
## unary operations
def pixel_wise_unary(self, pwop, *args, **kwargs):
out = kwargs.get('out', None)
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py b/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py
new file mode 100644
index 0000000..389ec99
--- /dev/null
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py
@@ -0,0 +1,89 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Created on Wed Apr 10 13:39:35 2019
+ @author: jakob
+"""
+
+ # -*- coding: utf-8 -*-
+# This work is part of the Core Imaging Library developed by
+# Visual Analytics and Imaging System Group of the Science Technology
+# Facilities Council, STFC
+
+ # Copyright 2018 Edoardo Pasca
+
+ # Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+ # http://www.apache.org/licenses/LICENSE-2.0
+
+ # Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Created on Thu Feb 21 11:11:23 2019
+ @author: ofn77899
+"""
+
+ from ccpi.optimisation.algorithms import Algorithm
+from ccpi.framework import ImageData, AcquisitionData
+
+ #from collections.abc import Iterable
+class SIRT(Algorithm):
+
+ '''Simultaneous Iterative Reconstruction Technique
+ Parameters:
+ x_init: initial guess
+ operator: operator for forward/backward projections
+ data: data to operate on
+ constraint: Function with prox-method, for example IndicatorBox to
+ enforce box constraints.
+ '''
+ def __init__(self, **kwargs):
+ super(SIRT, self).__init__()
+ self.x = kwargs.get('x_init', None)
+ self.operator = kwargs.get('operator', None)
+ self.data = kwargs.get('data', None)
+ self.constraint = kwargs.get('data', None)
+ if self.x is not None and self.operator is not None and \
+ self.data is not None:
+ print ("Calling from creator")
+ self.set_up(x_init =kwargs['x_init'],
+ operator=kwargs['operator'],
+ data =kwargs['data'])
+
+ def set_up(self, x_init, operator , data, constraint=None ):
+
+ self.x = x_init.copy()
+ self.operator = operator
+ self.data = data
+ self.constraint = constraint
+
+ self.r = data.copy()
+
+ self.relax_par = 1.0
+
+ # Set up scaling matrices D and M.
+ im1 = ImageData(geometry=self.x.geometry)
+ im1.array[:] = 1.0
+ self.M = 1/operator.direct(im1)
+ aq1 = AcquisitionData(geometry=self.M.geometry)
+ aq1.array[:] = 1.0
+ self.D = 1/operator.adjoint(aq1)
+
+
+ def update(self):
+
+ self.r = self.data - self.operator.direct(self.x)
+
+ self.x += self.relax_par * (self.D*self.operator.adjoint(self.M*self.r))
+
+ if self.constraint != None:
+ self.x = self.constraint.prox(self.x,None)
+
+ def update_objective(self):
+ self.loss.append(self.r.squared_norm())
+ \ No newline at end of file
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/__init__.py b/Wrappers/Python/ccpi/optimisation/algorithms/__init__.py
index f562973..2dbacfc 100644
--- a/Wrappers/Python/ccpi/optimisation/algorithms/__init__.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/__init__.py
@@ -24,6 +24,7 @@ Created on Thu Feb 21 11:03:13 2019
from .Algorithm import Algorithm
from .CGLS import CGLS
+from .SIRT import SIRT
from .GradientDescent import GradientDescent
from .FISTA import FISTA
from .FBPD import FBPD
diff --git a/Wrappers/Python/ccpi/optimisation/algs.py b/Wrappers/Python/ccpi/optimisation/algs.py
index 2f819d3..5a95c5c 100755
--- a/Wrappers/Python/ccpi/optimisation/algs.py
+++ b/Wrappers/Python/ccpi/optimisation/algs.py
@@ -280,10 +280,6 @@ def SIRT(x_init, operator , data , opt=None, constraint=None):
tol = opt['tol']
max_iter = opt['iter']
- # Set default constraint to unconstrained
- if constraint==None:
- constraint = Function()
-
x = x_init.clone()
timing = numpy.zeros(max_iter)
@@ -307,7 +303,10 @@ def SIRT(x_init, operator , data , opt=None, constraint=None):
t = time.time()
r = data - operator.direct(x)
- x = constraint.prox(x + relax_par * (D*operator.adjoint(M*r)),None)
+ x = x + relax_par * (D*operator.adjoint(M*r))
+
+ if constraint != None:
+ x = constraint.prox(x,None)
timing[it] = time.time() - t
if it > 0:
diff --git a/Wrappers/Python/wip/demo_test_sirt.py b/Wrappers/Python/wip/demo_test_sirt.py
index 6f5a44d..c7a3c76 100644
--- a/Wrappers/Python/wip/demo_test_sirt.py
+++ b/Wrappers/Python/wip/demo_test_sirt.py
@@ -8,6 +8,9 @@ from ccpi.optimisation.algs import FISTA, FBPD, CGLS, SIRT
from ccpi.optimisation.funcs import Norm2sq, Norm1, TV2D, IndicatorBox
from ccpi.astra.ops import AstraProjectorSimple
+from ccpi.optimisation.algorithms import CGLS as CGLSALG
+from ccpi.optimisation.algorithms import SIRT as SIRTALG
+
import numpy as np
import matplotlib.pyplot as plt
@@ -25,6 +28,7 @@ x = Phantom.as_array()
x[round(N/4):round(3*N/4),round(N/4):round(3*N/4)] = 0.5
x[round(N/8):round(7*N/8),round(3*N/8):round(5*N/8)] = 1
+plt.figure()
plt.imshow(x)
plt.title('Phantom image')
plt.show()
@@ -69,10 +73,12 @@ Aop = AstraProjectorSimple(ig, ag, 'gpu')
b = Aop.direct(Phantom)
z = Aop.adjoint(b)
+plt.figure()
plt.imshow(b.array)
plt.title('Simulated data')
plt.show()
+plt.figure()
plt.imshow(z.array)
plt.title('Backprojected data')
plt.show()
@@ -81,96 +87,156 @@ plt.show()
# demonstrated in the rest of this file. In general all methods need an initial
# guess and some algorithm options to be set:
x_init = ImageData(np.zeros(x.shape),geometry=ig)
-opt = {'tol': 1e-4, 'iter': 1000}
+opt = {'tol': 1e-4, 'iter': 100}
# First a CGLS reconstruction can be done:
x_CGLS, it_CGLS, timing_CGLS, criter_CGLS = CGLS(x_init, Aop, b, opt)
+plt.figure()
plt.imshow(x_CGLS.array)
plt.title('CGLS')
plt.colorbar()
plt.show()
+plt.figure()
plt.semilogy(criter_CGLS)
plt.title('CGLS criterion')
plt.show()
+
+my_CGLS_alg = CGLSALG()
+my_CGLS_alg.set_up(x_init, Aop, b )
+my_CGLS_alg.max_iteration = 2000
+my_CGLS_alg.run(opt['iter'])
+x_CGLS_alg = my_CGLS_alg.get_output()
+
+plt.figure()
+plt.imshow(x_CGLS_alg.array)
+plt.title('CGLS ALG')
+plt.colorbar()
+plt.show()
+
+
# A SIRT unconstrained reconstruction can be done: similarly:
x_SIRT, it_SIRT, timing_SIRT, criter_SIRT = SIRT(x_init, Aop, b, opt)
+plt.figure()
plt.imshow(x_SIRT.array)
plt.title('SIRT unconstrained')
plt.colorbar()
plt.show()
+plt.figure()
plt.semilogy(criter_SIRT)
plt.title('SIRT unconstrained criterion')
plt.show()
+
+
+my_SIRT_alg = SIRTALG()
+my_SIRT_alg.set_up(x_init, Aop, b )
+my_SIRT_alg.max_iteration = 2000
+my_SIRT_alg.run(opt['iter'])
+x_SIRT_alg = my_SIRT_alg.get_output()
+
+plt.figure()
+plt.imshow(x_SIRT_alg.array)
+plt.title('SIRT ALG')
+plt.colorbar()
+plt.show()
+
# A SIRT nonnegativity constrained reconstruction can be done using the
# additional input "constraint" set to a box indicator function with 0 as the
# lower bound and the default upper bound of infinity:
x_SIRT0, it_SIRT0, timing_SIRT0, criter_SIRT0 = SIRT(x_init, Aop, b, opt,
constraint=IndicatorBox(lower=0))
-
+plt.figure()
plt.imshow(x_SIRT0.array)
plt.title('SIRT nonneg')
plt.colorbar()
plt.show()
+plt.figure()
plt.semilogy(criter_SIRT0)
plt.title('SIRT nonneg criterion')
plt.show()
+
+my_SIRT_alg0 = SIRTALG()
+my_SIRT_alg0.set_up(x_init, Aop, b, constraint=IndicatorBox(lower=0) )
+my_SIRT_alg0.max_iteration = 2000
+my_SIRT_alg0.run(opt['iter'])
+x_SIRT_alg0 = my_SIRT_alg0.get_output()
+
+plt.figure()
+plt.imshow(x_SIRT_alg0.array)
+plt.title('SIRT ALG0')
+plt.colorbar()
+plt.show()
+
+
# A SIRT reconstruction with box constraints on [0,1] can also be done:
x_SIRT01, it_SIRT01, timing_SIRT01, criter_SIRT01 = SIRT(x_init, Aop, b, opt,
constraint=IndicatorBox(lower=0,upper=1))
+plt.figure()
plt.imshow(x_SIRT01.array)
plt.title('SIRT box(0,1)')
plt.colorbar()
plt.show()
+plt.figure()
plt.semilogy(criter_SIRT01)
plt.title('SIRT box(0,1) criterion')
plt.show()
+my_SIRT_alg01 = SIRTALG()
+my_SIRT_alg01.set_up(x_init, Aop, b, constraint=IndicatorBox(lower=0,upper=1) )
+my_SIRT_alg01.max_iteration = 2000
+my_SIRT_alg01.run(opt['iter'])
+x_SIRT_alg01 = my_SIRT_alg01.get_output()
+
+plt.figure()
+plt.imshow(x_SIRT_alg01.array)
+plt.title('SIRT ALG01')
+plt.colorbar()
+plt.show()
+
# The indicator function can also be used with the FISTA algorithm to do
# least squares with nonnegativity constraint.
+'''
# Create least squares object instance with projector, test data and a constant
# coefficient of 0.5:
f = Norm2sq(Aop,b,c=0.5)
-
# Run FISTA for least squares without constraints
x_fista, it, timing, criter = FISTA(x_init, f, None,opt)
-
+plt.figure()
plt.imshow(x_fista.array)
plt.title('FISTA Least squares')
plt.show()
-
+plt.figure()
plt.semilogy(criter)
plt.title('FISTA Least squares criterion')
plt.show()
-
# Run FISTA for least squares with nonnegativity constraint
x_fista0, it0, timing0, criter0 = FISTA(x_init, f, IndicatorBox(lower=0),opt)
-
+plt.figure()
plt.imshow(x_fista0.array)
plt.title('FISTA Least squares nonneg')
plt.show()
-
+plt.figure()
plt.semilogy(criter0)
plt.title('FISTA Least squares nonneg criterion')
plt.show()
-
# Run FISTA for least squares with box constraint [0,1]
x_fista01, it01, timing01, criter01 = FISTA(x_init, f, IndicatorBox(lower=0,upper=1),opt)
-
+plt.figure()
plt.imshow(x_fista01.array)
plt.title('FISTA Least squares box(0,1)')
plt.show()
-
+plt.figure()
plt.semilogy(criter01)
plt.title('FISTA Least squares box(0,1) criterion')
-plt.show() \ No newline at end of file
+plt.show()
+''' \ No newline at end of file