示例#1
0
def plotCost(cost_list, output=None):
    """Plot cost function

    Plot the final cost function

    Parameters
    ----------
    cost_list : list
        List of cost function values
    output : str, optional
        Output file name

    """

    if not import_fail:

        if isinstance(output, type(None)):
            file_name = 'cost_function.png'
        else:
            file_name = output + '_cost_function.png'

        plt.figure()
        plt.plot(np.log10(cost_list), 'r-')
        plt.title('Cost Function')
        plt.xlabel('Iteration')
        plt.ylabel(r'$\log_{10}$ Cost')
        plt.savefig(file_name)
        plt.close()

        print(' - Saving cost function data to:', file_name)

    else:

        warn('Matplotlib not installed.')
示例#2
0
def get_backend(backend):
    """Get backend.

    Returns the backend module for input specified by string.

    Parameters
    ----------
    backend: str
        String holding the backend name. One of ``'tensorflow'``,
        ``'numpy'`` or ``'cupy'``.

    Returns
    -------
    tuple
        Returns the module for carrying out calculations and the actual backend
        that was reverted towards. If the right libraries are not installed,
        the function warns and reverts to the ``'numpy'`` backend.
    """
    if backend not in LIBRARIES.keys() or LIBRARIES[backend] is None:
        msg = ('{0} backend not possible, please ensure that ' +
               'the optional libraries are installed.\n' +
               'Reverting to numpy.')
        warn(msg.format(backend))
        backend = 'numpy'
    return LIBRARIES[backend], backend
示例#3
0
    def __init__(self, shape, n_coils=1, samples=None, mask=None, n_jobs=1):
        """ Initilize the 'FFT' class.

        Parameters
        ----------
        shape: tuple of int
            shape of the image (not necessarly a square matrix).
        n_coils: int, default 1
            Number of coils used to acquire the signal in case of
            multiarray receiver coils acquisition. If n_coils > 1,
            data shape must be equal to [n_coils, Nx, Ny, NZ]
        samples: np.ndarray, default None
            the mask samples in the Fourier domain.
        mask: np.ndarray, default None
            the mask as a matrix with 1 at sample locations
            please pass samples or mask
        n_jobs: int, default 1
            Number of parallel workers to use for fourier computation
            All cores are used if -1
        """
        self.shape = shape
        if mask is None and samples is None:
            raise ValueError("Please pass either samples or mask as input")
        if mask is None:
            self.mask = convert_locations_to_mask(samples, self.shape)
            self.samples = samples
        else:
            self.mask = mask
            self.samples = convert_mask_to_locations(mask)
        if n_coils <= 0:
            warn("The number of coils should be strictly positive")
            n_coils = 1
        self.n_coils = n_coils
        self.n_jobs = n_jobs
示例#4
0
    def _check_operator(self, operator):
        """Check set-Up.

        This method checks algorithm operator against the expected parent
        classes

        Parameters
        ----------
        operator : str
            Algorithm operator to check

        """
        if not isinstance(operator, type(None)):
            tree = [op_obj.__name__ for op_obj in getmro(operator.__class__)]

            if not any(parent in tree for parent in self._op_parents):
                message = '{0} does not inherit an operator parent.'
                warn(message.format(str(operator.__class__)))
示例#5
0
def check_npndarray(input_obj, dtype=None, writeable=True, verbose=True):
    """Check Numpy ND-Array.

    Check if input object is a numpy array.

    Parameters
    ----------
    input_obj : numpy.ndarray
        Input object
    dtype : type
        Numpy ndarray data type
    writeable : bool
        Option to make array immutable
    verbose : bool
        Verbosity option

    Raises
    ------
    TypeError
        For invalid input type
    TypeError
        For invalid numpy.ndarray dtype

    """
    if not isinstance(input_obj, np.ndarray):
        raise TypeError('Input is not a numpy array.')

    if (
        (not isinstance(dtype, type(None)))
        and (not np.issubdtype(input_obj.dtype, dtype))
    ):
        raise (
            TypeError(
                'The numpy array elements are not of type: {0}'.format(dtype),
            ),
        )

    if not writeable and verbose and input_obj.flags.writeable:
        warn('Making input data immutable.')

    input_obj.flags.writeable = writeable
