def _cudnn_convolve(x, W, mode='full'): dtype = x.dtype device = backend.get_device(x) xp = device.xp if np.issubdtype(dtype, np.complexfloating): with device: xr = xp.real(x) xi = xp.imag(x) Wr = xp.real(W) Wi = xp.imag(W) # Concatenate real and imaginary to input/output channels x = xp.concatenate([xr, xi], axis=1) W = xp.concatenate([ xp.concatenate([Wr, -Wi], axis=1), xp.concatenate([Wi, Wr], axis=1) ], axis=0) y = _cudnn_convolve(x, W, mode=mode) # Convert back to complex return (y[:, :y.shape[1] // 2] + 1j * y[:, y.shape[1] // 2:]).astype(dtype) ndim = x.ndim - 2 batch_size = len(x) output_channel = W.shape[0] input_shape = x.shape[-ndim:] filter_shape = W.shape[-ndim:] strides = (1, ) * ndim dilations = (1, ) * ndim groups = 1 auto_tune = True tensor_core = 'auto' if mode == 'full': output_shape = tuple(m + n - 1 for m, n in zip(input_shape, filter_shape)) pads = tuple(n - 1 for n in W.shape[2:]) elif mode == 'valid': output_shape = tuple(m - n + 1 for m, n in zip(input_shape, filter_shape)) pads = (0, ) * ndim with device: y = xp.empty((batch_size, output_channel) + output_shape, dtype=dtype) W = util.flip(W, axes=range(-ndim, 0)) cudnn.convolution_forward(x, W, None, y, pads, strides, dilations, groups, auto_tune=auto_tune, tensor_core=tensor_core) return y
def call(self): cudnn.convolution_forward(self.x, self.W, self.b, self.y, self.pads, self.strides, self.dilations, self.groups, auto_tune=self.auto_tune, tensor_core=self.tensor_core)
def _convolve_cuda(data, filt, mode='full', strides=None, multi_channel=False): xp = backend.get_array_module(data) D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params( data.shape, filt.shape, mode, strides, multi_channel) if D == 1: return _convolve_cuda(xp.expand_dims(data, -1), xp.expand_dims(filt, -1), mode=mode, strides=list(strides) + [1] if strides is not None else None, multi_channel=multi_channel).squeeze(-1) elif D > 3: raise ValueError( f'cuDNN convolution only supports 1, 2, or 3D, got {D}.') dilations = (1, ) * D groups = 1 auto_tune = True tensor_core = 'auto' if mode == 'full': pads = tuple(n_d - 1 for n_d in n) elif mode == 'valid': pads = (0, ) * D data = data.reshape((B, c_i) + m) filt = filt.reshape((c_o, c_i) + n) output = xp.empty((B, c_o) + p, dtype=data.dtype) filt = util.flip(filt, axes=range(-D, 0)) cudnn.convolution_forward(data, filt, None, output, pads, s, dilations, groups, auto_tune=auto_tune, tensor_core=tensor_core) # Reshape. if multi_channel: output = output.reshape(b + (c_o, ) + p) else: output = output.reshape(b + p) return output
def _convolve_cuda(data, filt, mode='full', strides=None, multi_channel=False): device = backend.get_device(data) xp = device.xp D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params( data.shape, filt.shape, mode, strides, multi_channel) dilations = (1, ) * D groups = 1 auto_tune = True tensor_core = 'auto' if mode == 'full': pads = tuple(n_d - 1 for n_d in n) elif mode == 'valid': pads = (0, ) * D with device: data = data.reshape((B, c_i) + m) filt = filt.reshape((c_o, c_i) + n) output = xp.empty((B, c_o) + p, dtype=data.dtype) filt = util.flip(filt, axes=range(-D, 0)) cudnn.convolution_forward(data, filt, None, output, pads, s, dilations, groups, auto_tune=auto_tune, tensor_core=tensor_core) # Reshape. if multi_channel: output = output.reshape(b + (c_o, ) + p) else: output = output.reshape(b + p) return output