def ifft(self, src: cua.GPUArray, dest: cua.GPUArray = None): """ Compute the backward FFT :param src: the source GPUarray :param dest: the destination GPUarray. Should be None for an inplace transform :return: the transformed array. For a C2R inplace transform, the float view of the array is returned. """ if self.inplace: if dest is not None: if src.gpudata != dest.gpudata: raise RuntimeError( "VkFFTApp.fft: dest!=src but this is an inplace transform" ) if self.batch_shape is not None: if self.r2c: src_shape = tuple( list(self.batch_shape[:-1]) + [self.batch_shape[-1] // 2]) s = src.reshape(src_shape) else: s = src.reshape(self.batch_shape) else: s = src _vkfft_cuda.ifft(self.app, int(s.gpudata), int(s.gpudata)) if self.norm == "ortho": if self.precision == 2: src *= np.float16(self._get_ifft_scale(norm=0)) elif self.precision == 4: src *= np.float32(self._get_ifft_scale(norm=0)) elif self.precision == 8: src *= np.float64(self._get_ifft_scale(norm=0)) if self.r2c: if src.dtype == np.complex64: return src.view(dtype=np.float32) elif src.dtype == np.complex128: return src.view(dtype=np.float64) return src if not self.inplace: if dest is None: raise RuntimeError( "VkFFTApp.ifft: dest is None but this is an out-of-place transform" ) elif src.gpudata == dest.gpudata: raise RuntimeError( "VkFFTApp.ifft: dest and src are identical but this is an out-of-place transform" ) if self.r2c: assert (dest.size == src.size // src.shape[-1] * 2 * (src.shape[-1] - 1)) # Special case, src and dest buffer sizes are different, # VkFFT is configured to go back to the source buffer if self.batch_shape is not None: src_shape = tuple( list(self.batch_shape[:-1]) + [self.batch_shape[-1] // 2 + 1]) s = src.reshape(src_shape) d = dest.reshape(self.batch_shape) else: s, d = src, dest _vkfft_cuda.ifft(self.app, int(d.gpudata), int(s.gpudata)) else: if self.batch_shape is not None: s = src.reshape(self.batch_shape) d = dest.reshape(self.batch_shape) else: s, d = src, dest _vkfft_cuda.ifft(self.app, int(s.gpudata), int(d.gpudata)) if self.norm == "ortho": if self.precision == 2: dest *= np.float16(self._get_ifft_scale(norm=0)) elif self.precision == 4: dest *= np.float32(self._get_ifft_scale(norm=0)) elif self.precision == 8: dest *= np.float64(self._get_ifft_scale(norm=0)) return dest
def fft(self, src: cua.GPUArray, dest: cua.GPUArray = None): """ Compute the forward FFT :param src: the source GPUarray :param dest: the destination GPUarray. Should be None for an inplace transform :return: the transformed array. For a R2C inplace transform, the complex view of the array is returned. """ if self.inplace: if dest is not None: if src.gpudata != dest.gpudata: raise RuntimeError( "VkFFTApp.fft: dest is not None but this is an inplace transform" ) if self.batch_shape is not None: s = src.reshape(self.batch_shape) else: s = src _vkfft_cuda.fft(self.app, int(s.gpudata), int(s.gpudata)) if self.norm == "ortho": if self.precision == 2: src *= np.float16(self._get_fft_scale(norm=0)) elif self.precision == 4: src *= np.float32(self._get_fft_scale(norm=0)) elif self.precision == 8: src *= np.float64(self._get_fft_scale(norm=0)) if self.r2c: if src.dtype == np.float32: return src.view(dtype=np.complex64) elif src.dtype == np.float64: return src.view(dtype=np.complex128) return src else: if dest is None: raise RuntimeError( "VkFFTApp.fft: dest is None but this is an out-of-place transform" ) elif src.gpudata == dest.gpudata: raise RuntimeError( "VkFFTApp.fft: dest and src are identical but this is an out-of-place transform" ) if self.r2c: assert (src.size == dest.size // dest.shape[-1] * 2 * (dest.shape[-1] - 1)) if self.batch_shape is not None: s = src.reshape(self.batch_shape) if self.r2c: c_shape = tuple( list(self.batch_shape[:-1]) + [self.batch_shape[-1] // 2 + 1]) d = dest.reshape(c_shape) else: d = dest.reshape(self.batch_shape) else: s, d = src, dest _vkfft_cuda.fft(self.app, int(s.gpudata), int(d.gpudata)) if self.norm == "ortho": if self.precision == 2: dest *= np.float16(self._get_fft_scale(norm=0)) elif self.precision == 4: dest *= np.float32(self._get_fft_scale(norm=0)) elif self.precision == 8: dest *= np.float64(self._get_fft_scale(norm=0)) return dest