diff options
author | Edoardo Pasca <edo.paskino@gmail.com> | 2019-06-04 16:29:45 +0100 |
---|---|---|
committer | Edoardo Pasca <edo.paskino@gmail.com> | 2019-06-04 16:29:45 +0100 |
commit | 910a81042068b084278783b53bac45ad63b852d2 (patch) | |
tree | d3f82cd74b9d8c4e4cf41a4c22fa7b9c5400e2e8 /Wrappers | |
parent | 3271aa9c4ca50177bf2f9e37269efa03f52f635e (diff) | |
download | framework-910a81042068b084278783b53bac45ad63b852d2.tar.gz framework-910a81042068b084278783b53bac45ad63b852d2.tar.bz2 framework-910a81042068b084278783b53bac45ad63b852d2.tar.xz framework-910a81042068b084278783b53bac45ad63b852d2.zip |
progress
Diffstat (limited to 'Wrappers')
7 files changed, 149 insertions, 12 deletions
diff --git a/Wrappers/Python/ccpi/framework/VectorData.py b/Wrappers/Python/ccpi/framework/VectorData.py new file mode 100755 index 0000000..fdce3a5 --- /dev/null +++ b/Wrappers/Python/ccpi/framework/VectorData.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# This work is part of the Core Imaging Library developed by +# Visual Analytics and Imaging System Group of the Science Technology +# Facilities Council, STFC + +# Copyright 2018-2019 Edoardo Pasca + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import numpy +import sys +from datetime import timedelta, datetime +import warnings +from functools import reduce +from numbers import Number +from ccpi.framework import DataContainer, VectorGeometry + +class VectorData(DataContainer): + def __init__(self, array=None, **kwargs): + self.geometry = kwargs.get('geometry', None) + self.dtype = kwargs.get('dtype', numpy.float32) + + if self.geometry is None: + if array is None: + raise ValueError('Please specify either a geometry or an array') + else: + if len(array.shape) > 1: + raise ValueError('Incompatible size: expected 1D got {}'.format(array.shape)) + out = array + self.geometry = VectorGeometry.VectorGeometry(array.shape[0]) + self.length = self.geometry.length + else: + self.length = self.geometry.length + + if array is None: + out = numpy.zeros((self.length,), dtype=self.dtype) + else: + if self.length == array.shape[0]: + out = array + else: + raise ValueError('Incompatible size: expecting {} got {}'.format((self.length,), array.shape)) + deep_copy = False + super(VectorData, self).__init__(out, deep_copy, None) diff --git a/Wrappers/Python/ccpi/framework/VectorGeometry.py b/Wrappers/Python/ccpi/framework/VectorGeometry.py new file mode 100755 index 0000000..255d2a0 --- /dev/null +++ b/Wrappers/Python/ccpi/framework/VectorGeometry.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +# This work is part of the Core Imaging Library developed by +# Visual Analytics and Imaging System Group of the Science Technology +# Facilities Council, STFC + +# Copyright 2018-2019 Edoardo Pasca + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import numpy +import sys +from datetime import timedelta, datetime +import warnings +from functools import reduce +from numbers import Number +from ccpi.framework.VectorData import VectorData + +class VectorGeometry(object): + RANDOM = 'random' + RANDOM_INT = 'random_int' + + def __init__(self, + length): + + self.length = length + self.shape = (length, ) + + + def clone(self): + '''returns a copy of VectorGeometry''' + return VectorGeometry(self.length) + + def allocate(self, value=0, **kwargs): + '''allocates an VectorData according to the size expressed in the instance''' + self.dtype = kwargs.get('dtype', numpy.float32) + out = VectorData(geometry=self, dtype=self.dtype) + if isinstance(value, Number): + if value != 0: + out += value + else: + if value == VectorGeometry.RANDOM: + seed = kwargs.get('seed', None) + if seed is not None: + numpy.random.seed(seed) + out.fill(numpy.random.random_sample(self.shape)) + elif value == VectorGeometry.RANDOM_INT: + seed = kwargs.get('seed', None) + if seed is not None: + numpy.random.seed(seed) + max_value = kwargs.get('max_value', 100) + out.fill(numpy.random.randint(max_value,size=self.shape)) + else: + raise ValueError('Value {} unknown'.format(value)) + return out
\ No newline at end of file diff --git a/Wrappers/Python/ccpi/framework/__init__.py b/Wrappers/Python/ccpi/framework/__init__.py index 229edb5..9cec708 100755 --- a/Wrappers/Python/ccpi/framework/__init__.py +++ b/Wrappers/Python/ccpi/framework/__init__.py @@ -24,3 +24,5 @@ from .framework import DataProcessor from .framework import AX, PixelByPixelDataProcessor, CastDataContainer
from .BlockDataContainer import BlockDataContainer
from .BlockGeometry import BlockGeometry
+from .VectorGeometry import VectorGeometry
+from .VectorData import VectorData
diff --git a/Wrappers/Python/ccpi/framework/framework.py b/Wrappers/Python/ccpi/framework/framework.py index 7236e0e..b972be6 100755 --- a/Wrappers/Python/ccpi/framework/framework.py +++ b/Wrappers/Python/ccpi/framework/framework.py @@ -29,7 +29,6 @@ import warnings from functools import reduce from numbers import Number - def find_key(dic, val): """return the key of dictionary dic given the value""" return [k for k, v in dic.items() if v == val][0] diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py index 04e7c38..db4e8b7 100755 --- a/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/FISTA.py @@ -34,7 +34,7 @@ class FISTA(Algorithm): self.invL = None self.t_old = 1 if self.f is not None and self.g is not None: - print ("Calling from creator") + print ("FISTA initialising from creator") self.set_up(self.x_init, self.f, self.g) @@ -46,6 +46,8 @@ class FISTA(Algorithm): # algorithmic parameters if opt is None: opt = {'tol': 1e-4} + print(self.x_init.as_array()) + print(x_init.as_array()) self.y = x_init.copy() self.x_old = x_init.copy() @@ -60,18 +62,28 @@ class FISTA(Algorithm): def update(self): self.f.gradient(self.y, out=self.u) + print ('update, self.u' , self.u.as_array()) self.u.__imul__( -self.invL ) self.u.__iadd__( self.y ) + print ('update, self.u' , self.u.as_array()) + # x = g.prox(u,invL) + print ('update, self.x pre prox' , self.x.as_array()) self.g.proximal(self.u, self.invL, out=self.x) + print ('update, self.x post prox' , self.x.as_array()) self.t = 0.5*(1 + numpy.sqrt(1 + 4*(self.t_old**2))) self.x.subtract(self.x_old, out=self.y) + print ('update, self.y' , self.y.as_array()) + self.y.__imul__ ((self.t_old-1)/self.t) + print ('update, self.x' , self.x.as_array()) self.y.__iadd__( self.x ) + print ('update, self.y' , self.y.as_array()) self.x_old.fill(self.x) + print ('update, self.x_old' , self.x_old.as_array()) self.t_old = self.t def update_objective(self): diff --git a/Wrappers/Python/ccpi/optimisation/functions/Norm2Sq.py b/Wrappers/Python/ccpi/optimisation/functions/Norm2Sq.py index d9d9010..8e77f56 100755 --- a/Wrappers/Python/ccpi/optimisation/functions/Norm2Sq.py +++ b/Wrappers/Python/ccpi/optimisation/functions/Norm2Sq.py @@ -88,7 +88,6 @@ class Norm2Sq(Function): def gradient(self, x, out = None): if self.memopt: #return 2.0*self.c*self.A.adjoint( self.A.direct(x) - self.b ) - print (self.range_tmp, self.range_tmp.as_array()) self.A.direct(x, out=self.range_tmp) self.range_tmp -= self.b self.A.adjoint(self.range_tmp, out=out) diff --git a/Wrappers/Python/ccpi/optimisation/operators/LinearOperatorMatrix.py b/Wrappers/Python/ccpi/optimisation/operators/LinearOperatorMatrix.py index 6306192..292db2c 100644 --- a/Wrappers/Python/ccpi/optimisation/operators/LinearOperatorMatrix.py +++ b/Wrappers/Python/ccpi/optimisation/operators/LinearOperatorMatrix.py @@ -2,8 +2,8 @@ import numpy from scipy.sparse.linalg import svds from ccpi.framework import DataContainer from ccpi.framework import AcquisitionData -from ccpi.framework import ImageData -from ccpi.framework import ImageGeometry +from ccpi.framework import VectorData +from ccpi.framework import VectorGeometry from ccpi.framework import AcquisitionGeometry from numbers import Number from ccpi.optimisation.operators import LinearOperator @@ -11,8 +11,8 @@ class LinearOperatorMatrix(LinearOperator): def __init__(self,A): self.A = A M_A, N_A = self.A.shape - self.gm_domain = ImageGeometry(0, N_A) - self.gm_range = ImageGeometry(M_A,0) + self.gm_domain = VectorGeometry(N_A) + self.gm_range = VectorGeometry(M_A) self.s1 = None # Largest singular value, initially unknown super(LinearOperatorMatrix, self).__init__() @@ -21,18 +21,16 @@ class LinearOperatorMatrix(LinearOperator): return type(x)(numpy.dot(self.A,x.as_array())) else: numpy.dot(self.A, x.as_array(), out=out.as_array()) - - + def adjoint(self,x, out=None): if out is None: return type(x)(numpy.dot(self.A.transpose(),x.as_array())) else: numpy.dot(self.A.transpose(),x.as_array(), out=out.as_array()) - - + def size(self): return self.A.shape - + def norm(self): # If unknown, compute and store. If known, simply return it. if self.s1 is None: |