Exemplo n.º 1
0
    def test_resize(self):
        # Zero-pad
        x = np.array([1, 2, 3])
        oshape = [5]
        y = util.resize(x, oshape)
        npt.assert_allclose(y, [0, 1, 2, 3, 0])

        x = np.array([1, 2, 3])
        oshape = [4]
        y = util.resize(x, oshape)
        npt.assert_allclose(y, [0, 1, 2, 3])

        x = np.array([1, 2])
        oshape = [5]
        y = util.resize(x, oshape)
        npt.assert_allclose(y, [0, 1, 2, 0, 0])

        x = np.array([1, 2])
        oshape = [4]
        y = util.resize(x, oshape)
        npt.assert_allclose(y, [0, 1, 2, 0])

        # Zero-pad non centered
        x = np.array([1, 2, 3])
        oshape = [5]
        y = util.resize(x, oshape, oshift=[0])
        npt.assert_allclose(y, [1, 2, 3, 0, 0])

        # Crop
        x = np.array([0, 1, 2, 3, 0])
        oshape = [3]
        y = util.resize(x, oshape)
        npt.assert_allclose(y, [1, 2, 3])

        x = np.array([0, 1, 2, 3])
        oshape = [3]
        y = util.resize(x, oshape)
        npt.assert_allclose(y, [1, 2, 3])

        x = np.array([0, 1, 2, 0, 0])
        oshape = [2]
        y = util.resize(x, oshape)
        npt.assert_allclose(y, [1, 2])

        x = np.array([0, 1, 2, 0])
        oshape = [2]
        y = util.resize(x, oshape)
        npt.assert_allclose(y, [1, 2])

        # Crop non centered
        x = np.array([1, 2, 3, 0, 0])
        oshape = [3]
        y = util.resize(x, oshape, ishift=[0])
        npt.assert_allclose(y, [1, 2, 3])
Exemplo n.º 2
0
def _fft_convolve_adjoint_filter(x, y, mode='full'):
    ndim = x.ndim - 2
    batch_size = len(x)
    output_channel = y.shape[1]
    input_channel = x.shape[1]
    output_shape = y.shape[-ndim:]
    input_shape = x.shape[-ndim:]
    if mode == 'full':
        filter_shape = tuple(p - m + 1
                             for m, p in zip(input_shape, output_shape))
        pad_shape = output_shape
    elif mode == 'valid':
        filter_shape = tuple(m - p + 1
                             for m, p in zip(input_shape, output_shape))
        pad_shape = input_shape

    dtype = x.dtype
    device = backend.get_device(x)
    xp = device.xp
    with device:
        x = xp.conj(util.flip(x, axes=range(-ndim, 0)))
        x = x.reshape((batch_size, 1, input_channel) + input_shape)
        y = y.reshape((batch_size, output_channel, 1) + output_shape)

        x_pad = util.resize(x, (batch_size, 1, input_channel) + pad_shape,
                            oshift=[0] * x.ndim)
        y_pad = util.resize(y, (batch_size, output_channel, 1) + pad_shape,
                            oshift=[0] * y.ndim)

        if np.issubdtype(dtype, np.floating):
            x_fft = xp.fft.rfftn(x_pad, axes=range(-ndim, 0), norm='ortho')
            y_fft = xp.fft.rfftn(y_pad, axes=range(-ndim, 0), norm='ortho')
            W_fft = xp.sum(x_fft * y_fft, axis=0)
            W = xp.fft.irfftn(W_fft,
                              pad_shape,
                              axes=range(-ndim, 0),
                              norm='ortho').astype(dtype)
        else:
            x_fft = fourier.fft(x_pad, axes=range(-ndim, 0), center=False)
            y_fft = fourier.fft(y_pad, axes=range(-ndim, 0), center=False)
            W_fft = xp.sum(x_fft * y_fft, axis=0)
            W = fourier.ifft(W_fft, axes=range(-ndim, 0), center=False)

        if mode == 'full':
            shift = [0, 0] + [m - 1 for m in input_shape]
        elif mode == 'valid':
            shift = [0, 0] + [p - 1 for p in output_shape]

        W = util.resize(W, (output_channel, input_channel) + filter_shape,
                        ishift=shift)
        W *= util.prod(pad_shape)**0.5
        return W
