def _fft_convolve_adjoint_input(W, y, mode='full'): ndim = y.ndim - 2 batch_size = len(y) output_channel, input_channel = W.shape[:2] output_shape = y.shape[-ndim:] filter_shape = W.shape[-ndim:] if mode == 'full': input_shape = tuple(p - n + 1 for p, n in zip(output_shape, filter_shape)) pad_shape = output_shape elif mode == 'valid': input_shape = tuple(p + n - 1 for p, n in zip(output_shape, filter_shape)) pad_shape = input_shape dtype = y.dtype device = backend.get_device(y) xp = device.xp with device: y = y.reshape((batch_size, output_channel, 1) + output_shape) W = xp.conj(util.flip(W, axes=range(-ndim, 0))) y_pad = util.resize(y, (batch_size, output_channel, 1) + pad_shape, oshift=[0] * y.ndim) W_pad = util.resize(W, (output_channel, input_channel) + pad_shape, oshift=[0] * W.ndim) if np.issubdtype(dtype, np.floating): y_fft = xp.fft.rfftn(y_pad, axes=range(-ndim, 0), norm='ortho') W_fft = xp.fft.rfftn(W_pad, axes=range(-ndim, 0), norm='ortho') x_fft = xp.sum(y_fft * W_fft, axis=-ndim - 2) x = xp.fft.irfftn(x_fft, pad_shape, axes=range(-ndim, 0), norm='ortho').astype(dtype) else: y_fft = fourier.fft(y_pad, axes=range(-ndim, 0), center=False) W_fft = fourier.fft(W_pad, axes=range(-ndim, 0), center=False) x_fft = xp.sum(y_fft * W_fft, axis=-ndim - 2) x = fourier.ifft(x_fft, axes=range(-ndim, 0), center=False) if mode == 'full': shift = [0, 0] + [n - 1 for n in filter_shape] elif mode == 'valid': shift = [0] * x.ndim x = util.resize(x, (batch_size, input_channel) + input_shape, ishift=shift) x *= util.prod(pad_shape)**0.5 return x
def _convolve_filter_adjoint_cuda(output, data, filt_shape, mode='full', strides=None, multi_channel=False): xp = backend.get_array_module(data) D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params( data.shape, filt_shape, mode, strides, multi_channel) if D == 1: return _convolve_filter_adjoint_cuda( xp.expand_dims(output, -1), xp.expand_dims(data, -1), list(filt_shape) + [1], mode=mode, strides=list(strides) + [1] if strides is not None else None, multi_channel=multi_channel).squeeze(-1) elif D > 3: raise ValueError( f'cuDNN convolution only supports 1, 2 or 3D, got {D}.') dilations = (1, ) * D groups = 1 auto_tune = True tensor_core = 'auto' deterministic = False if mode == 'full': pads = tuple(n_d - 1 for n_d in n) elif mode == 'valid': pads = (0, ) * D data = data.reshape((B, c_i) + m) output = output.reshape((B, c_o) + p) filt = xp.empty((c_o, c_i) + n, dtype=output.dtype) cudnn.convolution_backward_filter(data, output, filt, pads, s, dilations, groups, deterministic=deterministic, auto_tune=auto_tune, tensor_core=tensor_core) filt = util.flip(filt, axes=range(-D, 0)) filt = filt.reshape(filt_shape) return filt
def _convolve_data_adjoint_cuda(output, filt, data_shape, mode='full', strides=None, multi_channel=False): device = backend.get_device(output) xp = device.xp D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params( data_shape, filt.shape, mode, strides, multi_channel) dilations = (1, ) * D groups = 1 auto_tune = True tensor_core = 'auto' deterministic = False if mode == 'full': pads = tuple(n_d - 1 for n_d in n) elif mode == 'valid': pads = (0, ) * D with device: output = output.reshape((B, c_o) + p) filt = filt.reshape((c_o, c_i) + n) data = xp.empty((B, c_i) + m, dtype=output.dtype) filt = util.flip(filt, axes=range(-D, 0)) cudnn.convolution_backward_data(filt, output, None, data, pads, s, dilations, groups, deterministic=deterministic, auto_tune=auto_tune, tensor_core=tensor_core) # Reshape. data = data.reshape(data_shape) return data
def _convolve_cuda(data, filt, mode='full', strides=None, multi_channel=False): device = backend.get_device(data) xp = device.xp D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params( data.shape, filt.shape, mode, strides, multi_channel) dilations = (1, ) * D groups = 1 auto_tune = True tensor_core = 'auto' if mode == 'full': pads = tuple(n_d - 1 for n_d in n) elif mode == 'valid': pads = (0, ) * D with device: data = data.reshape((B, c_i) + m) filt = filt.reshape((c_o, c_i) + n) output = xp.empty((B, c_o) + p, dtype=data.dtype) filt = util.flip(filt, axes=range(-D, 0)) cudnn.convolution_forward(data, filt, None, output, pads, s, dilations, groups, auto_tune=auto_tune, tensor_core=tensor_core) # Reshape. if multi_channel: output = output.reshape(b + (c_o, ) + p) else: output = output.reshape(b + p) return output
def _convolve_filter_adjoint_cuda(output, data, filt_shape, mode='full', strides=None, multi_channel=False): xp = backend.get_array_module(data) D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params( data.shape, filt_shape, mode, strides, multi_channel) dilations = (1, ) * D groups = 1 auto_tune = True tensor_core = 'auto' deterministic = False if mode == 'full': pads = tuple(n_d - 1 for n_d in n) elif mode == 'valid': pads = (0, ) * D data = data.reshape((B, c_i) + m) output = output.reshape((B, c_o) + p) filt = xp.empty((c_o, c_i) + n, dtype=output.dtype) cudnn.convolution_backward_filter(data, output, filt, pads, s, dilations, groups, deterministic=deterministic, auto_tune=auto_tune, tensor_core=tensor_core) filt = util.flip(filt, axes=range(-D, 0)) filt = filt.reshape(filt_shape) return filt
def _apply(self, input): return util.flip(input, self.axes)
def _cudnn_convolve_adjoint_input(W, y, mode='full'): dtype = y.dtype device = backend.get_device(y) xp = device.xp if np.issubdtype(dtype, np.complexfloating): with device: Wr = xp.real(W) Wi = xp.imag(W) yr = xp.real(y) yi = xp.imag(y) # Concatenate real and imaginary to input/output channels y = xp.concatenate([yr, yi], axis=1) W = xp.concatenate([ xp.concatenate([Wr, -Wi], axis=1), xp.concatenate([Wi, Wr], axis=1) ], axis=0) x = _cudnn_convolve_adjoint_input(W, y, mode=mode) # Convert back to complex x = x[:, :x.shape[1] // 2] + 1j * x[:, x.shape[1] // 2:] x = x.astype(dtype) return x ndim = y.ndim - 2 batch_size = len(y) input_channel = W.shape[1] output_shape = y.shape[-ndim:] filter_shape = W.shape[-ndim:] strides = (1, ) * ndim dilations = (1, ) * ndim groups = 1 auto_tune = True tensor_core = 'auto' deterministic = False if mode == 'full': input_shape = tuple(p - n + 1 for p, n in zip(output_shape, filter_shape)) pads = tuple(n - 1 for n in W.shape[2:]) elif mode == 'valid': input_shape = tuple(p + n - 1 for p, n in zip(output_shape, filter_shape)) pads = (0, ) * ndim with device: x = xp.empty((batch_size, input_channel) + input_shape, dtype=dtype) W = util.flip(W, axes=range(-ndim, 0)) cudnn.convolution_backward_data(W, y, None, x, pads, strides, dilations, groups, deterministic=deterministic, auto_tune=auto_tune, tensor_core=tensor_core) return x
def _cudnn_convolve(x, W, mode='full'): dtype = x.dtype device = backend.get_device(x) xp = device.xp if np.issubdtype(dtype, np.complexfloating): with device: xr = xp.real(x) xi = xp.imag(x) Wr = xp.real(W) Wi = xp.imag(W) # Concatenate real and imaginary to input/output channels x = xp.concatenate([xr, xi], axis=1) W = xp.concatenate([ xp.concatenate([Wr, -Wi], axis=1), xp.concatenate([Wi, Wr], axis=1) ], axis=0) y = _cudnn_convolve(x, W, mode=mode) # Convert back to complex y = y[:, :y.shape[1] // 2] + 1j * y[:, y.shape[1] // 2:] y = y.astype(dtype) return y ndim = x.ndim - 2 batch_size = len(x) output_channel = W.shape[0] input_shape = x.shape[-ndim:] filter_shape = W.shape[-ndim:] strides = (1, ) * ndim dilations = (1, ) * ndim groups = 1 auto_tune = True tensor_core = 'auto' if mode == 'full': output_shape = tuple(m + n - 1 for m, n in zip(input_shape, filter_shape)) pads = tuple(n - 1 for n in W.shape[2:]) elif mode == 'valid': output_shape = tuple(m - n + 1 for m, n in zip(input_shape, filter_shape)) pads = (0, ) * ndim with device: y = xp.empty((batch_size, output_channel) + output_shape, dtype=dtype) W = util.flip(W, axes=range(-ndim, 0)) cudnn.convolution_forward(x, W, None, y, pads, strides, dilations, groups, auto_tune=auto_tune, tensor_core=tensor_core) return y
def _apply(self, input): device = backend.get_device(input) with device: return util.flip(input, self.axes)
def _cudnn_convolve_adjoint_filter(x, y, mode='full'): dtype = y.dtype device = backend.get_device(y) xp = device.xp if np.issubdtype(dtype, np.complexfloating): with device: xr = xp.real(x) xi = xp.imag(x) yr = xp.real(y) yi = xp.imag(y) # Concatenate real and imaginary to input/output channels x = xp.concatenate([xr, xi], axis=1) y = xp.concatenate([yr, yi], axis=1) W = _cudnn_convolve_adjoint_filter(x, y, mode=mode) # Convert back to complex Wr = W[:W.shape[0] // 2, :W.shape[1] // 2] Wr += W[W.shape[0] // 2:, W.shape[1] // 2:] Wi = W[W.shape[0] // 2:, :W.shape[1] // 2] Wi -= W[:W.shape[0] // 2, W.shape[1] // 2:] return (Wr + 1j * Wi).astype(dtype) ndim = y.ndim - 2 batch_size = len(y) input_channel = x.shape[1] output_channel = y.shape[1] input_shape = x.shape[-ndim:] output_shape = y.shape[-ndim:] strides = (1, ) * ndim dilations = (1, ) * ndim groups = 1 auto_tune = True tensor_core = 'auto' deterministic = False if mode == 'full': filter_shape = tuple(p - m + 1 for m, p in zip(input_shape, output_shape)) pads = tuple(n - 1 for n in filter_shape) elif mode == 'valid': filter_shape = tuple(m - p + 1 for m, p in zip(input_shape, output_shape)) pads = (0, ) * ndim with device: W = xp.empty((output_channel, input_channel) + filter_shape, dtype=dtype) cudnn.convolution_backward_filter(x, y, W, pads, strides, dilations, groups, deterministic=deterministic, auto_tune=auto_tune, tensor_core=tensor_core) W = util.flip(W, axes=range(-ndim, 0)) return W