def l1_proj(eps, input):
    """Projection onto L1 ball.

    Args:
        eps (float, or array): L1 ball scaling.
        input (array)

    Returns:
        array: Result.

    References:
        J. Duchi, S. Shalev-Shwartz, and Y. Singer, "Efficient projections onto
        the l1-ball for learning in high dimensions" 2008.

    """
    xp = backend.get_array_module(input)
    shape = input.shape
    input = input.ravel()

    if xp.linalg.norm(input, 1) < eps:
        return input
    else:
        size = len(input)
        s = xp.sort(xp.abs(input))[::-1]
        st = (xp.cumsum(s) - eps) / (xp.arange(size) + 1)
        idx = xp.flatnonzero((s - st) > 0).max()
        return soft_thresh(st[idx], input.reshape(shape))
Exemple #2
0
def ifft(input, oshape=None, axes=None, center=True, norm='ortho'):
    """IFFT function that supports centering.

    Args:
        input (array): input array.
        oshape (None or array of ints): output shape.
        axes (None or array of ints): Axes over which to compute
            the inverse FFT.
        norm (None or ``"ortho"``): Keyword to specify the normalization mode.

    Returns:
        array of dimension oshape.

    See Also:
        :func:`numpy.fft.ifftn`

    """
    xp = backend.get_array_module(input)
    if not np.issubdtype(input.dtype, np.complexfloating):
        input = input.astype(np.complex)

    if center:
        output = _ifftc(input, oshape=oshape, axes=axes, norm=norm)
    else:
        output = xp.fft.ifftn(input, s=oshape, axes=axes, norm=norm)

    if np.issubdtype(input.dtype,
                     np.complexfloating) and input.dtype != output.dtype:
        output = output.astype(input.dtype)

    return output
Exemple #3
0
def convolve_data_adjoint(output,
                          filt,
                          data_shape,
                          mode='full',
                          strides=None,
                          multi_channel=False):
    """Adjoint convolution operation with respect to data.

    Args:
        output (array): output array of shape
            :math:`[..., p_1, ..., p_D]` if multi_channel is False,
            :math:`[..., c_o, p_1, ..., p_D]` otherwise.
        filt (array): filter array of shape
            :math:`[n_1, ..., n_D]` if multi_channel is False
            :math:`[c_o, c_i, n_1, ..., n_D]` otherwise.
        mode (str): {'full', 'valid'}.
        strides (None or tuple of ints): convolution strides of length D.
        multi_channel (bool): specify if input/output has multiple channels.
        multi_channel (bool): specify if data/output has multiple channels.
        mode (str): {'full', 'valid'}.

    Returns:
        array: data array of shape
            :math:`[..., m_1, ..., m_D]` if multi_channel is False,
            :math:`[..., c_i, m_1, ..., m_D]` otherwise.

    """
    data_shape = tuple(data_shape)

    xp = backend.get_array_module(output)
    if xp == np:
        data = _convolve_data_adjoint(output,
                                      filt,
                                      data_shape,
                                      mode=mode,
                                      strides=strides,
                                      multi_channel=multi_channel)
    else:  # pragma: no cover
        if config.cudnn_enabled:
            if np.issubdtype(output.dtype, np.floating):
                data = _convolve_data_adjoint_cuda(output,
                                                   filt,
                                                   data_shape,
                                                   mode=mode,
                                                   strides=strides,
                                                   multi_channel=multi_channel)
            else:
                data = _complex(_convolve_data_adjoint_cuda,
                                output,
                                filt.conj(),
                                data_shape,
                                mode=mode,
                                strides=strides,
                                multi_channel=multi_channel)
        else:
            raise RuntimeError(
                'cudnn must be installed to perform convolution on GPU.')

    return data
Exemple #4
0
    def __mul__(self, input):
        if isinstance(input, Linop):
            return Compose([self, input])
        elif np.isscalar(input):
            M = Multiply(self.ishape, input)
            return Compose([self, M])
        elif isinstance(input, backend.get_array_module(input).ndarray):
            return self.apply(input)

        return NotImplemented
Exemple #5
0
def vec(inputs):
    """Vectorize inputs.

    Args:
        shape (tuple or list): shape.

    Returns:
        array: Vectorized result.
    """
    xp = backend.get_array_module(inputs[0])
    return xp.concatenate([i.ravel() for i in inputs])
