コード例 #1
0
ファイル: conv.py プロジェクト: jychengmri/sigpy
def _fft_convolve_adjoint_input(W, y, mode='full'):
    ndim = y.ndim - 2
    batch_size = len(y)
    output_channel, input_channel = W.shape[:2]
    output_shape = y.shape[-ndim:]
    filter_shape = W.shape[-ndim:]
    if mode == 'full':
        input_shape = tuple(p - n + 1
                            for p, n in zip(output_shape, filter_shape))
        pad_shape = output_shape
    elif mode == 'valid':
        input_shape = tuple(p + n - 1
                            for p, n in zip(output_shape, filter_shape))
        pad_shape = input_shape

    dtype = y.dtype
    device = backend.get_device(y)
    xp = device.xp
    with device:
        y = y.reshape((batch_size, output_channel, 1) + output_shape)
        W = xp.conj(util.flip(W, axes=range(-ndim, 0)))

        y_pad = util.resize(y, (batch_size, output_channel, 1) + pad_shape,
                            oshift=[0] * y.ndim)
        W_pad = util.resize(W, (output_channel, input_channel) + pad_shape,
                            oshift=[0] * W.ndim)

        if np.issubdtype(dtype, np.floating):
            y_fft = xp.fft.rfftn(y_pad, axes=range(-ndim, 0), norm='ortho')
            W_fft = xp.fft.rfftn(W_pad, axes=range(-ndim, 0), norm='ortho')
            x_fft = xp.sum(y_fft * W_fft, axis=-ndim - 2)
            x = xp.fft.irfftn(x_fft,
                              pad_shape,
                              axes=range(-ndim, 0),
                              norm='ortho').astype(dtype)
        else:
            y_fft = fourier.fft(y_pad, axes=range(-ndim, 0), center=False)
            W_fft = fourier.fft(W_pad, axes=range(-ndim, 0), center=False)
            x_fft = xp.sum(y_fft * W_fft, axis=-ndim - 2)
            x = fourier.ifft(x_fft, axes=range(-ndim, 0), center=False)

        if mode == 'full':
            shift = [0, 0] + [n - 1 for n in filter_shape]
        elif mode == 'valid':
            shift = [0] * x.ndim

        x = util.resize(x, (batch_size, input_channel) + input_shape,
                        ishift=shift)
        x *= util.prod(pad_shape)**0.5
        return x
コード例 #2
0
ファイル: conv.py プロジェクト: mikgroup/sigpy
    def _convolve_filter_adjoint_cuda(output,
                                      data,
                                      filt_shape,
                                      mode='full',
                                      strides=None,
                                      multi_channel=False):
        xp = backend.get_array_module(data)

        D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
            data.shape, filt_shape, mode, strides, multi_channel)
        if D == 1:
            return _convolve_filter_adjoint_cuda(
                xp.expand_dims(output, -1),
                xp.expand_dims(data, -1),
                list(filt_shape) + [1],
                mode=mode,
                strides=list(strides) + [1] if strides is not None else None,
                multi_channel=multi_channel).squeeze(-1)
        elif D > 3:
            raise ValueError(
                f'cuDNN convolution only supports 1, 2 or 3D, got {D}.')

        dilations = (1, ) * D
        groups = 1
        auto_tune = True
        tensor_core = 'auto'
        deterministic = False
        if mode == 'full':
            pads = tuple(n_d - 1 for n_d in n)
        elif mode == 'valid':
            pads = (0, ) * D

        data = data.reshape((B, c_i) + m)
        output = output.reshape((B, c_o) + p)
        filt = xp.empty((c_o, c_i) + n, dtype=output.dtype)
        cudnn.convolution_backward_filter(data,
                                          output,
                                          filt,
                                          pads,
                                          s,
                                          dilations,
                                          groups,
                                          deterministic=deterministic,
                                          auto_tune=auto_tune,
                                          tensor_core=tensor_core)
        filt = util.flip(filt, axes=range(-D, 0))
        filt = filt.reshape(filt_shape)

        return filt
コード例 #3
0
    def _convolve_data_adjoint_cuda(output,
                                    filt,
                                    data_shape,
                                    mode='full',
                                    strides=None,
                                    multi_channel=False):
        device = backend.get_device(output)
        xp = device.xp

        D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
            data_shape, filt.shape, mode, strides, multi_channel)
        dilations = (1, ) * D
        groups = 1
        auto_tune = True
        tensor_core = 'auto'
        deterministic = False
        if mode == 'full':
            pads = tuple(n_d - 1 for n_d in n)
        elif mode == 'valid':
            pads = (0, ) * D

        with device:
            output = output.reshape((B, c_o) + p)
            filt = filt.reshape((c_o, c_i) + n)
            data = xp.empty((B, c_i) + m, dtype=output.dtype)
            filt = util.flip(filt, axes=range(-D, 0))
            cudnn.convolution_backward_data(filt,
                                            output,
                                            None,
                                            data,
                                            pads,
                                            s,
                                            dilations,
                                            groups,
                                            deterministic=deterministic,
                                            auto_tune=auto_tune,
                                            tensor_core=tensor_core)

            # Reshape.
            data = data.reshape(data_shape)

        return data
