def c2r(self, input): output = input.new(input.size()[:-1]) if (self.fft_cache[(input.size(), cufft.CUFFT_C2R, input.get_device())] is None): self.buildCache(input, cufft.CUFFT_C2R) cufft.cufftExecC2R( self.fft_cache[(input.size(), cufft.CUFFT_C2R, input.get_device())], input.data_ptr(), output.data_ptr()) return output
def __call__(self, input, direction='C2C', inplace=False, inverse=False): if direction == 'C2R': inverse = True if not isinstance(input, torch.cuda.FloatTensor): if not isinstance(input, (torch.FloatTensor, torch.DoubleTensor)): raise(TypeError('The input should be a torch.cuda.FloatTensor, \ torch.FloatTensor or a torch.DoubleTensor')) else: input_np = input[..., 0].numpy() + 1.0j * input[..., 1].numpy() f = lambda x: np.stack((np.real(x), np.imag(x)), axis=len(x.shape)) out_type = input.numpy().dtype if direction == 'C2R': out = np.real(np.fft.ifft2(input_np)).astype(out_type)*input.size(-2)*input.size(-3) return torch.from_numpy(out) if inplace: if inverse: out = f(np.fft.ifft2(input_np)).astype(out_type)*input.size(-2)*input.size(-3) else: out = f(np.fft.fft2(input_np)).astype(out_type) input.copy_(torch.from_numpy(out)) return else: if inverse: out = f(np.fft.ifft2(input_np)).astype(out_type)*input.size(-2)*input.size(-3) else: out = f(np.fft.fft2(input_np)).astype(out_type) return torch.from_numpy(out) if not iscomplex(input): raise(TypeError('The input should be complex (e.g. last dimension is 2)')) if (not input.is_contiguous()): raise (RuntimeError('Tensors must be contiguous!')) if direction == 'C2R': output = input.new(input.size()[:-1]) if(self.fft_cache[(input.size(), cufft.CUFFT_C2R, input.get_device())] is None): self.buildCache(input, cufft.CUFFT_C2R) cufft.cufftExecC2R(self.fft_cache[(input.size(), cufft.CUFFT_C2R, input.get_device())], input.data_ptr(), output.data_ptr()) return output elif direction == 'C2C': output = input.new(input.size()) if not inplace else input flag = cufft.CUFFT_INVERSE if inverse else cufft.CUFFT_FORWARD if (self.fft_cache[(input.size(), cufft.CUFFT_C2C, input.get_device())] is None): self.buildCache(input, cufft.CUFFT_C2C) cufft.cufftExecC2C(self.fft_cache[(input.size(), cufft.CUFFT_C2C, input.get_device())], input.data_ptr(), output.data_ptr(), flag) return output