Exemplo n.º 3
0
def _fft_convolve_adjoint_input(W, y, mode='full'):
    ndim = y.ndim - 2
    batch_size = len(y)
    output_channel, input_channel = W.shape[:2]
    output_shape = y.shape[-ndim:]
    filter_shape = W.shape[-ndim:]
    if mode == 'full':
        input_shape = tuple(p - n + 1
                            for p, n in zip(output_shape, filter_shape))
        pad_shape = output_shape
    elif mode == 'valid':
        input_shape = tuple(p + n - 1
                            for p, n in zip(output_shape, filter_shape))
        pad_shape = input_shape

    dtype = y.dtype
    device = backend.get_device(y)
    xp = device.xp
    with device:
        y = y.reshape((batch_size, output_channel, 1) + output_shape)
        W = xp.conj(util.flip(W, axes=range(-ndim, 0)))

        y_pad = util.resize(y, (batch_size, output_channel, 1) + pad_shape,
                            oshift=[0] * y.ndim)
        W_pad = util.resize(W, (output_channel, input_channel) + pad_shape,
                            oshift=[0] * W.ndim)

        if np.issubdtype(dtype, np.floating):
            y_fft = xp.fft.rfftn(y_pad, axes=range(-ndim, 0), norm='ortho')
            W_fft = xp.fft.rfftn(W_pad, axes=range(-ndim, 0), norm='ortho')
            x_fft = xp.sum(y_fft * W_fft, axis=-ndim - 2)
            x = xp.fft.irfftn(x_fft,
                              pad_shape,
                              axes=range(-ndim, 0),
                              norm='ortho').astype(dtype)
        else:
            y_fft = fourier.fft(y_pad, axes=range(-ndim, 0), center=False)
            W_fft = fourier.fft(W_pad, axes=range(-ndim, 0), center=False)
            x_fft = xp.sum(y_fft * W_fft, axis=-ndim - 2)
            x = fourier.ifft(x_fft, axes=range(-ndim, 0), center=False)

        if mode == 'full':
            shift = [0, 0] + [n - 1 for n in filter_shape]
        elif mode == 'valid':
            shift = [0] * x.ndim

        x = util.resize(x, (batch_size, input_channel) + input_shape,
                        ishift=shift)
        x *= util.prod(pad_shape)**0.5
        return x
Exemplo n.º 4
0
def _fft_convolve(x, W, mode='full'):
    ndim = x.ndim - 2
    batch_size = len(x)
    output_channel, input_channel = W.shape[:2]
    input_shape = x.shape[-ndim:]
    filter_shape = W.shape[-ndim:]
    if mode == 'full':
        output_shape = tuple(m + n - 1
                             for m, n in zip(input_shape, filter_shape))
        pad_shape = output_shape
    elif mode == 'valid':
        output_shape = tuple(m - n + 1
                             for m, n in zip(input_shape, filter_shape))
        pad_shape = input_shape

    dtype = x.dtype
    device = backend.get_device(x)
    xp = device.xp
    with device:
        x = x.reshape((batch_size, 1, input_channel) + input_shape)
        x_pad = util.resize(x, (batch_size, 1, input_channel) + pad_shape,
                            oshift=[0] * x.ndim)
        W_pad = util.resize(W, (output_channel, input_channel) + pad_shape,
                            oshift=[0] * W.ndim)

        if np.issubdtype(dtype, np.floating):
            x_fft = xp.fft.rfftn(x_pad, axes=range(-ndim, 0), norm='ortho')
            W_fft = xp.fft.rfftn(W_pad, axes=range(-ndim, 0), norm='ortho')
            y_fft = xp.sum(x_fft * W_fft, axis=-ndim - 1)
            y = xp.fft.irfftn(y_fft,
                              pad_shape,
                              axes=range(-ndim, 0),
                              norm='ortho').astype(dtype)
        else:
            x_fft = fourier.fft(x_pad, axes=range(-ndim, 0), center=False)
            W_fft = fourier.fft(W_pad, axes=range(-ndim, 0), center=False)
            y_fft = xp.sum(x_fft * W_fft, axis=-ndim - 1)
            y = fourier.ifft(y_fft, axes=range(-ndim, 0), center=False)

        if mode == 'full':
            shift = [0] * y.ndim
        elif mode == 'valid':
            shift = [0, 0] + [n - 1 for n in filter_shape]

        y = util.resize(y, (batch_size, output_channel) + output_shape,
                        ishift=shift)
        y *= util.prod(pad_shape)**0.5
        return y