コード例 #4
0
    def _convolve_cuda(data,
                       filt,
                       mode='full',
                       strides=None,
                       multi_channel=False):
        device = backend.get_device(data)
        xp = device.xp

        D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
            data.shape, filt.shape, mode, strides, multi_channel)
        dilations = (1, ) * D
        groups = 1
        auto_tune = True
        tensor_core = 'auto'
        if mode == 'full':
            pads = tuple(n_d - 1 for n_d in n)
        elif mode == 'valid':
            pads = (0, ) * D

        with device:
            data = data.reshape((B, c_i) + m)
            filt = filt.reshape((c_o, c_i) + n)
            output = xp.empty((B, c_o) + p, dtype=data.dtype)
            filt = util.flip(filt, axes=range(-D, 0))
            cudnn.convolution_forward(data,
                                      filt,
                                      None,
                                      output,
                                      pads,
                                      s,
                                      dilations,
                                      groups,
                                      auto_tune=auto_tune,
                                      tensor_core=tensor_core)

            # Reshape.
            if multi_channel:
                output = output.reshape(b + (c_o, ) + p)
            else:
                output = output.reshape(b + p)

        return output
コード例 #5
0
    def _convolve_filter_adjoint_cuda(output,
                                      data,
                                      filt_shape,
                                      mode='full',
                                      strides=None,
                                      multi_channel=False):
        xp = backend.get_array_module(data)

        D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
            data.shape, filt_shape, mode, strides, multi_channel)
        dilations = (1, ) * D
        groups = 1
        auto_tune = True
        tensor_core = 'auto'
        deterministic = False
        if mode == 'full':
            pads = tuple(n_d - 1 for n_d in n)
        elif mode == 'valid':
            pads = (0, ) * D

        data = data.reshape((B, c_i) + m)
        output = output.reshape((B, c_o) + p)
        filt = xp.empty((c_o, c_i) + n, dtype=output.dtype)
        cudnn.convolution_backward_filter(data,
                                          output,
                                          filt,
                                          pads,
                                          s,
                                          dilations,
                                          groups,
                                          deterministic=deterministic,
                                          auto_tune=auto_tune,
                                          tensor_core=tensor_core)
        filt = util.flip(filt, axes=range(-D, 0))
        filt = filt.reshape(filt_shape)

        return filt
コード例 #6
0
ファイル: linop.py プロジェクト: ShannonZ/sigpy
 def _apply(self, input):
     return util.flip(input, self.axes)