示例#6
0
    def __init__(self, samples, shape, n_coils=1):
        """ Initilize the 'FFT' class.

        Parameters
        ----------
        samples: np.ndarray
            the mask samples in the Fourier domain.
        shape: tuple of int
            shape of the image (not necessarly a square matrix).
         n_coils: int, default 1
                Number of coils used to acquire the signal in case of
                multiarray receiver coils acquisition. If n_coils > 1,
                 data shape must be equal to [n_coils, Nx, Ny, NZ]
        """
        self.samples = samples
        self.shape = shape
        self._mask = convert_locations_to_mask(self.samples, self.shape)
        if n_coils <= 0:
            warn("The number of coils should be strictly positive")
            n_coils = 1
        self.n_coils = n_coils
示例#7
0
    def _op_method(self, input_data, extra_factor=1.0):
        """Operator.

        This method returns the proximity operator of the squared k-support
        norm. Implements (Alg. 1) in :cite:`mcdonald2014`.

        Parameters
        ----------
        input_data : numpy.ndarray
            Input data array
        extra_factor : float
            Additional multiplication factor (default is ``1.0``)

        Returns
        -------
        numpy.ndarray
            Proximal map

        """
        data_shape = input_data.shape
        k_max = np.prod(data_shape)
        if self._k_value > k_max:
            warn(
                'K value of the K-support norm is greater than the input '
                + 'dimension, its value will be set to {0}'.format(k_max),
            )
            self._k_value = k_max

        # Computes line 1., 2. and 3. in Algorithm 1
        alpha = self._find_alpha(np.abs(input_data.flatten()), extra_factor)

        # Computes line 4. in Algorithm 1
        theta = self._compute_theta(np.abs(input_data.flatten()), alpha)

        # Computes line 5. in Algorithm 1.
        rslt = np.nan_to_num(
            (input_data.flatten() * theta)
            / (theta + self.beta * extra_factor),
        )
        return rslt.reshape(data_shape)
示例#8
0
文件: types.py 项目: zwq1230/ModOpt
def check_npndarray(val, dtype=None, writeable=True, verbose=True):
    """Check if input object is a numpy array.

    Parameters
    ----------
    val : numpy.ndarray
        Input object

    """

    if not isinstance(val, np.ndarray):
        raise TypeError('Input is not a numpy array.')

    if ((not isinstance(dtype, type(None))) and
            (not np.issubdtype(val.dtype, dtype))):
        raise TypeError('The numpy array elements are not of type: {}'
                        ''.format(dtype))

    if not writeable and verbose and val.flags.writeable:
        warn('Making input data immutable.')

    val.flags.writeable = writeable
示例#9
0
def check_psf(data, log):
    """Check PSF

    This method checks that the input PSFs are properly normalised and updates
    the opts namespace with the PSF type

    Parameters
    ----------
    data : np.ndarray
        Input data array
    log : logging.Logger
        Log instance

    """

    psf_sum = np.sum(data, axis=tuple(range(data.ndim - 2, data.ndim)))

    if not np.all(np.abs(psf_sum - 1) < 1e-5):
        warn('Not all PSFs integrate to 1.0.')
        log.info(' - Not all PSFs integrate to 1.0.')

    if data.ndim == 2:
        opts.psf_type = 'fixed'
示例#10
0
def check_psf(data, log):
    """Check PSF

    This method checks that the input PSFs are properly normalised and updates
    the opts namespace with the PSF type

    Parameters
    ----------
    data : np.ndarray
        Input data array
    log : logging.Logger
        Log instance

    """

    psf_sum = np.sum(data, axis=tuple(range(data.ndim - 2, data.ndim)))

    if not np.all(np.abs(psf_sum - 1) < 1e-5):
        warn('Not all PSFs integrate to 1.0.')
        log.info(' - Not all PSFs integrate to 1.0.')

    if data.ndim == 2:
        opts.psf_type = 'fixed'
