コード例 #1
0
ファイル: FFT.py プロジェクト: gentaiscool/scatternet
 def c2r(self, input):
     output = input.new(input.size()[:-1])
     if (self.fft_cache[(input.size(), cufft.CUFFT_C2R, input.get_device())]
             is None):
         self.buildCache(input, cufft.CUFFT_C2R)
     cufft.cufftExecC2R(
         self.fft_cache[(input.size(), cufft.CUFFT_C2R,
                         input.get_device())], input.data_ptr(),
         output.data_ptr())
     return output
コード例 #2
0
ファイル: utils.py プロジェクト: eugenium/pyscatlight
    def __call__(self, input, direction='C2C', inplace=False, inverse=False):
        if direction == 'C2R':
            inverse = True

        if not isinstance(input, torch.cuda.FloatTensor):
            if not isinstance(input, (torch.FloatTensor, torch.DoubleTensor)):
                raise(TypeError('The input should be a torch.cuda.FloatTensor, \
                                torch.FloatTensor or a torch.DoubleTensor'))
            else:
                input_np = input[..., 0].numpy() + 1.0j * input[..., 1].numpy()
                f = lambda x: np.stack((np.real(x), np.imag(x)), axis=len(x.shape))
                out_type = input.numpy().dtype

                if direction == 'C2R':
                    out = np.real(np.fft.ifft2(input_np)).astype(out_type)*input.size(-2)*input.size(-3)
                    return torch.from_numpy(out)

                if inplace:
                    if inverse:
                        out = f(np.fft.ifft2(input_np)).astype(out_type)*input.size(-2)*input.size(-3)
                    else:
                        out = f(np.fft.fft2(input_np)).astype(out_type)
                    input.copy_(torch.from_numpy(out))
                    return
                else:
                    if inverse:
                        out = f(np.fft.ifft2(input_np)).astype(out_type)*input.size(-2)*input.size(-3)
                    else:
                        out = f(np.fft.fft2(input_np)).astype(out_type)
                    return torch.from_numpy(out)

        if not iscomplex(input):
            raise(TypeError('The input should be complex (e.g. last dimension is 2)'))

        if (not input.is_contiguous()):
            raise (RuntimeError('Tensors must be contiguous!'))

        if direction == 'C2R':
            output = input.new(input.size()[:-1])
            if(self.fft_cache[(input.size(), cufft.CUFFT_C2R, input.get_device())] is None):
                self.buildCache(input, cufft.CUFFT_C2R)
            cufft.cufftExecC2R(self.fft_cache[(input.size(), cufft.CUFFT_C2R, input.get_device())],
                               input.data_ptr(), output.data_ptr())
            return output
        elif direction == 'C2C':
            output = input.new(input.size()) if not inplace else input
            flag = cufft.CUFFT_INVERSE if inverse else cufft.CUFFT_FORWARD
            if (self.fft_cache[(input.size(), cufft.CUFFT_C2C, input.get_device())] is None):
                self.buildCache(input, cufft.CUFFT_C2C)
            cufft.cufftExecC2C(self.fft_cache[(input.size(), cufft.CUFFT_C2C, input.get_device())],
                               input.data_ptr(), output.data_ptr(), flag)
            return output
コード例 #3
0
    def __call__(self, input, direction='C2C', inplace=False, inverse=False):
        if direction == 'C2R':
            inverse = True

        if not isinstance(input, torch.cuda.FloatTensor):
            if not isinstance(input, (torch.FloatTensor, torch.DoubleTensor)):
                raise(TypeError('The input should be a torch.cuda.FloatTensor, \
                                torch.FloatTensor or a torch.DoubleTensor'))
            else:
                input_np = input[..., 0].numpy() + 1.0j * input[..., 1].numpy()
                f = lambda x: np.stack((np.real(x), np.imag(x)), axis=len(x.shape))
                out_type = input.numpy().dtype

                if direction == 'C2R':
                    out = np.real(np.fft.ifft2(input_np)).astype(out_type)*input.size(-2)*input.size(-3)
                    return torch.from_numpy(out)

                if inplace:
                    if inverse:
                        out = f(np.fft.ifft2(input_np)).astype(out_type)*input.size(-2)*input.size(-3)
                    else:
                        out = f(np.fft.fft2(input_np)).astype(out_type)
                    input.copy_(torch.from_numpy(out))
                    return
                else:
                    if inverse:
                        out = f(np.fft.ifft2(input_np)).astype(out_type)*input.size(-2)*input.size(-3)
                    else:
                        out = f(np.fft.fft2(input_np)).astype(out_type)
                    return torch.from_numpy(out)

        if not iscomplex(input):
            raise(TypeError('The input should be complex (e.g. last dimension is 2)'))

        if (not input.is_contiguous()):
            raise (RuntimeError('Tensors must be contiguous!'))

        if direction == 'C2R':
            output = input.new(input.size()[:-1])
            if(self.fft_cache[(input.size(), cufft.CUFFT_C2R, input.get_device())] is None):
                self.buildCache(input, cufft.CUFFT_C2R)
            cufft.cufftExecC2R(self.fft_cache[(input.size(), cufft.CUFFT_C2R, input.get_device())],
                               input.data_ptr(), output.data_ptr())
            return output
        elif direction == 'C2C':
            output = input.new(input.size()) if not inplace else input
            flag = cufft.CUFFT_INVERSE if inverse else cufft.CUFFT_FORWARD
            if (self.fft_cache[(input.size(), cufft.CUFFT_C2C, input.get_device())] is None):
                self.buildCache(input, cufft.CUFFT_C2C)
            cufft.cufftExecC2C(self.fft_cache[(input.size(), cufft.CUFFT_C2C, input.get_device())],
                               input.data_ptr(), output.data_ptr(), flag)
            return output