Exemple #6
0
def rss(input, axes=(0, )):
    """Root sum of squares.

    Args:
        input (array): Input array.
        axes (None or tuple of ints): Axes to perform operation.

    Returns:
        array: Result.
    """
    xp = backend.get_array_module(input)
    return xp.sum(xp.abs(input)**2, axis=axes)**0.5
Exemple #7
0
    def _convolve_data_adjoint_cuda(output,
                                    filt,
                                    data_shape,
                                    mode='full',
                                    strides=None,
                                    multi_channel=False):
        xp = backend.get_array_module(output)

        D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
            data_shape, filt.shape, mode, strides, multi_channel)

        if D == 1:
            return _convolve_data_adjoint_cuda(
                xp.expand_dims(output, -1),
                xp.expand_dims(filt, -1),
                list(data_shape) + [1],
                mode=mode,
                strides=list(strides) + [1] if strides is not None else None,
                multi_channel=multi_channel).squeeze(-1)
        elif D > 3:
            raise ValueError(
                f'cuDNN convolution only supports 1, 2 or 3D, got {D}.')

        dilations = (1, ) * D
        groups = 1
        auto_tune = True
        tensor_core = 'auto'
        deterministic = False
        if mode == 'full':
            pads = tuple(n_d - 1 for n_d in n)
        elif mode == 'valid':
            pads = (0, ) * D

        output = output.reshape((B, c_o) + p)
        filt = filt.reshape((c_o, c_i) + n)
        data = xp.empty((B, c_i) + m, dtype=output.dtype)
        filt = util.flip(filt, axes=range(-D, 0))
        cudnn.convolution_backward_data(filt,
                                        output,
                                        None,
                                        data,
                                        pads,
                                        s,
                                        dilations,
                                        groups,
                                        deterministic=deterministic,
                                        auto_tune=auto_tune,
                                        tensor_core=tensor_core)

        # Reshape.
        data = data.reshape(data_shape)

        return data
Exemple #8
0
    def _convolve_cuda(data,
                       filt,
                       mode='full',
                       strides=None,
                       multi_channel=False):
        xp = backend.get_array_module(data)

        D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
            data.shape, filt.shape, mode, strides, multi_channel)

        if D == 1:
            return _convolve_cuda(xp.expand_dims(data, -1),
                                  xp.expand_dims(filt, -1),
                                  mode=mode,
                                  strides=list(strides) +
                                  [1] if strides is not None else None,
                                  multi_channel=multi_channel).squeeze(-1)
        elif D > 3:
            raise ValueError(
                f'cuDNN convolution only supports 1, 2, or 3D, got {D}.')

        dilations = (1, ) * D
        groups = 1
        auto_tune = True
        tensor_core = 'auto'
        if mode == 'full':
            pads = tuple(n_d - 1 for n_d in n)
        elif mode == 'valid':
            pads = (0, ) * D

        data = data.reshape((B, c_i) + m)
        filt = filt.reshape((c_o, c_i) + n)
        output = xp.empty((B, c_o) + p, dtype=data.dtype)
        filt = util.flip(filt, axes=range(-D, 0))
        cudnn.convolution_forward(data,
                                  filt,
                                  None,
                                  output,
                                  pads,
                                  s,
                                  dilations,
                                  groups,
                                  auto_tune=auto_tune,
                                  tensor_core=tensor_core)

        # Reshape.
        if multi_channel:
            output = output.reshape(b + (c_o, ) + p)
        else:
            output = output.reshape(b + p)

        return output
Exemple #9
0
def _ifftc(input, oshape=None, axes=None, norm='ortho'):
    ndim = input.ndim
    axes = util._normalize_axes(axes, ndim)
    xp = backend.get_array_module(input)

    if oshape is None:
        oshape = input.shape

    tmp = util.resize(input, oshape)
    tmp = xp.fft.ifftshift(tmp, axes=axes)
    tmp = xp.fft.ifftn(tmp, axes=axes, norm=norm)
    output = xp.fft.fftshift(tmp, axes=axes)
    return output
