Example #1
0
def _fft_convolve_adjoint_filter(x, y, mode='full'):
    ndim = x.ndim - 2
    batch_size = len(x)
    output_channel = y.shape[1]
    input_channel = x.shape[1]
    output_shape = y.shape[-ndim:]
    input_shape = x.shape[-ndim:]
    if mode == 'full':
        filter_shape = tuple(p - m + 1
                             for m, p in zip(input_shape, output_shape))
        pad_shape = output_shape
    elif mode == 'valid':
        filter_shape = tuple(m - p + 1
                             for m, p in zip(input_shape, output_shape))
        pad_shape = input_shape

    dtype = x.dtype
    device = backend.get_device(x)
    xp = device.xp
    with device:
        x = xp.conj(util.flip(x, axes=range(-ndim, 0)))
        x = x.reshape((batch_size, 1, input_channel) + input_shape)
        y = y.reshape((batch_size, output_channel, 1) + output_shape)

        x_pad = util.resize(x, (batch_size, 1, input_channel) + pad_shape,
                            oshift=[0] * x.ndim)
        y_pad = util.resize(y, (batch_size, output_channel, 1) + pad_shape,
                            oshift=[0] * y.ndim)

        if np.issubdtype(dtype, np.floating):
            x_fft = xp.fft.rfftn(x_pad, axes=range(-ndim, 0), norm='ortho')
            y_fft = xp.fft.rfftn(y_pad, axes=range(-ndim, 0), norm='ortho')
            W_fft = xp.sum(x_fft * y_fft, axis=0)
            W = xp.fft.irfftn(W_fft,
                              pad_shape,
                              axes=range(-ndim, 0),
                              norm='ortho').astype(dtype)
        else:
            x_fft = fourier.fft(x_pad, axes=range(-ndim, 0), center=False)
            y_fft = fourier.fft(y_pad, axes=range(-ndim, 0), center=False)
            W_fft = xp.sum(x_fft * y_fft, axis=0)
            W = fourier.ifft(W_fft, axes=range(-ndim, 0), center=False)

        if mode == 'full':
            shift = [0, 0] + [m - 1 for m in input_shape]
        elif mode == 'valid':
            shift = [0, 0] + [p - 1 for p in output_shape]

        W = util.resize(W, (output_channel, input_channel) + filter_shape,
                        ishift=shift)
        W *= util.prod(pad_shape)**0.5
        return W
Example #2
0
def _fft_convolve_adjoint_input(W, y, mode='full'):
    ndim = y.ndim - 2
    batch_size = len(y)
    output_channel, input_channel = W.shape[:2]
    output_shape = y.shape[-ndim:]
    filter_shape = W.shape[-ndim:]
    if mode == 'full':
        input_shape = tuple(p - n + 1
                            for p, n in zip(output_shape, filter_shape))
        pad_shape = output_shape
    elif mode == 'valid':
        input_shape = tuple(p + n - 1
                            for p, n in zip(output_shape, filter_shape))
        pad_shape = input_shape

    dtype = y.dtype
    device = backend.get_device(y)
    xp = device.xp
    with device:
        y = y.reshape((batch_size, output_channel, 1) + output_shape)
        W = xp.conj(util.flip(W, axes=range(-ndim, 0)))

        y_pad = util.resize(y, (batch_size, output_channel, 1) + pad_shape,
                            oshift=[0] * y.ndim)
        W_pad = util.resize(W, (output_channel, input_channel) + pad_shape,
                            oshift=[0] * W.ndim)

        if np.issubdtype(dtype, np.floating):
            y_fft = xp.fft.rfftn(y_pad, axes=range(-ndim, 0), norm='ortho')
            W_fft = xp.fft.rfftn(W_pad, axes=range(-ndim, 0), norm='ortho')
            x_fft = xp.sum(y_fft * W_fft, axis=-ndim - 2)
            x = xp.fft.irfftn(x_fft,
                              pad_shape,
                              axes=range(-ndim, 0),
                              norm='ortho').astype(dtype)
        else:
            y_fft = fourier.fft(y_pad, axes=range(-ndim, 0), center=False)
            W_fft = fourier.fft(W_pad, axes=range(-ndim, 0), center=False)
            x_fft = xp.sum(y_fft * W_fft, axis=-ndim - 2)
            x = fourier.ifft(x_fft, axes=range(-ndim, 0), center=False)

        if mode == 'full':
            shift = [0, 0] + [n - 1 for n in filter_shape]
        elif mode == 'valid':
            shift = [0] * x.ndim

        x = util.resize(x, (batch_size, input_channel) + input_shape,
                        ishift=shift)
        x *= util.prod(pad_shape)**0.5
        return x
