M = 5
N = array([2**M, 2**(M-1), 2**(M-1)])
L = array([2, 2*pi, 2*pi])
    
dx = (L / N).astype(float)
comm = MPI.COMM_WORLD
num_processes = comm.Get_size()
rank = comm.Get_rank()
Np = N / num_processes
# Get points and weights for Chebyshev weighted integrals
quad = "GC"
BC1 = array([1,0,0, 1,0,0])
BC2 = array([0,1,0, 0,1,0])
BC3 = array([0,1,0, 1,0,0])
SC = ChebyshevTransform(quad)
ST = ShenBasis(BC1, quad)
SN = ShenBasis(BC2, quad, Neumann = True)
SR = ShenBasis(BC3, quad)
SB = ShenBiharmonicBasis(quad, fast_transform=False)

points, weights = ST.points_and_weights(N[0])
pointsN, weightsN = SN.points_and_weights(N[0])

x1 = arange(N[1], dtype=float)*L[1]/N[1]
x2 = arange(N[2], dtype=float)*L[2]/N[2]

# Get grid for velocity points
X = array(meshgrid(points[rank*Np[0]:(rank+1)*Np[0]], x1, x2, indexing='ij'), dtype=float)
Y = array(meshgrid(pointsN[rank*Np[0]:(rank+1)*Np[0]], x1, x2, indexing='ij'), dtype=float)

Nf = N[2]/2+1 # Number of independent complex wavenumbers in z-direction 
    print "          Menu                 "
    print "==============================="
    print "0 - Poisson Neumann BCs"
    print "1 - Poisson Dirichlet BCs"
    print "2 - Helmholtz Dirichlet BCs"
    print "3 - Biharmonic Dirichlet BCs"
    print "==============================="
    test = raw_input("Your choice: ")

    M = 2 ** 3
    quad = "GL"
    BC1 = array([1, 0, 0, 1, 0, 0])
    BC2 = array([0, 1, 0, 0, 1, 0])
    BC3 = array([0, 1, 0, 1, 0, 0])
    SC = ChebyshevTransform(quad)
    ST = ShenBasis(BC1, quad)
    SN = ShenBasis(BC2, quad, Neumann=True)
    SR = ShenBasis(BC3, quad)
    SB = ShenBiharmonicBasis(quad, fast_transform=False)

    if test == "0":
        v_exact = zeros(M)
        f = zeros(M)
        points, weights = SN.points_and_weights(M)
        f[:] = (-8.0 / 3.0) * points  # -((pi/2.)**2)*sin(pi*points/2.)
        v_exact = -(4.0 / 9.0) * (points ** 3 - 3.0 * points)  # sin(pi*points/2.)
        v = Poisson1DNeumann(M, f, quad, SC, SN, BC2)

        print "Error: ", linalg.norm(v - v_exact, inf)
        assert allclose(v, v_exact)
        pl.plot(points, v_exact, points, v)