def test_FST_padded(ST, quad):
    ST1 = ST(N, quad=quad)
    M = np.array([N, 2 * N, 4 * N])
    FST = SlabShen_R2C(M,
                       np.array([2 * pi, 2 * pi, 2 * pi]),
                       comm,
                       communication='Alltoall')
    FST_SELF = SlabShen_R2C(M, np.array([2 * pi, 2 * pi, 2 * pi]),
                            MPI.COMM_SELF)
    ST1.plan(FST.complex_shape(), 0, np.complex, {})

    if FST.rank == 0:
        ST0 = ST(N, quad=quad)
        ST0.plan(FST_SELF.complex_shape(), 0, np.complex, {})
        A = np.random.random(M).astype(FST.float)
        A_hat = np.zeros(FST_SELF.complex_shape(), dtype=FST.complex)

        A_hat = FST_SELF.forward(A, A_hat, ST0)
        A = FST_SELF.backward(A_hat, A, ST0)
        A_hat = FST_SELF.forward(A, A_hat, ST0)

        A_hat[:, -M[1] // 2] = 0

        A_pad = np.zeros(FST_SELF.real_shape_padded(), dtype=FST.float)
        A_pad = FST_SELF.backward(A_hat, A_pad, ST0, dealias='3/2-rule')
        A_hat = FST_SELF.forward(A_pad, A_hat, ST0, dealias='3/2-rule')

    else:
        A_pad = np.zeros(FST_SELF.real_shape_padded(), dtype=FST.float)
        A_hat = np.zeros(FST_SELF.complex_shape(), dtype=FST.complex)

    atol, rtol = (1e-10, 1e-8) if FST.float is np.float64 else (5e-7, 1e-4)
    FST.comm.Bcast(A_pad, root=0)
    FST.comm.Bcast(A_hat, root=0)

    a = np.zeros(FST.real_shape_padded(), dtype=FST.float)
    c = np.zeros(FST.complex_shape(), dtype=FST.complex)
    a[:] = A_pad[FST.real_local_slice(padsize=1.5)]
    c = FST.forward(a, c, ST1, dealias='3/2-rule')

    assert np.all(abs((c - A_hat[FST.complex_local_slice()]) / c.max()) < rtol)

    a = FST.backward(c, a, ST1, dealias='3/2-rule')

    #print abs((a - A_pad[FST.real_local_slice(padsize=1.5)])/a.max())
    assert np.all(
        abs((a - A_pad[FST.real_local_slice(padsize=1.5)]) / a.max()) < rtol)
def test_FST(ST, quad):
    ST1 = ST(N, quad=quad)
    FST = SlabShen_R2C(np.array([N, N, N]), np.array([2 * pi, 2 * pi, 2 * pi]),
                       comm)
    ST1.plan(FST.complex_shape(), 0, np.complex, {})

    if FST.rank == 0:

        FST_SELF = SlabShen_R2C(np.array([N, N, N]),
                                np.array([2 * pi, 2 * pi, 2 * pi]),
                                MPI.COMM_SELF)
        ST0 = ST(N, quad=quad)
        A = np.random.random((N, N, N)).astype(FST.float)
        B2 = np.zeros(FST_SELF.complex_shape(), dtype=FST.complex)
        ST0.plan(FST_SELF.complex_shape(), 0, np.complex, {})

        B2 = FST_SELF.forward(A, B2, ST0)
        A = FST_SELF.backward(B2, A, ST0)
        B2 = FST_SELF.forward(A, B2, ST0)

    else:
        A = np.zeros((N, N, N), dtype=FST.float)
        B2 = np.zeros((N, N, N // 2 + 1), dtype=FST.complex)

    atol, rtol = (1e-10, 1e-8) if FST.float is np.float64 else (5e-7, 1e-4)
    FST.comm.Bcast(A, root=0)
    FST.comm.Bcast(B2, root=0)

    a = np.zeros(FST.real_shape(), dtype=FST.float)
    c = np.zeros(FST.complex_shape(), dtype=FST.complex)
    a[:] = A[FST.real_local_slice()]
    c = FST.forward(a, c, ST1)

    assert np.all(abs((c - B2[FST.complex_local_slice()]) / c.max()) < rtol)

    a = FST.backward(c, a, ST1)

    assert np.all(abs((a - A[FST.real_local_slice()]) / a.max()) < rtol)
Exemple #3
0
def refine(infile, mesh):
    """Refine 3D solution
    """
    assert mesh in ("channel", "triplyperiodic")
    comm = MPI.COMM_WORLD
    filename, ending = infile.split(".")
    fin = h5py.File(infile, driver="mpio", comm=comm)    
    fout = h5py.File(filename+"_refined.h5", "w", driver="mpio", comm=comm)
    N = fin.attrs["N"]
    N1 = N.copy()
    if mesh == "channel":
        N1[1:] *= 2
        FFT0 = FST(N, fin.attrs["L"], MPI, padsize=2)
        SB = ShenBiharmonicBasis("GL")
        ST = ShenDirichletBasis("GL")
        
    elif mesh == "triplyperiodic":
        N1 *= 2        
        FFT0 = FFT(N, fin.attrs["L"], MPI, "double", padsize=2)

    shape = (3, N1[0], N1[1], N1[2])
    fout.create_group("3D")
    fout["3D"].create_group("checkpoint")
    fout["3D/checkpoint"].create_group("U")
    fout.attrs.create("dt", fin.attrs["dt"])
    fout.attrs.create("N", N1) 
    fout.attrs.create("L", fin.attrs["L"])
    fout["3D/checkpoint/U"].create_dataset("0", shape=shape, dtype=FFT0.float)
    fout["3D/checkpoint/U"].create_dataset("1", shape=shape, dtype=FFT0.float)

    assert "checkpoint" in fin["3D"]
    rank = comm.Get_rank()
    U0 = np.empty((3,)+FFT0.real_shape(), dtype=FFT0.float)
    s = FFT0.real_local_slice()
    s1 = FFT0.real_local_slice(padsize=2)
    
    U0[:] = fin["3D/checkpoint/U/0"][:, s[0], s[1], s[2]]
    U0_hat = np.empty((3,)+FFT0.complex_shape(), FFT0.complex)
    U0_pad = np.empty((3,)+FFT0.real_shape_padded(), dtype=FFT0.float)
    if mesh == "triplyperiodic":
        for i in range(3):
            U0_hat[i] = FFT0.fftn(U0[i], U0_hat[i])
            
        for i in range(3):
            U0_pad[i] = FFT0.ifftn(U0_hat[i], U0_pad[i], dealias="3/2-rule") # Name is 3/2-rule, but padsize is 2

    else:
        U0_hat[0] = FFT0.fst(U0[0], U0_hat[0], SB)
        for i in range(1,3):
            U0_hat[i] = FFT0.fst(U0[i], U0_hat[i], ST)
        
        U0_pad[0] = FFT0.ifst(U0_hat[0], U0_pad[0], SB, dealias="3/2-rule")
        for i in range(1,3):
            U0_pad[i] = FFT0.ifst(U0_hat[i], U0_pad[i], ST, dealias="3/2-rule") # Name is 3/2-rule, but padsize is 2

    # Get new values
    fout["3D/checkpoint/U/0"][:, s1[0], s1[1], s1[2]] = U0_pad[:]
    U0[:] = fin["3D/checkpoint/U/1"][:, s[0], s[1], s[2]]
    
    if mesh == "triplyperiodic":
        for i in range(3):
            U0_hat[i] = FFT0.fftn(U0[i], U0_hat[i])    
        for i in range(3):
            U0_pad[i] = FFT0.ifftn(U0_hat[i], U0_pad[i], dealias="3/2-rule")    
    else:
        U0_hat[0] = FFT0.fst(U0[0], U0_hat[0], SB)
        for i in range(1,3):
            U0_hat[i] = FFT0.fst(U0[i], U0_hat[i], ST)
        
        U0_pad[0] = FFT0.ifst(U0_hat[0], U0_pad[0], SB, dealias="3/2-rule")
        for i in range(1,3):
            U0_pad[i] = FFT0.ifst(U0_hat[i], U0_pad[i], ST, dealias="3/2-rule") # Name is 3/2-rule, but padsize is 2

    fout["3D/checkpoint/U/1"][:, s1[0], s1[1], s1[2]] = U0_pad[:]
    fout.close()
    fin.close()