示例#1
0
 def test_rfftn_numpy(self):
     """Test that rfftn_numpy works as expected"""
     axes = [(0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0), (2, 0, 1), (2, 1, 0)]
     for x in [self.ad, self.af]:
         for a in axes:
             r_tol, a_tol = _get_rtol_atol(x)
             rfft_tr = mkl_fft.rfftn_numpy(np.transpose(x, a))
             tr_rfft = np.transpose(mkl_fft.rfftn_numpy(x, axes=a), a)
             assert_allclose(rfft_tr, tr_rfft, rtol=r_tol, atol=a_tol)
示例#2
0
文件: fft.py 项目: inwwin/cddm
def _rfft2(a, overwrite_x=False, extra={}):
    libname = CDDMConfig["rfft2lib"]
    cutoff = a.shape[-1] // 2 + 1
    if libname == "mkl_fft":
        return mkl_fft.rfftn_numpy(a.real, axes=(-2, -1), **extra)
        #return mkl_fft.fft2(a, overwrite_x = overwrite_x)[...,0:cutoff]
    elif libname == "scipy":
        return spfft.fft2(a, overwrite_x=overwrite_x, **extra)[..., 0:cutoff]
    elif libname == "numpy":
        return np.fft.rfft2(a.real,
                            **extra)  #force real in case input is complex
    elif libname == "pyfftw":
        return fftw.numpy_fft.rfft2(
            a.real, **extra)  #force real in case input is complex
示例#3
0
def _rfft2(a, overwrite_x=False, extra={}):
    libname = CDDMConfig["rfft2lib"]
    cutoff = a.shape[-1] // 2 + 1
    if libname == "mkl_fft":
        out = mkl_fft.rfftn_numpy(a.real, axes=(-2, -1), **extra)
        #return mkl_fft.fft2(a, overwrite_x = overwrite_x)[...,0:cutoff]
    elif libname == "scipy":
        out = spfft.fft2(a, overwrite_x=overwrite_x, **extra)[..., 0:cutoff]
    elif libname == "numpy":
        out = np.fft.rfft2(a.real,
                           **extra)  #force real in case input is complex
    elif libname == "pyfftw":
        out = fftw.numpy_fft.rfft2(
            a.real, **extra)  #force real in case input is complex

    # depending on how the libraries are compiled, the output may not be of same dtype as requested
    # float32 may be converted to complex128... so we make sure it is of specified type.
    return np.asarray(out, CDTYPE)