コード例 #7
0
ファイル: conv.py プロジェクト: jychengmri/sigpy
    def _cudnn_convolve_adjoint_input(W, y, mode='full'):
        dtype = y.dtype
        device = backend.get_device(y)
        xp = device.xp
        if np.issubdtype(dtype, np.complexfloating):
            with device:
                Wr = xp.real(W)
                Wi = xp.imag(W)
                yr = xp.real(y)
                yi = xp.imag(y)

                # Concatenate real and imaginary to input/output channels
                y = xp.concatenate([yr, yi], axis=1)
                W = xp.concatenate([
                    xp.concatenate([Wr, -Wi], axis=1),
                    xp.concatenate([Wi, Wr], axis=1)
                ],
                                   axis=0)

                x = _cudnn_convolve_adjoint_input(W, y, mode=mode)

                # Convert back to complex
                x = x[:, :x.shape[1] // 2] + 1j * x[:, x.shape[1] // 2:]
                x = x.astype(dtype)

                return x

        ndim = y.ndim - 2
        batch_size = len(y)
        input_channel = W.shape[1]
        output_shape = y.shape[-ndim:]
        filter_shape = W.shape[-ndim:]
        strides = (1, ) * ndim
        dilations = (1, ) * ndim
        groups = 1
        auto_tune = True
        tensor_core = 'auto'
        deterministic = False
        if mode == 'full':
            input_shape = tuple(p - n + 1
                                for p, n in zip(output_shape, filter_shape))
            pads = tuple(n - 1 for n in W.shape[2:])
        elif mode == 'valid':
            input_shape = tuple(p + n - 1
                                for p, n in zip(output_shape, filter_shape))
            pads = (0, ) * ndim

        with device:
            x = xp.empty((batch_size, input_channel) + input_shape,
                         dtype=dtype)
            W = util.flip(W, axes=range(-ndim, 0))
            cudnn.convolution_backward_data(W,
                                            y,
                                            None,
                                            x,
                                            pads,
                                            strides,
                                            dilations,
                                            groups,
                                            deterministic=deterministic,
                                            auto_tune=auto_tune,
                                            tensor_core=tensor_core)

        return x
コード例 #8
0
ファイル: conv.py プロジェクト: jychengmri/sigpy
    def _cudnn_convolve(x, W, mode='full'):
        dtype = x.dtype
        device = backend.get_device(x)
        xp = device.xp
        if np.issubdtype(dtype, np.complexfloating):
            with device:
                xr = xp.real(x)
                xi = xp.imag(x)
                Wr = xp.real(W)
                Wi = xp.imag(W)

                # Concatenate real and imaginary to input/output channels
                x = xp.concatenate([xr, xi], axis=1)
                W = xp.concatenate([
                    xp.concatenate([Wr, -Wi], axis=1),
                    xp.concatenate([Wi, Wr], axis=1)
                ],
                                   axis=0)

                y = _cudnn_convolve(x, W, mode=mode)

                # Convert back to complex
                y = y[:, :y.shape[1] // 2] + 1j * y[:, y.shape[1] // 2:]
                y = y.astype(dtype)

                return y

        ndim = x.ndim - 2
        batch_size = len(x)
        output_channel = W.shape[0]
        input_shape = x.shape[-ndim:]
        filter_shape = W.shape[-ndim:]
        strides = (1, ) * ndim
        dilations = (1, ) * ndim
        groups = 1
        auto_tune = True
        tensor_core = 'auto'
        if mode == 'full':
            output_shape = tuple(m + n - 1
                                 for m, n in zip(input_shape, filter_shape))
            pads = tuple(n - 1 for n in W.shape[2:])
        elif mode == 'valid':
            output_shape = tuple(m - n + 1
                                 for m, n in zip(input_shape, filter_shape))
            pads = (0, ) * ndim

        with device:
            y = xp.empty((batch_size, output_channel) + output_shape,
                         dtype=dtype)
            W = util.flip(W, axes=range(-ndim, 0))
            cudnn.convolution_forward(x,
                                      W,
                                      None,
                                      y,
                                      pads,
                                      strides,
                                      dilations,
                                      groups,
                                      auto_tune=auto_tune,
                                      tensor_core=tensor_core)

        return y
コード例 #9
0
 def _apply(self, input):
     device = backend.get_device(input)
     with device:
         return util.flip(input, self.axes)
コード例 #10
0
ファイル: conv.py プロジェクト: jtamir/sigpy
def _cudnn_convolve_adjoint_filter(x, y, mode='full'):
    dtype = y.dtype
    device = backend.get_device(y)
    xp = device.xp
    if np.issubdtype(dtype, np.complexfloating):
        with device:
            xr = xp.real(x)
            xi = xp.imag(x)
            yr = xp.real(y)
            yi = xp.imag(y)

            # Concatenate real and imaginary to input/output channels
            x = xp.concatenate([xr, xi], axis=1)
            y = xp.concatenate([yr, yi], axis=1)

            W = _cudnn_convolve_adjoint_filter(x, y, mode=mode)

            # Convert back to complex
            Wr = W[:W.shape[0] // 2, :W.shape[1] // 2]
            Wr += W[W.shape[0] // 2:, W.shape[1] // 2:]
            Wi = W[W.shape[0] // 2:, :W.shape[1] // 2]
            Wi -= W[:W.shape[0] // 2, W.shape[1] // 2:]
            return (Wr + 1j * Wi).astype(dtype)

    ndim = y.ndim - 2
    batch_size = len(y)
    input_channel = x.shape[1]
    output_channel = y.shape[1]
    input_shape = x.shape[-ndim:]
    output_shape = y.shape[-ndim:]
    strides = (1, ) * ndim
    dilations = (1, ) * ndim
    groups = 1
    auto_tune = True
    tensor_core = 'auto'
    deterministic = False
    if mode == 'full':
        filter_shape = tuple(p - m + 1
                             for m, p in zip(input_shape, output_shape))
        pads = tuple(n - 1 for n in filter_shape)
    elif mode == 'valid':
        filter_shape = tuple(m - p + 1
                             for m, p in zip(input_shape, output_shape))
        pads = (0, ) * ndim

    with device:
        W = xp.empty((output_channel, input_channel) + filter_shape,
                     dtype=dtype)
        cudnn.convolution_backward_filter(x,
                                          y,
                                          W,
                                          pads,
                                          strides,
                                          dilations,
                                          groups,
                                          deterministic=deterministic,
                                          auto_tune=auto_tune,
                                          tensor_core=tensor_core)
        W = util.flip(W, axes=range(-ndim, 0))

    return W