def idct_N(x, expk=None): N = x.size(-1) if expk is None: expk = get_expk(N, dtype=x.dtype, device=x.device) size = list(x.size()) size.append(2) x_reorder = torch.zeros(size, dtype=x.dtype, device=x.device) x_reorder[..., 0] = x x_reorder[..., 1:, 1] = x.flip([x.ndimension() - 1])[..., :N - 1].mul_(-1) x_reorder[..., 0] = x.mul(expk[..., 0]).sub_(x_reorder[..., 1].mul(expk[..., 1])) x_reorder[..., 1].mul_(expk[..., 0]) x_reorder[..., 1].add_(x.mul(expk[..., 1])) # this is to match idct_2N # normal way should multiply 0.25 x_reorder.mul_(0.5) y = torch_fft_api.ifft(x_reorder, signal_ndim=1, normalized=False) y.mul_(N) z = torch.empty_like(x) z[..., 0:N:2] = y[..., :(N + 1) // 2, 0] z[..., 1:N:2] = y[..., (N + 1) // 2:, 0].flip([x.ndimension() - 1]) return z
def test_fft_3D(N, dtype): x = torch.empty(2, N, N, N, 2, dtype=dtype).uniform_(0, 10.0) y1 = torch_fft_api.fft(x, 3, False) y2 = torch_fft_api.fft(x, 3, True) x1_hat = torch_fft_api.ifft(y1, 3, False) x2_hat = torch_fft_api.ifft(y2, 3, True) print("x") print(x) print("y1") print(y1) print("y2") print(y2) print("x1_hat") print(x1_hat) print("x2_hat") print(x2_hat)
def idxt(x, cos_or_sin_flag, expk=None): """ Batch Inverse Discrete Cosine Transformation without normalization to coefficients. Compute y_u = \sum_i x_i cos(pi*(2u+1)*i/(2N)), Impelements the 2N padding trick to solve IDCT with IFFT in the following link, https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/spectral_ops.py 1. Multiply by 2*exp(1j*pi*u/(2N)) 2. Pad x by zeros 3. Perform IFFT 4. Extract the real part @param x batch 1D tensor for conversion @param cos_or_sin_flag 0 for cosine tranformation and 1 or sine transformation @param expk 2*exp(j*pi*k/(2N)) """ # last dimension N = x.size(-1) if expk is None: expk = get_expk(N, dtype=x.dtype, device=x.device) # multiply by 2*exp(1j*pi*u/(2N)) x_pad = x.unsqueeze(-1).mul(expk) # pad second last dimension, excluding the complex number dimension x_pad = F.pad(x_pad, (0, 0, 0, N), 'constant', 0) if len(x.size()) == 1: x_pad.unsqueeze_(0) # the last dimension here becomes -2 because complex numbers introduce a new dimension # Must use IFFT here y = torch_fft_api.ifft(x_pad, signal_ndim=1, normalized=False)[..., 0:N, cos_or_sin_flag] y.mul_(N) if len(x.size()) == 1: y.squeeze_(0) return y