Example #3
0
def _fft_convolve(x, W, mode='full'):
    ndim = x.ndim - 2
    batch_size = len(x)
    output_channel, input_channel = W.shape[:2]
    input_shape = x.shape[-ndim:]
    filter_shape = W.shape[-ndim:]
    if mode == 'full':
        output_shape = tuple(m + n - 1
                             for m, n in zip(input_shape, filter_shape))
        pad_shape = output_shape
    elif mode == 'valid':
        output_shape = tuple(m - n + 1
                             for m, n in zip(input_shape, filter_shape))
        pad_shape = input_shape

    dtype = x.dtype
    device = backend.get_device(x)
    xp = device.xp
    with device:
        x = x.reshape((batch_size, 1, input_channel) + input_shape)
        x_pad = util.resize(x, (batch_size, 1, input_channel) + pad_shape,
                            oshift=[0] * x.ndim)
        W_pad = util.resize(W, (output_channel, input_channel) + pad_shape,
                            oshift=[0] * W.ndim)

        if np.issubdtype(dtype, np.floating):
            x_fft = xp.fft.rfftn(x_pad, axes=range(-ndim, 0), norm='ortho')
            W_fft = xp.fft.rfftn(W_pad, axes=range(-ndim, 0), norm='ortho')
            y_fft = xp.sum(x_fft * W_fft, axis=-ndim - 1)
            y = xp.fft.irfftn(y_fft,
                              pad_shape,
                              axes=range(-ndim, 0),
                              norm='ortho').astype(dtype)
        else:
            x_fft = fourier.fft(x_pad, axes=range(-ndim, 0), center=False)
            W_fft = fourier.fft(W_pad, axes=range(-ndim, 0), center=False)
            y_fft = xp.sum(x_fft * W_fft, axis=-ndim - 1)
            y = fourier.ifft(y_fft, axes=range(-ndim, 0), center=False)

        if mode == 'full':
            shift = [0] * y.ndim
        elif mode == 'valid':
            shift = [0, 0] + [n - 1 for n in filter_shape]

        y = util.resize(y, (batch_size, output_channel) + output_shape,
                        ishift=shift)
        y *= util.prod(pad_shape)**0.5
        return y
Example #4
0
    def test_fft_dtype(self):

        for dtype in [np.complex64, np.complex128]:
            input = np.array([0, 1, 0], dtype=dtype)
            output = fourier.fft(input)

            assert output.dtype == dtype
Example #5
0
    def test_fft(self):
        input = np.array([0, 1, 0], dtype=np.complex)
        npt.assert_allclose(fourier.fft(input), np.ones(3) / 3**0.5, atol=1e-5)

        input = np.array([1, 1, 1], dtype=np.complex)
        npt.assert_allclose(fourier.fft(input), [0, 3**0.5, 0], atol=1e-5)

        input = util.randn([4, 5, 6])
        npt.assert_allclose(fourier.fft(input),
                            np.fft.fftshift(
                                np.fft.fftn(np.fft.ifftshift(input),
                                            norm='ortho')),
                            atol=1e-5)

        input = np.array([0, 1, 0], dtype=np.complex)
        npt.assert_allclose(fourier.fft(input, oshape=[5]),
                            np.ones(5) / 5**0.5,
                            atol=1e-5)
Example #6
0
 def _apply(self, input):
     return fourier.fft(input, axes=self.axes, center=self.center)
Example #7
0
 def _apply(self, input):
     with backend.get_device(input):
         return fourier.fft(input, axes=self.axes, center=self.center)