def test_Helmholtz2(SD):
    M = 2 * N
    kx = 11
    points, weights = SD.points_and_weights(M)
    uj = np.random.randn(M)
    u_hat = np.zeros(M)
    u_hat = SD.fst(uj, u_hat)
    uj = SD.ifst(u_hat, uj)

    A = ADDmat(np.arange(M).astype(np.float))
    B = BDDmat(np.arange(M).astype(np.float), SD.quad)
    s = slice(0, M - 2)

    u1 = np.zeros(M)
    u1 = SD.fst(uj, u1)
    c = A.matvec(u1) + kx ** 2 * B.matvec(u1)

    b = np.zeros(M)
    SFTc.Mult_Helmholtz_1D(M, SD.quad == "GL", 1, kx ** 2, u1, b)
    assert np.allclose(c, b)

    b = np.zeros((M, 4, 4), dtype=np.complex)
    u1 = u1.repeat(16).reshape((M, 4, 4)) + 1j * u1.repeat(16).reshape((M, 4, 4))
    kx = np.zeros((4, 4)) + kx
    SFTc.Mult_Helmholtz_3D_complex(M, SD.quad == "GL", 1.0, kx ** 2, u1, b)
    assert np.linalg.norm(b[:, 2, 2].real - c) / (M * 16) < 1e-12
    assert np.linalg.norm(b[:, 2, 2].imag - c) / (M * 16) < 1e-12
def test_Helmholtz(ST2):
    M = 4 * N
    kx = 12

    points, weights = ST2.points_and_weights(M)

    fj = np.random.randn(M)
    f_hat = np.zeros(M)
    f_hat = ST2.fst(fj, f_hat)
    fj = ST2.ifst(f_hat, fj)

    if ST2.__class__.__name__ == "ShenDirichletBasis":
        A = ADDmat(np.arange(M).astype(np.float))
        B = BDDmat(np.arange(M).astype(np.float), ST2.quad)
        s = slice(0, M - 2)
    elif ST2.__class__.__name__ == "ShenNeumannBasis":
        A = ANNmat(np.arange(M).astype(np.float))
        B = BNNmat(np.arange(M).astype(np.float), ST2.quad)
        s = slice(1, M - 2)

    f_hat = np.zeros(M)
    f_hat = ST2.fastShenScalar(fj, f_hat)
    u_hat = np.zeros(M)
    u_hat[s] = la.spsolve(A.diags() + kx ** 2 * B.diags(), f_hat[s])

    u1 = np.zeros(M)
    u1 = ST2.ifst(u_hat, u1)

    c = A.matvec(u_hat) + kx ** 2 * B.matvec(u_hat)
    c2 = np.dot(A.diags().toarray(), u_hat[s]) + kx ** 2 * np.dot(B.diags().toarray(), u_hat[s])

    # from IPython import embed; embed()
    assert np.allclose(c, f_hat)
    assert np.allclose(c[s], c2)

    # Multidimensional
    f_hat = f_hat.repeat(16).reshape((M, 4, 4)) + 1j * f_hat.repeat(16).reshape((M, 4, 4))
    kx = np.zeros((4, 4)) + 12
    H = Helmholtz(M, kx, ST2.quad, ST2.__class__.__name__ == "ShenNeumannBasis")
    u0_hat = np.zeros((M, 4, 4), dtype=np.complex)
    u0_hat = H(u0_hat, f_hat)
    u0 = np.zeros((M, 4, 4), dtype=np.complex)
    u0 = ST2.ifst(u0_hat, u0)

    assert np.linalg.norm(u0[:, 2, 2].real - u1) / (M * 16) < 1e-12
    assert np.linalg.norm(u0[:, 2, 2].imag - u1) / (M * 16) < 1e-12