Exemplo n.º 5
0
def nufft(input, coord, oversamp=1.25, width=4):
    """Non-uniform Fast Fourier Transform.

    Args:
        input (array): input signal domain array of shape
            (..., n_{ndim - 1}, ..., n_1, n_0),
            where ndim is specified by coord.shape[-1]. The nufft
            is applied on the last ndim axes, and looped over
            the remaining axes.
        coord (array): Fourier domain coordinate array of shape (..., ndim).
            ndim determines the number of dimensions to apply the nufft.
            coord[..., i] should be scaled to have its range between
            -n_i // 2, and n_i // 2.
        oversamp (float): oversampling factor.
        width (float): interpolation kernel full-width in terms of
            oversampled grid.
        n (int): number of sampling points of the interpolation kernel.

    Returns:
        array: Fourier domain data of shape
            input.shape[:-ndim] + coord.shape[:-1].

    References:
        Fessler, J. A., & Sutton, B. P. (2003).
        Nonuniform fast Fourier transforms using min-max interpolation
        IEEE Transactions on Signal Processing, 51(2), 560-574.
        Beatty, P. J., Nishimura, D. G., & Pauly, J. M. (2005).
        Rapid gridding reconstruction with a minimal oversampling ratio.
        IEEE transactions on medical imaging, 24(6), 799-808.

    """
    ndim = coord.shape[-1]
    beta = np.pi * (((width / oversamp) * (oversamp - 0.5))**2 - 0.8)**0.5
    os_shape = _get_oversamp_shape(input.shape, ndim, oversamp)

    output = input.copy()

    # Apodize
    _apodize(output, ndim, oversamp, width, beta)

    # Zero-pad
    output /= util.prod(input.shape[-ndim:])**0.5
    output = util.resize(output, os_shape)

    # FFT
    output = fft(output, axes=range(-ndim, 0), norm=None)

    # Interpolate
    coord = _scale_coord(coord, input.shape, oversamp)
    output = interp.interpolate(output,
                                coord,
                                kernel='kaiser_bessel',
                                width=width,
                                param=beta)
    output /= width**ndim

    return output
Exemplo n.º 6
0
def nufft_adjoint(input, coord, oshape=None, oversamp=1.25, width=4):
    """Adjoint non-uniform Fast Fourier Transform.

    Args:
        input (array): input Fourier domain array of shape
            (...) + coord.shape[:-1]. That is, the last dimensions
            of input must match the first dimensions of coord.
            The nufft_adjoint is applied on the last coord.ndim - 1 axes,
            and looped over the remaining axes.
        coord (array): Fourier domain coordinate array of shape (..., ndim).
            ndim determines the number of dimension to apply nufft adjoint.
            coord[..., i] should be scaled to have its range between
            -n_i // 2, and n_i // 2.
        oshape (tuple of ints): output shape of the form
            (..., n_{ndim - 1}, ..., n_1, n_0).
        oversamp (float): oversampling factor.
        width (float): interpolation kernel full-width in terms of
            oversampled grid.
        n (int): number of sampling points of the interpolation kernel.

    Returns:
        array: signal domain array with shape specified by oshape.

    See Also:
        :func:`sigpy.nufft.nufft`

    """
    ndim = coord.shape[-1]
    beta = np.pi * (((width / oversamp) * (oversamp - 0.5))**2 - 0.8)**0.5
    if oshape is None:
        oshape = list(input.shape[:-coord.ndim + 1]) + estimate_shape(coord)
    else:
        oshape = list(oshape)

    os_shape = _get_oversamp_shape(oshape, ndim, oversamp)

    # Gridding
    coord = _scale_coord(coord, oshape, oversamp)
    output = interp.gridding(input,
                             coord,
                             os_shape,
                             kernel='kaiser_bessel',
                             width=width,
                             param=beta)
    output /= width**ndim

    # IFFT
    output = ifft(output, axes=range(-ndim, 0), norm=None)

    # Crop
    output = util.resize(output, oshape)
    output *= util.prod(os_shape[-ndim:]) / util.prod(oshape[-ndim:])**0.5

    # Apodize
    _apodize(output, ndim, oversamp, width, beta)

    return output