Exemple #10
0
def _apodize(input, ndim, oversamp, width, beta):
    xp = backend.get_array_module(input)
    output = input
    for a in range(-ndim, 0):
        i = output.shape[a]
        os_i = ceil(oversamp * i)
        idx = xp.arange(i, dtype=output.dtype)

        # Calculate apodization
        apod = (beta**2 - (np.pi * width * (idx - i // 2) / os_i)**2)**0.5
        apod /= xp.sinh(apod)
        output *= apod.reshape([i] + [1] * (-a - 1))

    return output
Exemple #11
0
def convolve(data, filt, mode='full', strides=None, multi_channel=False):
    r"""Convolution that supports multi-dimensional and multi-channel inputs.

    This function follows the signal processing definition of convolution.
    Note that the cuDNN version only supports inputs with D=1, 2 or 3.

    Args:
        data (array): data array of shape:
            :math:`[..., m_1, ..., m_D]` if multi_channel is False,
            :math:`[..., c_i, m_1, ..., m_D]` otherwise.
        filt (array): filter array of shape:
            :math:`[n_1, ..., n_D]` if multi_channel is False
            :math:`[c_o, c_i, n_1, ..., n_D]` otherwise.
        mode (str): {'full', 'valid'}.
        strides (None or tuple of ints): convolution strides of length D.
        multi_channel (bool): specify if input/output has multiple channels.

    Returns:
        array: output array of shape:
            :math:`[..., p_1, ..., p_D]` if multi_channel is False,
            :math:`[..., c_o, p_1, ..., p_D]` otherwise.

    """
    xp = backend.get_array_module(data)
    if xp == np:
        output = _convolve(data,
                           filt,
                           mode=mode,
                           strides=strides,
                           multi_channel=multi_channel)
    else:  # pragma: no cover
        if config.cudnn_enabled:
            if np.issubdtype(data.dtype, np.floating):
                output = _convolve_cuda(data,
                                        filt,
                                        mode=mode,
                                        strides=strides,
                                        multi_channel=multi_channel)
            else:
                output = _complex(_convolve_cuda,
                                  data,
                                  filt,
                                  mode=mode,
                                  strides=strides,
                                  multi_channel=multi_channel)
        else:
            raise RuntimeError(
                'cudnn must be installed to perform convolution on GPU.')

    return output
Exemple #12
0
def psd_proj(input):
    """Projection onto postiive semi-definite matrices.

    Args:
        input (array): a two-dimensional matrix.

    Returns:
        array: Result.

    """
    xp = backend.get_array_module(input)
    w, v = xp.linalg.eig((input + xp.conj(input).T) / 2)
    w[w < 0] = 0
    return (v * w) @ v.conjugate().T
Exemple #13
0
def xpay(y, a, x):
    """Compute y = x + a * y.

    Args:
        y (array): Output array.
        a (scalar): Input scalar.
        x (array): Input array.
    """
    xp = backend.get_array_module(y)

    if xp == np:
        _xpay(y, a, x, out=y)
    else:
        _xpay_cuda(a, x, y)
Exemple #14
0
def hard_thresh(lamda, input):
    """Hard threshold.

    Args:
        lamda (float, or array): Threshold parameter.
        input (array)

    Returns:
        array: hard-thresholded result.

    """
    xp = backend.get_array_module(input)
    if xp == np:
        return _hard_thresh(lamda, input)
    else:  # pragma: no cover
        return _hard_thresh_cuda(lamda, input)
Exemple #15
0
    def _complex(func, data1, data2, *kargs, **kwargs):
        """Helper function to convert func to support complex floats.
        """
        xp = backend.get_array_module(data1)
        data1r = xp.real(data1)
        data1i = xp.imag(data1)
        data2r = xp.real(data2)
        data2i = xp.imag(data2)

        outputr = func(data1r, data2r, *kargs, **kwargs)
        outputr -= func(data1i, data2i, *kargs, **kwargs)
        outputi = func(data1i, data2r, *kargs, **kwargs)
        outputi += func(data1r, data2i, *kargs, **kwargs)

        output = outputr + 1j * outputi
        output = output.astype(data1.dtype, copy=False)
        return output
Exemple #16
0
def toeplitz_psf(coord, shape, oversamp=1.25, width=4):
    """Toeplitz PSF for fast Normal non-uniform Fast Fourier Transform.

    While fast, this is more computationally expensive.

    Args:
        coord (array): Fourier domain coordinate array of shape (..., ndim).
            ndim determines the number of dimension to apply nufft adjoint.
            coord[..., i] should be scaled to have its range between
            -n_i // 2, and n_i // 2.
        shape (tuple of ints): shape of the form
            (..., n_{ndim - 1}, ..., n_1, n_0).
            This is the shape of the input array of the forward nufft.
        oversamp (float): oversampling factor.
        width (float): interpolation kernel full-width in terms of
            oversampled grid.

    Returns:
        array: PSF to be used by the normal operator defined in
            `sigpy.linop.NUFFT`

    See Also:
        :func:`sigpy.linop.NUFFT`

    """
    xp = backend.get_array_module(coord)
    with backend.get_device(coord):
        ndim = coord.shape[-1]

        new_shape = _get_oversamp_shape(shape, ndim, 2)
        new_coord = _scale_coord(coord, new_shape, 2)

        idx = [slice(None)] * len(new_shape)
        for k in range(-1, -(ndim + 1), -1):
            idx[k] = new_shape[k] // 2

        d = xp.zeros(new_shape, dtype=xp.complex64)
        d[tuple(idx)] = 1

        psf = nufft(d, new_coord, oversamp, width)
        psf = nufft_adjoint(psf, new_coord, d.shape, oversamp, width)
        fft_axes = tuple(range(-1, -(ndim + 1), -1))
        psf = fft(psf, axes=fft_axes, norm=None) * (2**ndim)

        return psf
Exemple #17
0
def l2_proj(eps, input, axes=None):
    """Projection onto L2 ball.

    Args:
        eps (float, or array): L2 ball scaling.
        input (array)

    Returns:
        array: Result.

    """
    axes = util._normalize_axes(axes, input.ndim)

    xp = backend.get_array_module(input)
    norm = xp.sum(xp.abs(input)**2, axis=axes, keepdims=True)**0.5
    mask = norm < eps
    output = mask * input + (1 - mask) * (eps * input / (norm + mask))

    return output
Exemple #18
0
def soft_thresh(lamda, input):
    r"""Soft threshold.

    Performs:

    .. math::
        (| x | - \lambda)_+  \text{sgn}(x)

    Args:
        lamda (float, or array): Threshold parameter.
        input (array)

    Returns:
        array: soft-thresholded result.

    """
    xp = backend.get_array_module(input)
    if xp == np:
        return _soft_thresh(lamda, input)
    else:  # pragma: no cover
        return _soft_thresh_cuda(lamda, input)
Exemple #19
0
    def _convolve_data_adjoint_cuda(output,
                                    filt,
                                    data_shape,
                                    mode='full',
                                    strides=None,
                                    multi_channel=False):
        xp = backend.get_array_module(output)

        D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
            data_shape, filt.shape, mode, strides, multi_channel)
        dilations = (1, ) * D
        groups = 1
        auto_tune = True
        tensor_core = 'auto'
        deterministic = False
        if mode == 'full':
            pads = tuple(n_d - 1 for n_d in n)
        elif mode == 'valid':
            pads = (0, ) * D

        output = output.reshape((B, c_o) + p)
        filt = filt.reshape((c_o, c_i) + n)
        data = xp.empty((B, c_i) + m, dtype=output.dtype)
        filt = util.flip(filt, axes=range(-D, 0))
        cudnn.convolution_backward_data(filt,
                                        output,
                                        None,
                                        data,
                                        pads,
                                        s,
                                        dilations,
                                        groups,
                                        deterministic=deterministic,
                                        auto_tune=auto_tune,
                                        tensor_core=tensor_core)

        # Reshape.
        data = data.reshape(data_shape)

        return data
Exemple #20
0
    def _convolve_cuda(data,
                       filt,
                       mode='full',
                       strides=None,
                       multi_channel=False):
        xp = backend.get_array_module(data)

        D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
            data.shape, filt.shape, mode, strides, multi_channel)
        dilations = (1, ) * D
        groups = 1
        auto_tune = True
        tensor_core = 'auto'
        if mode == 'full':
            pads = tuple(n_d - 1 for n_d in n)
        elif mode == 'valid':
            pads = (0, ) * D

        data = data.reshape((B, c_i) + m)
        filt = filt.reshape((c_o, c_i) + n)
        output = xp.empty((B, c_o) + p, dtype=data.dtype)
        filt = util.flip(filt, axes=range(-D, 0))
        cudnn.convolution_forward(data,
                                  filt,
                                  None,
                                  output,
                                  pads,
                                  s,
                                  dilations,
                                  groups,
                                  auto_tune=auto_tune,
                                  tensor_core=tensor_core)

        # Reshape.
        if multi_channel:
            output = output.reshape(b + (c_o, ) + p)
        else:
            output = output.reshape(b + p)

        return output
Exemple #21
0
def upsample(input, oshape, factors, shift=None):
    """Upsample input.

    Args:
        input (array): Input array.
        factors (tuple of ints): Upsampling factors.
        shifts (None or tuple of ints): Shifts.

    Returns:
        array: Result.
    """

    if shift is None:
        shift = [0] * len(factors)

    slc = tuple(slice(s, None, f) for s, f in zip(shift, factors))

    xp = backend.get_array_module(input)
    output = xp.zeros(oshape, dtype=input.dtype)
    output[slc] = input

    return output
Exemple #22
0
def circshift(input, shifts, axes=None):
    """Circular shift input.

    Args:
        input (array): Input array.
        shifts (tuple of ints): Shifts.
        axes (None or tuple of ints): Axes to perform operation.

    Returns:
        array: Result.
    """

    if axes is None:
        axes = range(input.ndim)

    assert (len(axes) == len(shifts))
    xp = backend.get_array_module(input)

    for axis, shift in zip(axes, shifts):
        input = xp.roll(input, shift, axis=axis)

    return input
Exemple #23
0
def resize(input, oshape, ishift=None, oshift=None):
    """Resize with zero-padding or cropping.

    Args:
        input (array): Input array.
        oshape (tuple of ints): Output shape.
        ishift (None or tuple of ints): Input shift.
        oshift (None or tuple of ints): Output shift.

    Returns:
        array: Zero-padded or cropped result.
    """

    ishape1, oshape1 = _expand_shapes(input.shape, oshape)

    if ishape1 == oshape1:
        return input.reshape(oshape)

    if ishift is None:
        ishift = [max(i // 2 - o // 2, 0) for i, o in zip(ishape1, oshape1)]

    if oshift is None:
        oshift = [max(o // 2 - i // 2, 0) for i, o in zip(ishape1, oshape1)]

    copy_shape = [
        min(i - si, o - so)
        for i, si, o, so in zip(ishape1, ishift, oshape1, oshift)
    ]
    islice = tuple([slice(si, si + c) for si, c in zip(ishift, copy_shape)])
    oslice = tuple([slice(so, so + c) for so, c in zip(oshift, copy_shape)])

    xp = backend.get_array_module(input)
    output = xp.zeros(oshape1, dtype=input.dtype)
    input = input.reshape(ishape1)
    output[oslice] = input[islice]

    return output.reshape(oshape)
Exemple #24
0
def gridding(input, coord, shape, kernel="spline", width=2, param=1):
    r"""Gridding of points specified by coordinates to array.

    Let :math:`y` be the input, :math:`x` be the output,
    :math:`c` be the coordinates, :math:`W` be the kernel width,
    and :math:`K` be the interpolation kernel, then the function computes,

    .. math ::
        x[i] = \sum_{j : \| i - c[j] \|_\infty \leq W / 2}
               K\left(\frac{i - c[j]}{W / 2}\right) y[j]

    There are two types of kernels: 'spline' and 'kaiser_bessel'.

    'spline' uses the cardinal B-spline functions as kernels.
    The order of the spline can be specified using param.
    For example, param=1 performs linear interpolation.
    Concretely, for param=0, :math:`K(x) = 1`,
    for param=1, :math:`K(x) = 1 - |x|`, and
    for param=2, :math:`K(x) = \frac{9}{8} (1 - |x|)^2`
    for :math:`|x| > \frac{1}{3}`
    and :math:`K(x) = \frac{3}{4} (1 - 3 x^2)` for :math:`|x| < \frac{1}{3}`.

    These function expressions are derived from the reference wikipedia
    page by shifting and scaling the range to -1 to 1.
    When the coordinates specifies a uniformly spaced grid,
    it is recommended to use the original scaling with width=param + 1
    so that the interpolation weights add up to one.

    'kaiser_bessel' uses the Kaiser-Bessel function as kernel.
    Concretely, :math:`K(x) = I_0(\beta \sqrt{1 - x^2})`,
    where :math:`I_0` is the modified Bessel function of the first kind.
    The beta parameter can be specified with param.
    The modified Bessel function of the first kind is approximated
    using the power series, following the reference.

    Args:
        input (array): Input array.
        coord (array): Coordinate array of shape [..., ndim]
        width (float or tuple of floats): Interpolation kernel full-width.
        kernel (str): Interpolation kernel, {"spline", "kaiser_bessel"}.
        param (float or tuple of floats): Kernel parameter.

    Returns:
        output (array): Output array.

    References:
        https://en.wikipedia.org/wiki/Spline_wavelet#Cardinal_B-splines_of_small_orders
        http://people.math.sfu.ca/~cbm/aands/page_378.htm
    """
    ndim = coord.shape[-1]

    batch_shape = shape[:-ndim]
    batch_size = util.prod(batch_shape)

    pts_shape = coord.shape[:-1]
    npts = util.prod(pts_shape)

    xp = backend.get_array_module(input)
    isreal = np.issubdtype(input.dtype, np.floating)

    input = input.reshape([batch_size, npts])
    coord = coord.reshape([npts, ndim])
    output = xp.zeros([batch_size] + list(shape[-ndim:]), dtype=input.dtype)

    if np.isscalar(param):
        param = xp.array([param] * ndim, coord.dtype)
    else:
        param = xp.array(param, coord.dtype)

    if np.isscalar(width):
        width = xp.array([width] * ndim, coord.dtype)
    else:
        width = xp.array(width, coord.dtype)

    if xp == np:
        _gridding[kernel][ndim - 1](output, input, coord, width, param)
    else:  # pragma: no cover
        if isreal:
            _gridding_cuda[kernel][ndim - 1](
                input, coord, width, param, output, size=npts)
        else:
            _gridding_cuda_complex[kernel][ndim - 1](
                input, coord, width, param, output, size=npts)

    return output.reshape(shape)
Exemple #25
0
def array_to_blocks(input, blk_shape, blk_strides):
    """Extract blocks from an array in a sliding window manner.

    Args:
        input (array): input array of shape [..., N_1, ..., N_ndim]
        blk_shape (tuple): block shape of length ndim, with ndim={1, 2, 3}.
        blk_strides (tuple): block strides of length ndim.

    Returns:
        array: array of shape [...] + num_blks + blk_shape, where
            num_blks = (N - blk_shape + blk_strides) // blk_strides.

    Example:

        >>> input = np.array([0, 1, 2, 3, 4, 5])
        >>> print(array_to_blocks(input, [2], [2]))
        [[0, 1],
         [2, 3],
         [4, 5]]

    """
    if len(blk_shape) != len(blk_strides):
        raise ValueError('blk_shape must have the same length as blk_strides.')

    ndim = len(blk_shape)
    num_blks = [(i - b + s) // s
                for i, b, s in zip(input.shape[-ndim:], blk_shape, blk_strides)
                ]
    batch_shape = list(input.shape[:-ndim])
    batch_size = util.prod(batch_shape)
    xp = backend.get_array_module(input)
    output = xp.zeros([batch_size] + num_blks + list(blk_shape),
                      dtype=input.dtype)
    input = input.reshape([batch_size] + list(input.shape[-ndim:]))

    if ndim == 1:
        if xp == np:
            _array_to_blocks1(output, input, batch_size, blk_shape[-1],
                              blk_strides[-1], num_blks[-1])
        else:  # pragma: no cover
            _array_to_blocks1_cuda(input,
                                   batch_size,
                                   blk_shape[-1],
                                   blk_strides[-1],
                                   num_blks[-1],
                                   output,
                                   size=output.size)
    elif ndim == 2:
        if xp == np:
            _array_to_blocks2(output, input, batch_size, blk_shape[-1],
                              blk_shape[-2], blk_strides[-1], blk_strides[-2],
                              num_blks[-1], num_blks[-2])
        else:  # pragma: no cover
            _array_to_blocks2_cuda(input,
                                   batch_size,
                                   blk_shape[-1],
                                   blk_shape[-2],
                                   blk_strides[-1],
                                   blk_strides[-2],
                                   num_blks[-1],
                                   num_blks[-2],
                                   output,
                                   size=output.size)
    elif ndim == 3:
        if xp == np:
            _array_to_blocks3(output, input, batch_size, blk_shape[-1],
                              blk_shape[-2], blk_shape[-3], blk_strides[-1],
                              blk_strides[-2], blk_strides[-3], num_blks[-1],
                              num_blks[-2], num_blks[-3])
        else:  # pragma: no cover
            _array_to_blocks3_cuda(input,
                                   batch_size,
                                   blk_shape[-1],
                                   blk_shape[-2],
                                   blk_shape[-3],
                                   blk_strides[-1],
                                   blk_strides[-2],
                                   blk_strides[-3],
                                   num_blks[-1],
                                   num_blks[-2],
                                   num_blks[-3],
                                   output,
                                   size=output.size)
    else:
        raise ValueError('Only support ndim=1, 2, or 3, got {}'.format(ndim))

    return output.reshape(batch_shape + num_blks + list(blk_shape))
Exemple #26
0
def blocks_to_array(input, oshape, blk_shape, blk_strides):
    """Accumulate blocks into an array in a sliding window manner.

    Args:
        input (array): input array of shape [...] + num_blks + blk_shape
        oshape (tuple): output shape.
        blk_shape (tuple): block shape of length ndim.
        blk_strides (tuple): block strides of length ndim.

    Returns:
        array: array of shape oshape.

    """
    if len(blk_shape) != len(blk_strides):
        raise ValueError('blk_shape must have the same length as blk_strides.')

    ndim = len(blk_shape)
    num_blks = input.shape[-(2 * ndim):-ndim]
    batch_shape = list(oshape[:-ndim])
    batch_size = util.prod(batch_shape)
    xp = backend.get_array_module(input)
    output = xp.zeros([batch_size] + list(oshape[-ndim:]), dtype=input.dtype)
    input = input.reshape([batch_size] + list(input.shape[-2 * ndim:]))

    if ndim == 1:
        if xp == np:
            _blocks_to_array1(output, input, batch_size, blk_shape[-1],
                              blk_strides[-1], num_blks[-1])
        else:  # pragma: no cover
            _blocks_to_array1_cuda(input,
                                   batch_size,
                                   blk_shape[-1],
                                   blk_strides[-1],
                                   num_blks[-1],
                                   output,
                                   size=output.size)
    elif ndim == 2:
        if xp == np:
            _blocks_to_array2(output, input, batch_size, blk_shape[-1],
                              blk_shape[-2], blk_strides[-1], blk_strides[-2],
                              num_blks[-1], num_blks[-2])
        else:  # pragma: no cover
            _blocks_to_array2_cuda(input,
                                   batch_size,
                                   blk_shape[-1],
                                   blk_shape[-2],
                                   blk_strides[-1],
                                   blk_strides[-2],
                                   num_blks[-1],
                                   num_blks[-2],
                                   output,
                                   size=output.size)
    elif ndim == 3:
        if xp == np:
            _blocks_to_array3(output, input, batch_size, blk_shape[-1],
                              blk_shape[-2], blk_shape[-3], blk_strides[-1],
                              blk_strides[-2], blk_strides[-3], num_blks[-1],
                              num_blks[-2], num_blks[-3])
        else:  # pragma: no cover
            _blocks_to_array3_cuda(input,
                                   batch_size,
                                   blk_shape[-1],
                                   blk_shape[-2],
                                   blk_shape[-3],
                                   blk_strides[-1],
                                   blk_strides[-2],
                                   blk_strides[-3],
                                   num_blks[-1],
                                   num_blks[-2],
                                   num_blks[-3],
                                   output,
                                   size=output.size)
    else:
        raise ValueError('Only support ndim=1, 2, or 3, got {}'.format(ndim))

    return output.reshape(oshape)