diff options
Diffstat (limited to 'Wrappers')
-rwxr-xr-x | Wrappers/Python/ccpi/framework/framework.py | 4 | ||||
-rw-r--r-- | Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py | 89 | ||||
-rw-r--r-- | Wrappers/Python/ccpi/optimisation/algorithms/__init__.py | 1 | ||||
-rwxr-xr-x | Wrappers/Python/ccpi/optimisation/algs.py | 9 | ||||
-rw-r--r-- | Wrappers/Python/wip/demo_test_sirt.py | 90 |
5 files changed, 176 insertions, 17 deletions
diff --git a/Wrappers/Python/ccpi/framework/framework.py b/Wrappers/Python/ccpi/framework/framework.py index 7516447..ffc91ae 100755 --- a/Wrappers/Python/ccpi/framework/framework.py +++ b/Wrappers/Python/ccpi/framework/framework.py @@ -707,6 +707,10 @@ class DataContainer(object): def maximum(self, x2, *args, **kwargs): return self.pixel_wise_binary(numpy.maximum, x2, *args, **kwargs) + def minimum(self,x2, out=None, *args, **kwargs): + return self.pixel_wise_binary(numpy.minimum, x2=x2, out=out, *args, **kwargs) + + ## unary operations def pixel_wise_unary(self, pwop, *args, **kwargs): out = kwargs.get('out', None) diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py b/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py new file mode 100644 index 0000000..389ec99 --- /dev/null +++ b/Wrappers/Python/ccpi/optimisation/algorithms/SIRT.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Wed Apr 10 13:39:35 2019 + @author: jakob +""" + + # -*- coding: utf-8 -*- +# This work is part of the Core Imaging Library developed by +# Visual Analytics and Imaging System Group of the Science Technology +# Facilities Council, STFC + + # Copyright 2018 Edoardo Pasca + + # Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + + # http://www.apache.org/licenses/LICENSE-2.0 + + # Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Created on Thu Feb 21 11:11:23 2019 + @author: ofn77899 +""" + + from ccpi.optimisation.algorithms import Algorithm +from ccpi.framework import ImageData, AcquisitionData + + #from collections.abc import Iterable +class SIRT(Algorithm): + + '''Simultaneous Iterative Reconstruction Technique + Parameters: + x_init: initial guess + operator: operator for forward/backward projections + data: data to operate on + constraint: Function with prox-method, for example IndicatorBox to + enforce box constraints. + ''' + def __init__(self, **kwargs): + super(SIRT, self).__init__() + self.x = kwargs.get('x_init', None) + self.operator = kwargs.get('operator', None) + self.data = kwargs.get('data', None) + self.constraint = kwargs.get('data', None) + if self.x is not None and self.operator is not None and \ + self.data is not None: + print ("Calling from creator") + self.set_up(x_init =kwargs['x_init'], + operator=kwargs['operator'], + data =kwargs['data']) + + def set_up(self, x_init, operator , data, constraint=None ): + + self.x = x_init.copy() + self.operator = operator + self.data = data + self.constraint = constraint + + self.r = data.copy() + + self.relax_par = 1.0 + + # Set up scaling matrices D and M. + im1 = ImageData(geometry=self.x.geometry) + im1.array[:] = 1.0 + self.M = 1/operator.direct(im1) + aq1 = AcquisitionData(geometry=self.M.geometry) + aq1.array[:] = 1.0 + self.D = 1/operator.adjoint(aq1) + + + def update(self): + + self.r = self.data - self.operator.direct(self.x) + + self.x += self.relax_par * (self.D*self.operator.adjoint(self.M*self.r)) + + if self.constraint != None: + self.x = self.constraint.prox(self.x,None) + + def update_objective(self): + self.loss.append(self.r.squared_norm()) +
\ No newline at end of file diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/__init__.py b/Wrappers/Python/ccpi/optimisation/algorithms/__init__.py index f562973..2dbacfc 100644 --- a/Wrappers/Python/ccpi/optimisation/algorithms/__init__.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/__init__.py @@ -24,6 +24,7 @@ Created on Thu Feb 21 11:03:13 2019 from .Algorithm import Algorithm from .CGLS import CGLS +from .SIRT import SIRT from .GradientDescent import GradientDescent from .FISTA import FISTA from .FBPD import FBPD diff --git a/Wrappers/Python/ccpi/optimisation/algs.py b/Wrappers/Python/ccpi/optimisation/algs.py index 2f819d3..5a95c5c 100755 --- a/Wrappers/Python/ccpi/optimisation/algs.py +++ b/Wrappers/Python/ccpi/optimisation/algs.py @@ -280,10 +280,6 @@ def SIRT(x_init, operator , data , opt=None, constraint=None): tol = opt['tol'] max_iter = opt['iter'] - # Set default constraint to unconstrained - if constraint==None: - constraint = Function() - x = x_init.clone() timing = numpy.zeros(max_iter) @@ -307,7 +303,10 @@ def SIRT(x_init, operator , data , opt=None, constraint=None): t = time.time() r = data - operator.direct(x) - x = constraint.prox(x + relax_par * (D*operator.adjoint(M*r)),None) + x = x + relax_par * (D*operator.adjoint(M*r)) + + if constraint != None: + x = constraint.prox(x,None) timing[it] = time.time() - t if it > 0: diff --git a/Wrappers/Python/wip/demo_test_sirt.py b/Wrappers/Python/wip/demo_test_sirt.py index 6f5a44d..c7a3c76 100644 --- a/Wrappers/Python/wip/demo_test_sirt.py +++ b/Wrappers/Python/wip/demo_test_sirt.py @@ -8,6 +8,9 @@ from ccpi.optimisation.algs import FISTA, FBPD, CGLS, SIRT from ccpi.optimisation.funcs import Norm2sq, Norm1, TV2D, IndicatorBox from ccpi.astra.ops import AstraProjectorSimple +from ccpi.optimisation.algorithms import CGLS as CGLSALG +from ccpi.optimisation.algorithms import SIRT as SIRTALG + import numpy as np import matplotlib.pyplot as plt @@ -25,6 +28,7 @@ x = Phantom.as_array() x[round(N/4):round(3*N/4),round(N/4):round(3*N/4)] = 0.5 x[round(N/8):round(7*N/8),round(3*N/8):round(5*N/8)] = 1 +plt.figure() plt.imshow(x) plt.title('Phantom image') plt.show() @@ -69,10 +73,12 @@ Aop = AstraProjectorSimple(ig, ag, 'gpu') b = Aop.direct(Phantom) z = Aop.adjoint(b) +plt.figure() plt.imshow(b.array) plt.title('Simulated data') plt.show() +plt.figure() plt.imshow(z.array) plt.title('Backprojected data') plt.show() @@ -81,96 +87,156 @@ plt.show() # demonstrated in the rest of this file. In general all methods need an initial # guess and some algorithm options to be set: x_init = ImageData(np.zeros(x.shape),geometry=ig) -opt = {'tol': 1e-4, 'iter': 1000} +opt = {'tol': 1e-4, 'iter': 100} # First a CGLS reconstruction can be done: x_CGLS, it_CGLS, timing_CGLS, criter_CGLS = CGLS(x_init, Aop, b, opt) +plt.figure() plt.imshow(x_CGLS.array) plt.title('CGLS') plt.colorbar() plt.show() +plt.figure() plt.semilogy(criter_CGLS) plt.title('CGLS criterion') plt.show() + +my_CGLS_alg = CGLSALG() +my_CGLS_alg.set_up(x_init, Aop, b ) +my_CGLS_alg.max_iteration = 2000 +my_CGLS_alg.run(opt['iter']) +x_CGLS_alg = my_CGLS_alg.get_output() + +plt.figure() +plt.imshow(x_CGLS_alg.array) +plt.title('CGLS ALG') +plt.colorbar() +plt.show() + + # A SIRT unconstrained reconstruction can be done: similarly: x_SIRT, it_SIRT, timing_SIRT, criter_SIRT = SIRT(x_init, Aop, b, opt) +plt.figure() plt.imshow(x_SIRT.array) plt.title('SIRT unconstrained') plt.colorbar() plt.show() +plt.figure() plt.semilogy(criter_SIRT) plt.title('SIRT unconstrained criterion') plt.show() + + +my_SIRT_alg = SIRTALG() +my_SIRT_alg.set_up(x_init, Aop, b ) +my_SIRT_alg.max_iteration = 2000 +my_SIRT_alg.run(opt['iter']) +x_SIRT_alg = my_SIRT_alg.get_output() + +plt.figure() +plt.imshow(x_SIRT_alg.array) +plt.title('SIRT ALG') +plt.colorbar() +plt.show() + # A SIRT nonnegativity constrained reconstruction can be done using the # additional input "constraint" set to a box indicator function with 0 as the # lower bound and the default upper bound of infinity: x_SIRT0, it_SIRT0, timing_SIRT0, criter_SIRT0 = SIRT(x_init, Aop, b, opt, constraint=IndicatorBox(lower=0)) - +plt.figure() plt.imshow(x_SIRT0.array) plt.title('SIRT nonneg') plt.colorbar() plt.show() +plt.figure() plt.semilogy(criter_SIRT0) plt.title('SIRT nonneg criterion') plt.show() + +my_SIRT_alg0 = SIRTALG() +my_SIRT_alg0.set_up(x_init, Aop, b, constraint=IndicatorBox(lower=0) ) +my_SIRT_alg0.max_iteration = 2000 +my_SIRT_alg0.run(opt['iter']) +x_SIRT_alg0 = my_SIRT_alg0.get_output() + +plt.figure() +plt.imshow(x_SIRT_alg0.array) +plt.title('SIRT ALG0') +plt.colorbar() +plt.show() + + # A SIRT reconstruction with box constraints on [0,1] can also be done: x_SIRT01, it_SIRT01, timing_SIRT01, criter_SIRT01 = SIRT(x_init, Aop, b, opt, constraint=IndicatorBox(lower=0,upper=1)) +plt.figure() plt.imshow(x_SIRT01.array) plt.title('SIRT box(0,1)') plt.colorbar() plt.show() +plt.figure() plt.semilogy(criter_SIRT01) plt.title('SIRT box(0,1) criterion') plt.show() +my_SIRT_alg01 = SIRTALG() +my_SIRT_alg01.set_up(x_init, Aop, b, constraint=IndicatorBox(lower=0,upper=1) ) +my_SIRT_alg01.max_iteration = 2000 +my_SIRT_alg01.run(opt['iter']) +x_SIRT_alg01 = my_SIRT_alg01.get_output() + +plt.figure() +plt.imshow(x_SIRT_alg01.array) +plt.title('SIRT ALG01') +plt.colorbar() +plt.show() + # The indicator function can also be used with the FISTA algorithm to do # least squares with nonnegativity constraint. +''' # Create least squares object instance with projector, test data and a constant # coefficient of 0.5: f = Norm2sq(Aop,b,c=0.5) - # Run FISTA for least squares without constraints x_fista, it, timing, criter = FISTA(x_init, f, None,opt) - +plt.figure() plt.imshow(x_fista.array) plt.title('FISTA Least squares') plt.show() - +plt.figure() plt.semilogy(criter) plt.title('FISTA Least squares criterion') plt.show() - # Run FISTA for least squares with nonnegativity constraint x_fista0, it0, timing0, criter0 = FISTA(x_init, f, IndicatorBox(lower=0),opt) - +plt.figure() plt.imshow(x_fista0.array) plt.title('FISTA Least squares nonneg') plt.show() - +plt.figure() plt.semilogy(criter0) plt.title('FISTA Least squares nonneg criterion') plt.show() - # Run FISTA for least squares with box constraint [0,1] x_fista01, it01, timing01, criter01 = FISTA(x_init, f, IndicatorBox(lower=0,upper=1),opt) - +plt.figure() plt.imshow(x_fista01.array) plt.title('FISTA Least squares box(0,1)') plt.show() - +plt.figure() plt.semilogy(criter01) plt.title('FISTA Least squares box(0,1) criterion') -plt.show()
\ No newline at end of file +plt.show() +'''
\ No newline at end of file |