示例#11
0
文件: svd.py 项目: ljpyanne/ModOpt
def svd_thresh(data, threshold=None, n_pc=None, thresh_type='hard'):
    r"""Threshold the singular values

    This method thresholds the input data using singular value decomposition

    Parameters
    ----------
    data : np.ndarray
        Input data array, 2D matrix
    threshold : float or np.ndarray, optional
        Threshold value(s)
    n_pc : int or str, optional
        Number of principal components, specify an integer value or 'all'
    threshold_type : str {'hard', 'soft'}, optional
        Type of thresholding (default is 'hard')

    Returns
    -------
    np.ndarray thresholded data

    Raises
    ------
    ValueError
        For invalid n_pc value

    Examples
    --------
    >>> from modopt.signal.svd import svd_thresh
    >>> x = np.arange(18).reshape(9, 2).astype(float)
    >>> svd_thresh(x, n_pc=1)
    array([[  0.49815487,   0.54291537],
           [  2.40863386,   2.62505584],
           [  4.31911286,   4.70719631],
           [  6.22959185,   6.78933678],
           [  8.14007085,   8.87147725],
           [ 10.05054985,  10.95361772],
           [ 11.96102884,  13.03575819],
           [ 13.87150784,  15.11789866],
           [ 15.78198684,  17.20003913]])

    """

    if ((not isinstance(n_pc, (int, str, type(None)))) or
            (isinstance(n_pc, int) and n_pc <= 0) or
            (isinstance(n_pc, str) and n_pc != 'all')):
        raise ValueError('Invalid value for "n_pc", specify a positive '
                         'integer value or "all"')

    # Get SVD of input data.
    u, s, v = calculate_svd(data)

    # Find the threshold if not provided.
    if isinstance(threshold, type(None)):

        # Find the required number of principal components if not specified.
        if isinstance(n_pc, type(None)):
            n_pc = find_n_pc(u, factor=0.1)

        # If the number of PCs is too large use all of the singular values.
        if ((isinstance(n_pc, int) and n_pc >= s.size) or
                (isinstance(n_pc, str) and n_pc == 'all')):
            n_pc = s.size
            warn('Using all singular values.')

        threshold = s[n_pc - 1]

    # Threshold the singular values.
    s_new = thresh(s, threshold, thresh_type)

    if np.all(s_new == s):
        warn('No change to singular values.')

    # Diagonalize the svd
    s_new = np.diag(s_new)

    # Return the thresholded data.
    return np.dot(u, np.dot(s_new, v))
