示例#1
0
文件: opencl.py 项目: vincefn/pyvkfft
 def ifft(self, src: cla.Array, dest: cla.Array = None):
     """
     Compute the backward FFT
     :param src: the source GPUarray
     :param dest: the destination GPUarray. Should be None for an inplace transform
     :return: the transformed array. For a C2R inplace transform, the float view of the
         array is returned.
     """
     if self.inplace:
         if dest is not None:
             if src.data.int_ptr != dest.data.int_ptr:
                 raise RuntimeError(
                     "VkFFTApp.fft: dest!=src but this is an inplace transform"
                 )
         if self.batch_shape is not None:
             if self.r2c:
                 src_shape = tuple(
                     list(self.batch_shape[:-1]) +
                     [self.batch_shape[-1] // 2])
                 s = src.reshape(src_shape)
             else:
                 s = src.reshape(self.batch_shape)
         else:
             s = src
         _vkfft_opencl.ifft(self.app, int(s.data.int_ptr),
                            int(s.data.int_ptr), int(self.queue.int_ptr))
         if self.norm == "ortho":
             if self.precision == 2:
                 src *= np.float16(self._get_ifft_scale(norm=0))
             elif self.precision == 4:
                 src *= np.float32(self._get_ifft_scale(norm=0))
             elif self.precision == 8:
                 src *= np.float64(self._get_ifft_scale(norm=0))
         if self.r2c:
             if src.dtype == np.complex64:
                 return src.view(dtype=np.float32)
             elif src.dtype == np.complex128:
                 return src.view(dtype=np.float64)
         return src
     if not self.inplace:
         if dest is None:
             raise RuntimeError(
                 "VkFFTApp.ifft: dest is None but this is an out-of-place transform"
             )
         elif src.data.int_ptr == dest.data.int_ptr:
             raise RuntimeError(
                 "VkFFTApp.ifft: dest and src are identical but this is an out-of-place transform"
             )
         if self.r2c:
             assert (dest.size == src.size // src.shape[-1] * 2 *
                     (src.shape[-1] - 1))
             # Special case, src and dest buffer sizes are different,
             # VkFFT is configured to go back to the source buffer
             if self.batch_shape is not None:
                 src_shape = tuple(
                     list(self.batch_shape[:-1]) +
                     [self.batch_shape[-1] // 2 + 1])
                 s = src.reshape(src_shape)
                 d = dest.reshape(self.batch_shape)
             else:
                 s, d = src, dest
             _vkfft_opencl.ifft(self.app, int(d.data.int_ptr),
                                int(s.data.int_ptr),
                                int(self.queue.int_ptr))
         else:
             if self.batch_shape is not None:
                 s = src.reshape(self.batch_shape)
                 d = dest.reshape(self.batch_shape)
             else:
                 s, d = src, dest
             _vkfft_opencl.ifft(self.app, int(s.data.int_ptr),
                                int(d.data.int_ptr),
                                int(self.queue.int_ptr))
         if self.norm == "ortho":
             if self.precision == 2:
                 dest *= np.float16(self._get_ifft_scale(norm=0))
             elif self.precision == 4:
                 dest *= np.float32(self._get_ifft_scale(norm=0))
             elif self.precision == 8:
                 dest *= np.float64(self._get_ifft_scale(norm=0))
         return dest
示例#2
0
文件: opencl.py 项目: vincefn/pyvkfft
 def fft(self, src: cla.Array, dest: cla.Array = None):
     """
     Compute the forward FFT
     :param src: the source pyopencl Array
     :param dest: the destination pyopencl Array. Should be None for an inplace transform
     :return: the transformed array. For a R2C inplace transform, the complex view of the
         array is returned.
     """
     if self.inplace:
         if dest is not None:
             if src.data.int_ptr != dest.data.int_ptr:
                 raise RuntimeError(
                     "VkFFTApp.fft: dest is not None but this is an inplace transform"
                 )
         if self.batch_shape is not None:
             s = src.reshape(self.batch_shape)
         else:
             s = src
         _vkfft_opencl.fft(self.app, int(s.data.int_ptr),
                           int(s.data.int_ptr), int(self.queue.int_ptr))
         if self.norm == "ortho":
             if self.precision == 2:
                 src *= np.float16(self._get_fft_scale(norm=0))
             elif self.precision == 4:
                 src *= np.float32(self._get_fft_scale(norm=0))
             elif self.precision == 8:
                 src *= np.float64(self._get_fft_scale(norm=0))
         if self.r2c:
             if src.dtype == np.float32:
                 return src.view(dtype=np.complex64)
             elif src.dtype == np.float64:
                 return src.view(dtype=np.complex128)
         return src
     else:
         if dest is None:
             raise RuntimeError(
                 "VkFFTApp.fft: dest is None but this is an out-of-place transform"
             )
         elif src.data.int_ptr == dest.data.int_ptr:
             raise RuntimeError(
                 "VkFFTApp.fft: dest and src are identical but this is an out-of-place transform"
             )
         if self.r2c:
             assert (src.size == dest.size // dest.shape[-1] * 2 *
                     (dest.shape[-1] - 1))
         if self.batch_shape is not None:
             s = src.reshape(self.batch_shape)
             if self.r2c:
                 c_shape = tuple(
                     list(self.batch_shape[:-1]) +
                     [self.batch_shape[-1] // 2 + 1])
                 d = dest.reshape(c_shape)
             else:
                 d = dest.reshape(self.batch_shape)
         else:
             s, d = src, dest
         _vkfft_opencl.fft(self.app, int(s.data.int_ptr),
                           int(d.data.int_ptr), int(self.queue.int_ptr))
         if self.norm == "ortho":
             if self.precision == 2:
                 dest *= np.float16(self._get_fft_scale(norm=0))
             elif self.precision == 4:
                 dest *= np.float32(self._get_fft_scale(norm=0))
             elif self.precision == 8:
                 dest *= np.float64(self._get_fft_scale(norm=0))
         return dest