def idct_N(x, expk=None):
    N = x.size(-1)

    if expk is None:
        expk = get_expk(N, dtype=x.dtype, device=x.device)

    size = list(x.size())
    size.append(2)
    x_reorder = torch.zeros(size, dtype=x.dtype, device=x.device)
    x_reorder[..., 0] = x
    x_reorder[..., 1:, 1] = x.flip([x.ndimension() - 1])[..., :N - 1].mul_(-1)

    x_reorder[..., 0] = x.mul(expk[..., 0]).sub_(x_reorder[...,
                                                           1].mul(expk[...,
                                                                       1]))
    x_reorder[..., 1].mul_(expk[..., 0])
    x_reorder[..., 1].add_(x.mul(expk[..., 1]))
    # this is to match idct_2N
    # normal way should multiply 0.25
    x_reorder.mul_(0.5)

    y = torch_fft_api.ifft(x_reorder, signal_ndim=1, normalized=False)
    y.mul_(N)

    z = torch.empty_like(x)
    z[..., 0:N:2] = y[..., :(N + 1) // 2, 0]
    z[..., 1:N:2] = y[..., (N + 1) // 2:, 0].flip([x.ndimension() - 1])

    return z
예제 #2
0
def test_fft_3D(N, dtype):
    x = torch.empty(2, N, N, N, 2, dtype=dtype).uniform_(0, 10.0)
    y1 = torch_fft_api.fft(x, 3, False)
    y2 = torch_fft_api.fft(x, 3, True)
    x1_hat = torch_fft_api.ifft(y1, 3, False)
    x2_hat = torch_fft_api.ifft(y2, 3, True)

    print("x")
    print(x)
    print("y1")
    print(y1)
    print("y2")
    print(y2)
    print("x1_hat")
    print(x1_hat)
    print("x2_hat")
    print(x2_hat)
def idxt(x, cos_or_sin_flag, expk=None):
    """ Batch Inverse Discrete Cosine Transformation without normalization to coefficients.
    Compute y_u = \sum_i  x_i cos(pi*(2u+1)*i/(2N)),
    Impelements the 2N padding trick to solve IDCT with IFFT in the following link,
    https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/spectral_ops.py

    1. Multiply by 2*exp(1j*pi*u/(2N))
    2. Pad x by zeros
    3. Perform IFFT
    4. Extract the real part

    @param x batch 1D tensor for conversion
    @param cos_or_sin_flag 0 for cosine tranformation and 1 or sine transformation
    @param expk 2*exp(j*pi*k/(2N))
    """
    # last dimension
    N = x.size(-1)

    if expk is None:
        expk = get_expk(N, dtype=x.dtype, device=x.device)

    # multiply by 2*exp(1j*pi*u/(2N))
    x_pad = x.unsqueeze(-1).mul(expk)
    # pad second last dimension, excluding the complex number dimension
    x_pad = F.pad(x_pad, (0, 0, 0, N), 'constant', 0)

    if len(x.size()) == 1:
        x_pad.unsqueeze_(0)

    # the last dimension here becomes -2 because complex numbers introduce a new dimension
    # Must use IFFT here
    y = torch_fft_api.ifft(x_pad, signal_ndim=1,
                           normalized=False)[..., 0:N, cos_or_sin_flag]
    y.mul_(N)

    if len(x.size()) == 1:
        y.squeeze_(0)

    return y