Exemplo n.º 7
0
def _ifftc(input, oshape=None, axes=None, norm='ortho'):
    ndim = input.ndim
    axes = util._normalize_axes(axes, ndim)
    xp = backend.get_array_module(input)

    if oshape is None:
        oshape = input.shape

    tmp = util.resize(input, oshape)
    tmp = xp.fft.ifftshift(tmp, axes=axes)
    tmp = xp.fft.ifftn(tmp, axes=axes, norm=norm)
    output = xp.fft.fftshift(tmp, axes=axes)
    return output
Exemplo n.º 8
0
def _ifftc(input, oshape=None, axes=None, norm='ortho'):
    ndim = input.ndim
    axes = util._normalize_axes(axes, ndim)
    device = backend.get_device(input)
    xp = device.xp

    if oshape is None:
        oshape = input.shape

    with device:
        tmp = util.resize(input, oshape)
        tmp = xp.fft.ifftshift(tmp, axes=axes)
        tmp = xp.fft.ifftn(tmp, axes=axes, norm=norm)
        output = xp.fft.fftshift(tmp, axes=axes)
        return output
Exemplo n.º 9
0
def iwt(input, oshape, coeff_slices, wave_name='db4', axes=None, level=None):
    """Inverse wavelet transform.

    Args:
        input (array): Input array.
        oshape (tuple of ints): Output shape.
        coeff_slices (list of slice): Slices to split coefficients.
        axes (None or tuple of int): Axes to perform wavelet transform.
        wave_name (str): Wavelet name.
        level (None or int): Number of wavelet levels.
    """
    device = backend.get_device(input)
    input = backend.to_device(input, backend.cpu_device)

    input = pywt.array_to_coeffs(input, coeff_slices, output_format='wavedecn')
    output = pywt.waverecn(input, wave_name, mode='zero', axes=axes)
    output = util.resize(output, oshape)

    output = backend.to_device(output, device)
    return output
Exemplo n.º 10
0
def fwt(input, wave_name='db4', axes=None, level=None):
    """Forward wavelet transform.

    Args:
        input (array): Input array.
        axes (None or tuple of int): Axes to perform wavelet transform.
        wave_name (str): Wavelet name.
        level (None or int): Number of wavelet levels.
    """
    device = backend.get_device(input)
    input = backend.to_device(input, backend.cpu_device)

    zshape = [((i + 1) // 2) * 2 for i in input.shape]
    zinput = util.resize(input, zshape)

    coeffs = pywt.wavedecn(zinput, wave_name, mode='zero', axes=axes, level=level)
    output, _ = pywt.coeffs_to_array(coeffs, axes=axes)

    output = backend.to_device(output, device)
    return output
Exemplo n.º 11
0
    def _apply(self, input):

        return util.resize(input,
                           self.oshape,
                           ishift=self.ishift,
                           oshift=self.oshift)
Exemplo n.º 12
0
 def _apply(self, input):
     with backend.get_device(input):
         return util.resize(input, self.oshape,
                            ishift=self.ishift, oshift=self.oshift)
Exemplo n.º 13
0
 def _get_dataset(self, i):
     return util.resize(np.load(self.filepaths[i]), self.shape[1:])