def _fft_c2r( func_name: str, input: TensorLikeType, n: Optional[int], dim: int, norm: NormType, forward: bool, ) -> TensorLikeType: """Common code for performing any complex to real FFT (irfft or hfft)""" input = _maybe_promote_tensor_fft(input, require_complex=True) dims = (utils.canonicalize_dim(input.ndim, dim), ) last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1) check(last_dim_size >= 1, lambda: f"Invalid number of data points ({n}) specified") if n is not None: input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1, )) if forward: input = torch.conj(input) output = prims.fft_c2r(input, dim=dims, last_dim_size=last_dim_size) return _apply_norm(output, norm=norm, signal_numel=last_dim_size, forward=forward)
def glu_backward(grad_output: Tensor, self: Tensor, dim: int) -> Tensor: assert self.dim() > 0, "glu does not support 0-dimensional tensors" wrap_dim = utils.canonicalize_dim(self.dim(), dim) nIn = self.size(wrap_dim) assert ( nIn % 2 == 0 ), f"Halving dimension must be even, but dimension {wrap_dim} is size {nIn}" inputSize = nIn // 2 firstHalf = self.narrow(wrap_dim, 0, inputSize) secondHalf = self.narrow(wrap_dim, inputSize, inputSize) gradInputFirstHalf = torch.sigmoid(secondHalf) gradInputSecondHalf = ((1.0 - gradInputFirstHalf) * gradInputFirstHalf * firstHalf * grad_output) gradInputFirstHalf = gradInputFirstHalf * grad_output return torch.cat([gradInputFirstHalf, gradInputSecondHalf], dim=wrap_dim)
def _fft_c2c( func_name: str, input: TensorLikeType, n: Optional[int], dim: int, norm: NormType, forward: bool, ) -> TensorLikeType: """Common code for performing any complex to complex FFT (fft or ifft)""" check( input.dtype.is_complex, lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}", ) dims = (utils.canonicalize_dim(input.ndim, dim), ) if n is not None: input = _resize_fft_input(input, dims, (n, )) ret = prims.fft_c2c(input, dim=dims, forward=forward) return _apply_norm(ret, norm, input.shape[dim], forward)
def _fft_r2c( func_name: str, input: TensorLikeType, n: Optional[int], dim: int, norm: NormType, forward: bool, onesided: bool, ) -> TensorLikeType: """Common code for performing any real to complex FFT (rfft or ihfft)""" check( not input.dtype.is_complex, lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}", ) input = _maybe_promote_tensor_fft(input) dims = (utils.canonicalize_dim(input.ndim, dim), ) if n is not None: input = _resize_fft_input(input, dims, (n, )) ret = prims.fft_r2c(input, dim=dims, onesided=onesided) ret = _apply_norm(ret, norm, input.shape[dim], forward) return ret if forward else torch.conj(ret)