示例#12
0
文件: wavelet.py 项目: NiCadet/ModOpt
def call_mr_transform(data,
                      opt='',
                      path='./',
                      remove_files=True):  # pragma: no cover
    r"""Call mr_transform

    This method calls the iSAP module mr_transform

    Parameters
    ----------
    data : np.ndarray
        Input data, 2D array
    opt : list or str, optional
        Options to be passed to mr_transform
    path : str, optional
        Path for output files (default is './')
    remove_files : bool, optional
        Option to remove output files (default is 'True')

    Returns
    -------
    np.ndarray results of mr_transform

    Raises
    ------
    ValueError
        If the input data is not a 2D numpy array

    Examples
    --------
    >>> from modopt.signal.wavelet import *
    >>> a = np.arange(9).reshape(3, 3).astype(float)
    >>> call_mr_transform(a)
    array([[[-1.5       , -1.125     , -0.75      ],
            [-0.375     ,  0.        ,  0.375     ],
            [ 0.75      ,  1.125     ,  1.5       ]],

           [[-1.5625    , -1.171875  , -0.78125   ],
            [-0.390625  ,  0.        ,  0.390625  ],
            [ 0.78125   ,  1.171875  ,  1.5625    ]],

           [[-0.5859375 , -0.43945312, -0.29296875],
            [-0.14648438,  0.        ,  0.14648438],
            [ 0.29296875,  0.43945312,  0.5859375 ]],

           [[ 3.6484375 ,  3.73632812,  3.82421875],
            [ 3.91210938,  4.        ,  4.08789062],
            [ 4.17578125,  4.26367188,  4.3515625 ]]], dtype=float32)

    """

    if not import_astropy:
        raise ImportError('Astropy package not found.')

    if (not isinstance(data, np.ndarray)) or (data.ndim != 2):
        raise ValueError('Input data must be a 2D numpy array.')

    executable = 'mr_transform'

    # Make sure mr_transform is installed.
    is_executable(executable)

    # Create a unique string using the current date and time.
    unique_string = datetime.now().strftime('%Y.%m.%d_%H.%M.%S')

    # Set the ouput file names.
    file_name = path + 'mr_temp_' + unique_string
    file_fits = file_name + '.fits'
    file_mr = file_name + '.mr'

    # Write the input data to a fits file.
    fits.writeto(file_fits, data)

    if isinstance(opt, str):
        opt = opt.split()

    # Call mr_transform.
    try:

        check_call([executable] + opt + [file_fits, file_mr])

    except Exception:

        warn('{} failed to run with the options provided.'.format(executable))
        remove(file_fits)

    else:

        # Retrieve wavelet transformed data.
        result = fits.getdata(file_mr)

        # Remove the temporary files.
        if remove_files:
            remove(file_fits)
            remove(file_mr)

        # Return the mr_transform results.
        return result
示例#13
0
import numpy as np
import matplotlib
from sf_tools.image.stamp import FetchStamps
matplotlib.use("Agg")
try:
    from pysap import load_transform
except ImportError:
    from astropy.convolution import convolve_fft
    from modopt.interface.errors import warn
    warn('PySAP not installed, using backup filters.')
    import_pysap = False
else:
    import_pysap = True


def convolve(data, kernel):
    """ Convolve

    Convolve input data with kernel.

    Parameters
    ----------
    data : np.ndarray
        Input 2D data array
    kernel : np.ndarray
        Input 2D kernel array

    Returns
    -------
    np.ndarray
        Convolved array
示例#14
0
"""

from __future__ import division
from builtins import zip
import numpy as np
import scipy.signal
from astropy.convolution import convolve_fft
from modopt.base.np_adjust import rotate_stack
from modopt.interface.errors import warn
try:
    import pyfftw
except ImportError:  # pragma: no cover
    pass
else:  # pragma: no cover
    scipy.fftpack = pyfftw.interfaces.scipy_fftpack
    warn('Using pyFFTW "monkey patch" for scipy.fftpack')


