def test_Mult_Div(): SD = ShenDirichletBasis("GC") SN = ShenDirichletBasis("GC") Cm = CNDmat(np.arange(N).astype(np.float)) Bm = BNDmat(np.arange(N).astype(np.float), "GC") uk = np.random.randn((N)) + np.random.randn((N)) * 1j vk = np.random.randn((N)) + np.random.randn((N)) * 1j wk = np.random.randn((N)) + np.random.randn((N)) * 1j b = np.zeros(N, dtype=np.complex) uk0 = np.zeros(N, dtype=np.complex) vk0 = np.zeros(N, dtype=np.complex) wk0 = np.zeros(N, dtype=np.complex) uk0 = SD.fst(uk, uk0) uk = SD.ifst(uk0, uk) uk0 = SD.fst(uk, uk0) vk0 = SD.fst(vk, vk0) vk = SD.ifst(vk0, vk) vk0 = SD.fst(vk, vk0) wk0 = SD.fst(wk, wk0) wk = SD.ifst(wk0, wk) wk0 = SD.fst(wk, wk0) SFTc.Mult_Div_1D(N, 7, 7, uk0[: N - 2], vk0[: N - 2], wk0[: N - 2], b[1 : N - 2]) uu = Cm.matvec(uk0) uu += 1j * 7 * Bm.matvec(vk0) + 1j * 7 * Bm.matvec(wk0) # from IPython import embed; embed() assert np.allclose(uu, b) uk0 = uk0.repeat(4 * 4).reshape((N, 4, 4)) + 1j * uk0.repeat(4 * 4).reshape((N, 4, 4)) vk0 = vk0.repeat(4 * 4).reshape((N, 4, 4)) + 1j * vk0.repeat(4 * 4).reshape((N, 4, 4)) wk0 = wk0.repeat(4 * 4).reshape((N, 4, 4)) + 1j * wk0.repeat(4 * 4).reshape((N, 4, 4)) b = np.zeros((N, 4, 4), dtype=np.complex) m = np.zeros((4, 4)) + 7 n = np.zeros((4, 4)) + 7 SFTc.Mult_Div_3D(N, m, n, uk0[: N - 2], vk0[: N - 2], wk0[: N - 2], b[1 : N - 2]) uu = Cm.matvec(uk0) uu += 1j * 7 * Bm.matvec(vk0) + 1j * 7 * Bm.matvec(wk0) assert np.allclose(uu, b)
assert np.allclose(uk[:, 1, 1].imag, uk_hat) b = np.zeros((N-2, 10, 10), dtype=np.complex) SFTc.Mult_Helmholtz_3D_complex(N, ST.quad=="GL", 1, alfa**2, uk, b) assert np.allclose(b[:, 1, 1].real, fk[:-2]) assert np.allclose(b[:, 1, 1].imag, fk[:-2]) return uk_hat f_hat = fj.copy() f_hat = ST.fastShenScalar(fj, f_hat) uk_hat = fj.copy() uk_hat[:-2] = solve(f_hat) uq = uk_hat.copy() uq = ST.ifst(uk_hat, uq) uqf = uq.copy() uqf = ST.fst(uq, uqf) uq0 = uq.copy() assert np.allclose(ST.ifst(uqf, uq0), uq) u_exact = np.array([u.subs(x, h) for h in points], dtype=np.float) plt.figure(); plt.plot(points, [u.subs(x, i) for i in points]); plt.title("U") plt.figure(); plt.plot(points, uq - u_exact); plt.title("Error") print "Error = ", np.linalg.norm(uq - u_exact) #plt.show()