def _fft_convolve_adjoint_filter(x, y, mode='full'): ndim = x.ndim - 2 batch_size = len(x) output_channel = y.shape[1] input_channel = x.shape[1] output_shape = y.shape[-ndim:] input_shape = x.shape[-ndim:] if mode == 'full': filter_shape = tuple(p - m + 1 for m, p in zip(input_shape, output_shape)) pad_shape = output_shape elif mode == 'valid': filter_shape = tuple(m - p + 1 for m, p in zip(input_shape, output_shape)) pad_shape = input_shape dtype = x.dtype device = backend.get_device(x) xp = device.xp with device: x = xp.conj(util.flip(x, axes=range(-ndim, 0))) x = x.reshape((batch_size, 1, input_channel) + input_shape) y = y.reshape((batch_size, output_channel, 1) + output_shape) x_pad = util.resize(x, (batch_size, 1, input_channel) + pad_shape, oshift=[0] * x.ndim) y_pad = util.resize(y, (batch_size, output_channel, 1) + pad_shape, oshift=[0] * y.ndim) if np.issubdtype(dtype, np.floating): x_fft = xp.fft.rfftn(x_pad, axes=range(-ndim, 0), norm='ortho') y_fft = xp.fft.rfftn(y_pad, axes=range(-ndim, 0), norm='ortho') W_fft = xp.sum(x_fft * y_fft, axis=0) W = xp.fft.irfftn(W_fft, pad_shape, axes=range(-ndim, 0), norm='ortho').astype(dtype) else: x_fft = fourier.fft(x_pad, axes=range(-ndim, 0), center=False) y_fft = fourier.fft(y_pad, axes=range(-ndim, 0), center=False) W_fft = xp.sum(x_fft * y_fft, axis=0) W = fourier.ifft(W_fft, axes=range(-ndim, 0), center=False) if mode == 'full': shift = [0, 0] + [m - 1 for m in input_shape] elif mode == 'valid': shift = [0, 0] + [p - 1 for p in output_shape] W = util.resize(W, (output_channel, input_channel) + filter_shape, ishift=shift) W *= util.prod(pad_shape)**0.5 return W
def _fft_convolve_adjoint_input(W, y, mode='full'): ndim = y.ndim - 2 batch_size = len(y) output_channel, input_channel = W.shape[:2] output_shape = y.shape[-ndim:] filter_shape = W.shape[-ndim:] if mode == 'full': input_shape = tuple(p - n + 1 for p, n in zip(output_shape, filter_shape)) pad_shape = output_shape elif mode == 'valid': input_shape = tuple(p + n - 1 for p, n in zip(output_shape, filter_shape)) pad_shape = input_shape dtype = y.dtype device = backend.get_device(y) xp = device.xp with device: y = y.reshape((batch_size, output_channel, 1) + output_shape) W = xp.conj(util.flip(W, axes=range(-ndim, 0))) y_pad = util.resize(y, (batch_size, output_channel, 1) + pad_shape, oshift=[0] * y.ndim) W_pad = util.resize(W, (output_channel, input_channel) + pad_shape, oshift=[0] * W.ndim) if np.issubdtype(dtype, np.floating): y_fft = xp.fft.rfftn(y_pad, axes=range(-ndim, 0), norm='ortho') W_fft = xp.fft.rfftn(W_pad, axes=range(-ndim, 0), norm='ortho') x_fft = xp.sum(y_fft * W_fft, axis=-ndim - 2) x = xp.fft.irfftn(x_fft, pad_shape, axes=range(-ndim, 0), norm='ortho').astype(dtype) else: y_fft = fourier.fft(y_pad, axes=range(-ndim, 0), center=False) W_fft = fourier.fft(W_pad, axes=range(-ndim, 0), center=False) x_fft = xp.sum(y_fft * W_fft, axis=-ndim - 2) x = fourier.ifft(x_fft, axes=range(-ndim, 0), center=False) if mode == 'full': shift = [0, 0] + [n - 1 for n in filter_shape] elif mode == 'valid': shift = [0] * x.ndim x = util.resize(x, (batch_size, input_channel) + input_shape, ishift=shift) x *= util.prod(pad_shape)**0.5 return x
def _fft_convolve(x, W, mode='full'): ndim = x.ndim - 2 batch_size = len(x) output_channel, input_channel = W.shape[:2] input_shape = x.shape[-ndim:] filter_shape = W.shape[-ndim:] if mode == 'full': output_shape = tuple(m + n - 1 for m, n in zip(input_shape, filter_shape)) pad_shape = output_shape elif mode == 'valid': output_shape = tuple(m - n + 1 for m, n in zip(input_shape, filter_shape)) pad_shape = input_shape dtype = x.dtype device = backend.get_device(x) xp = device.xp with device: x = x.reshape((batch_size, 1, input_channel) + input_shape) x_pad = util.resize(x, (batch_size, 1, input_channel) + pad_shape, oshift=[0] * x.ndim) W_pad = util.resize(W, (output_channel, input_channel) + pad_shape, oshift=[0] * W.ndim) if np.issubdtype(dtype, np.floating): x_fft = xp.fft.rfftn(x_pad, axes=range(-ndim, 0), norm='ortho') W_fft = xp.fft.rfftn(W_pad, axes=range(-ndim, 0), norm='ortho') y_fft = xp.sum(x_fft * W_fft, axis=-ndim - 1) y = xp.fft.irfftn(y_fft, pad_shape, axes=range(-ndim, 0), norm='ortho').astype(dtype) else: x_fft = fourier.fft(x_pad, axes=range(-ndim, 0), center=False) W_fft = fourier.fft(W_pad, axes=range(-ndim, 0), center=False) y_fft = xp.sum(x_fft * W_fft, axis=-ndim - 1) y = fourier.ifft(y_fft, axes=range(-ndim, 0), center=False) if mode == 'full': shift = [0] * y.ndim elif mode == 'valid': shift = [0, 0] + [n - 1 for n in filter_shape] y = util.resize(y, (batch_size, output_channel) + output_shape, ishift=shift) y *= util.prod(pad_shape)**0.5 return y
def test_fft_dtype(self): for dtype in [np.complex64, np.complex128]: input = np.array([0, 1, 0], dtype=dtype) output = fourier.fft(input) assert output.dtype == dtype
def test_fft(self): input = np.array([0, 1, 0], dtype=np.complex) npt.assert_allclose(fourier.fft(input), np.ones(3) / 3**0.5, atol=1e-5) input = np.array([1, 1, 1], dtype=np.complex) npt.assert_allclose(fourier.fft(input), [0, 3**0.5, 0], atol=1e-5) input = util.randn([4, 5, 6]) npt.assert_allclose(fourier.fft(input), np.fft.fftshift( np.fft.fftn(np.fft.ifftshift(input), norm='ortho')), atol=1e-5) input = np.array([0, 1, 0], dtype=np.complex) npt.assert_allclose(fourier.fft(input, oshape=[5]), np.ones(5) / 5**0.5, atol=1e-5)
def _apply(self, input): return fourier.fft(input, axes=self.axes, center=self.center)
def _apply(self, input): with backend.get_device(input): return fourier.fft(input, axes=self.axes, center=self.center)