def test_FST_padded(ST, quad): ST1 = ST(N, quad=quad) M = np.array([N, 2 * N, 4 * N]) FST = SlabShen_R2C(M, np.array([2 * pi, 2 * pi, 2 * pi]), comm, communication='Alltoall') FST_SELF = SlabShen_R2C(M, np.array([2 * pi, 2 * pi, 2 * pi]), MPI.COMM_SELF) ST1.plan(FST.complex_shape(), 0, np.complex, {}) if FST.rank == 0: ST0 = ST(N, quad=quad) ST0.plan(FST_SELF.complex_shape(), 0, np.complex, {}) A = np.random.random(M).astype(FST.float) A_hat = np.zeros(FST_SELF.complex_shape(), dtype=FST.complex) A_hat = FST_SELF.forward(A, A_hat, ST0) A = FST_SELF.backward(A_hat, A, ST0) A_hat = FST_SELF.forward(A, A_hat, ST0) A_hat[:, -M[1] // 2] = 0 A_pad = np.zeros(FST_SELF.real_shape_padded(), dtype=FST.float) A_pad = FST_SELF.backward(A_hat, A_pad, ST0, dealias='3/2-rule') A_hat = FST_SELF.forward(A_pad, A_hat, ST0, dealias='3/2-rule') else: A_pad = np.zeros(FST_SELF.real_shape_padded(), dtype=FST.float) A_hat = np.zeros(FST_SELF.complex_shape(), dtype=FST.complex) atol, rtol = (1e-10, 1e-8) if FST.float is np.float64 else (5e-7, 1e-4) FST.comm.Bcast(A_pad, root=0) FST.comm.Bcast(A_hat, root=0) a = np.zeros(FST.real_shape_padded(), dtype=FST.float) c = np.zeros(FST.complex_shape(), dtype=FST.complex) a[:] = A_pad[FST.real_local_slice(padsize=1.5)] c = FST.forward(a, c, ST1, dealias='3/2-rule') assert np.all(abs((c - A_hat[FST.complex_local_slice()]) / c.max()) < rtol) a = FST.backward(c, a, ST1, dealias='3/2-rule') #print abs((a - A_pad[FST.real_local_slice(padsize=1.5)])/a.max()) assert np.all( abs((a - A_pad[FST.real_local_slice(padsize=1.5)]) / a.max()) < rtol)
def test_FST(ST, quad): ST1 = ST(N, quad=quad) FST = SlabShen_R2C(np.array([N, N, N]), np.array([2 * pi, 2 * pi, 2 * pi]), comm) ST1.plan(FST.complex_shape(), 0, np.complex, {}) if FST.rank == 0: FST_SELF = SlabShen_R2C(np.array([N, N, N]), np.array([2 * pi, 2 * pi, 2 * pi]), MPI.COMM_SELF) ST0 = ST(N, quad=quad) A = np.random.random((N, N, N)).astype(FST.float) B2 = np.zeros(FST_SELF.complex_shape(), dtype=FST.complex) ST0.plan(FST_SELF.complex_shape(), 0, np.complex, {}) B2 = FST_SELF.forward(A, B2, ST0) A = FST_SELF.backward(B2, A, ST0) B2 = FST_SELF.forward(A, B2, ST0) else: A = np.zeros((N, N, N), dtype=FST.float) B2 = np.zeros((N, N, N // 2 + 1), dtype=FST.complex) atol, rtol = (1e-10, 1e-8) if FST.float is np.float64 else (5e-7, 1e-4) FST.comm.Bcast(A, root=0) FST.comm.Bcast(B2, root=0) a = np.zeros(FST.real_shape(), dtype=FST.float) c = np.zeros(FST.complex_shape(), dtype=FST.complex) a[:] = A[FST.real_local_slice()] c = FST.forward(a, c, ST1) assert np.all(abs((c - B2[FST.complex_local_slice()]) / c.max()) < rtol) a = FST.backward(c, a, ST1) assert np.all(abs((a - A[FST.real_local_slice()]) / a.max()) < rtol)
def refine(infile, mesh): """Refine 3D solution """ assert mesh in ("channel", "triplyperiodic") comm = MPI.COMM_WORLD filename, ending = infile.split(".") fin = h5py.File(infile, driver="mpio", comm=comm) fout = h5py.File(filename+"_refined.h5", "w", driver="mpio", comm=comm) N = fin.attrs["N"] N1 = N.copy() if mesh == "channel": N1[1:] *= 2 FFT0 = FST(N, fin.attrs["L"], MPI, padsize=2) SB = ShenBiharmonicBasis("GL") ST = ShenDirichletBasis("GL") elif mesh == "triplyperiodic": N1 *= 2 FFT0 = FFT(N, fin.attrs["L"], MPI, "double", padsize=2) shape = (3, N1[0], N1[1], N1[2]) fout.create_group("3D") fout["3D"].create_group("checkpoint") fout["3D/checkpoint"].create_group("U") fout.attrs.create("dt", fin.attrs["dt"]) fout.attrs.create("N", N1) fout.attrs.create("L", fin.attrs["L"]) fout["3D/checkpoint/U"].create_dataset("0", shape=shape, dtype=FFT0.float) fout["3D/checkpoint/U"].create_dataset("1", shape=shape, dtype=FFT0.float) assert "checkpoint" in fin["3D"] rank = comm.Get_rank() U0 = np.empty((3,)+FFT0.real_shape(), dtype=FFT0.float) s = FFT0.real_local_slice() s1 = FFT0.real_local_slice(padsize=2) U0[:] = fin["3D/checkpoint/U/0"][:, s[0], s[1], s[2]] U0_hat = np.empty((3,)+FFT0.complex_shape(), FFT0.complex) U0_pad = np.empty((3,)+FFT0.real_shape_padded(), dtype=FFT0.float) if mesh == "triplyperiodic": for i in range(3): U0_hat[i] = FFT0.fftn(U0[i], U0_hat[i]) for i in range(3): U0_pad[i] = FFT0.ifftn(U0_hat[i], U0_pad[i], dealias="3/2-rule") # Name is 3/2-rule, but padsize is 2 else: U0_hat[0] = FFT0.fst(U0[0], U0_hat[0], SB) for i in range(1,3): U0_hat[i] = FFT0.fst(U0[i], U0_hat[i], ST) U0_pad[0] = FFT0.ifst(U0_hat[0], U0_pad[0], SB, dealias="3/2-rule") for i in range(1,3): U0_pad[i] = FFT0.ifst(U0_hat[i], U0_pad[i], ST, dealias="3/2-rule") # Name is 3/2-rule, but padsize is 2 # Get new values fout["3D/checkpoint/U/0"][:, s1[0], s1[1], s1[2]] = U0_pad[:] U0[:] = fin["3D/checkpoint/U/1"][:, s[0], s[1], s[2]] if mesh == "triplyperiodic": for i in range(3): U0_hat[i] = FFT0.fftn(U0[i], U0_hat[i]) for i in range(3): U0_pad[i] = FFT0.ifftn(U0_hat[i], U0_pad[i], dealias="3/2-rule") else: U0_hat[0] = FFT0.fst(U0[0], U0_hat[0], SB) for i in range(1,3): U0_hat[i] = FFT0.fst(U0[i], U0_hat[i], ST) U0_pad[0] = FFT0.ifst(U0_hat[0], U0_pad[0], SB, dealias="3/2-rule") for i in range(1,3): U0_pad[i] = FFT0.ifst(U0_hat[i], U0_pad[i], ST, dealias="3/2-rule") # Name is 3/2-rule, but padsize is 2 fout["3D/checkpoint/U/1"][:, s1[0], s1[1], s1[2]] = U0_pad[:] fout.close() fin.close()