def convolve(data, kernel, method='astropy'):
    r"""Convolve data with kernel

    This method convolves the input data with a given kernel using FFT and
    is the default convolution used for all routines

    Parameters
    ----------
    data : np.ndarray
        Input data array, normally a 2D image
    kernel : np.ndarray
        Input kernel array, normally a 2D kernel
    method : str {'astropy', 'scipy'}, optional
示例#15
0
    def __init__(self,
                 samples,
                 shape,
                 platform='cuda',
                 Kd=None,
                 Jd=None,
                 n_coils=1,
                 verbosity=0):
        """ Initilize the 'NUFFT' class.

        Parameters
        ----------
        samples: np.ndarray
            the mask samples in the Fourier domain.
        shape: tuple of int
            shape of the image (necessarly a square/cubic matrix).
        platform: string, 'cpu', 'opencl' or 'cuda'
            string indicating which hardware platform will be used to
            compute the NUFFT
        Kd: int or tuple
            int or tuple indicating the size of the frequency grid,
            for regridding. If int, will be evaluated
            to (Kd,)*nb_dim of the image
        Jd: int or tuple
            Size of the interpolator kernel. If int, will be evaluated
            to (Jd,)*dims image
        n_coils: int
            Number of coils used to acquire the signal in case of multiarray
            receiver coils acquisition
        """
        if (n_coils < 1) or (type(n_coils) is not int):
            raise ValueError('The number of coils should be an integer >= 1')
        if not pynufft_available:
            raise ValueError('PyNUFFT Package is not installed, please '
                             'consider using `gpuNUFFT` or install the '
                             'PyNUFFT package')
        self.nb_coils = n_coils
        self.shape = shape
        self.platform = platform
        self.samples = samples * (2 * np.pi)  # Pynufft use samples in
        # [-pi, pi[ instead of [-0.5, 0.5[
        self.dim = samples.shape[1]  # number of dimensions of the image

        if type(Kd) == int:
            self.Kd = (Kd, ) * self.dim
        elif type(Kd) == tuple:
            self.Kd = Kd
        elif Kd is None:
            # Preferential option
            self.Kd = tuple([2 * ix for ix in shape])

        if type(Jd) == int:
            self.Jd = (Jd, ) * self.dim
        elif type(Jd) == tuple:
            self.Jd = Jd
        elif Jd is None:
            # Preferential option
            self.Jd = (5, ) * self.dim

        for (i, s) in enumerate(shape):
            assert (self.shape[i] <= self.Kd[i]), 'size of frequency grid' + \
                                                  'must be greater or equal ' \
                                                  'than the image size'
        if verbosity > 0:
            print('Creating the NUFFT object...')
        if self.platform == 'opencl':
            warn('Attemping to use OpenCL plateform. Make sure to '
                 'have  all the dependecies installed')
            Singleton.__init__(self)
            if self.getNumInstances() > 1:
                warn('You have created more than one NUFFT object. '
                     'This could cause memory leaks')
            self.nufftObj = NUFFT_hsa(API='ocl',
                                      platform_number=None,
                                      device_number=None,
                                      verbosity=verbosity)

            self.nufftObj.plan(
                om=self.samples,
                Nd=self.shape,
                Kd=self.Kd,
                Jd=self.Jd,
                batch=1,  # TODO self.nb_coils,
                ft_axes=tuple(range(samples.shape[1])),
                radix=None)

        elif self.platform == 'cuda':
            warn('Attemping to use Cuda plateform. Make sure to '
                 'have  all the dependecies installed and '
                 'to create only one instance of NUFFT GPU')
            Singleton.__init__(self)
            if self.getNumInstances() > 1:
                warn('You have created more than one NUFFT object. '
                     'This could cause memory leaks')
            self.nufftObj = NUFFT_hsa(API='cuda',
                                      platform_number=None,
                                      device_number=None,
                                      verbosity=verbosity)

            self.nufftObj.plan(
                om=self.samples,
                Nd=self.shape,
                Kd=self.Kd,
                Jd=self.Jd,
                batch=1,  # TODO self.nb_coils,
                ft_axes=tuple(range(samples.shape[1])),
                radix=None)

        else:
            raise ValueError('Wrong type of platform. Platform must be'
                             '\'opencl\' or \'cuda\'')
示例#16
0
:Author: Samuel Farrens <*****@*****.**>

"""

from __future__ import division
from builtins import zip
import numpy as np
import scipy.signal
from modopt.base.np_adjust import rotate_stack
from modopt.interface.errors import warn
try:
    from astropy.convolution import convolve_fft
except ImportError:  # pragma: no cover
    import_astropy = False
    warn('astropy not found, will default to scipy for convolution')
else:
    import_astropy = True
try:
    import pyfftw
except ImportError:  # pragma: no cover
    pass
else:  # pragma: no cover
    scipy.fftpack = pyfftw.interfaces.scipy_fftpack
    warn('Using pyFFTW "monkey patch" for scipy.fftpack')


def convolve(data, kernel, method='scipy'):
    r"""Convolve data with kernel

    This method convolves the input data with a given kernel using FFT and