def solve_heat_equation(k, time_stepping_method):
    """
    Solve the heat equation on a hard-coded mesh with a hard-coded initial and boundary conditions

    :param k: Thermal conductivity
    :param time_stepping_method: Time stepping method. Can be one of ["forward_euler", "backward_euler", "trapezoidal"]
    """
    mesh, boundary = setup_geometry()

    # Exact solution (Gauss curve)
    ue = Expression("exp(-(x[0]*x[0]+x[1]*x[1])/(4*a*t))/(4*pi*a*t)",
                    a=k,
                    t=1e-7,
                    domain=mesh,
                    degree=1)

    # Polynomial degree
    r = 1

    # Setup FEM function space
    V = FunctionSpace(mesh, "CG", r)

    # Create boundary condition
    bc = DirichletBC(V, ue, boundary)

    # Setup FEM functions
    v = TestFunction(V)
    u = Function(V)

    # Time parameters
    time_step = 0.001
    t_start, t_end = 0.0, 20.0

    # Time stepping
    t = t_start
    if time_stepping_method == "forward_euler":
        theta = 0.0
    if time_stepping_method == "backward_euler":
        theta = 1.0
    if time_stepping_method == "trapezoidal":
        theta = 0.5
    u0 = ue
    step = 0
    while t < t_end:
        # Intermediate value for u (depending on the chosen time stepping method)
        um = (1.0 - theta) * u0 + theta * u

        # Weak form of the heat equation
        a = (u - u0) / time_step * v * dx + k * inner(grad(um), grad(v)) * dx

        # Solve the heat equation (one time step)
        solve(a == 0, u, bc)

        # Advance time in exact solution
        t += time_step
        ue.t = t

        if step % 100 == 0:
            # Compute error in L2 norm
            error_L2 = errornorm(ue, u, 'L2')
            # or equivalently
            # sqrt(assemble((ue - u) * (ue - u) * dx))
            # Compute norm of exact solution
            nue = norm(ue)
            # Print relative error
            print("Relative error = {}".format(error_L2 / nue))

        # Shift to next time step
        u0 = project(u, V)
        step += 1
def solve_wave_equation(a, symmetric=True):
    """
    Solve the wave equation on a hard-coded mesh with a hard-coded initial and boundary conditions

    :param a: Wave propagation factor
    :param symmetric: Whether or not the problem is symmetric
    """
    mesh, boundary = setup_geometry()

    # Exact solution
    if symmetric:
        ue = Expression(
            "(1-pow(a*t-x[0],2))*exp(-pow(a*t-x[0],2)) + (1-pow(a*t+x[0],2))*exp(-pow(a*t+x[0],2))",
            a=a,
            t=0,
            domain=mesh,
            degree=2)
        ve = Expression(
            "2*a*(a*t-x[0])*(pow(a*t-x[0],2)-2)*exp(-pow(a*t-x[0],2))"
            "+ 2*a*(a*t+x[0])*(pow(a*t+x[0],2)-2)*exp(-pow(a*t+x[0],2))",
            a=a,
            t=0,
            domain=mesh,
            degree=2)
    else:
        ue = Expression("(1-pow(a*t+x[0],2))*exp(-pow(a*t+x[0],2))",
                        a=a,
                        t=0,
                        domain=mesh,
                        degree=2)
        ve = Expression(
            "2*a*(a*t+x[0])*(pow(a*t+x[0],2)-2)*exp(-pow(a*t+x[0],2))",
            a=a,
            t=0,
            domain=mesh,
            degree=2)

    # Polynomial degree
    r = 1

    # Setup FEM function spaces
    Q = FunctionSpace(mesh, "CG", r)
    W = VectorFunctionSpace(mesh, "CG", r, dim=2)

    # Create boundary conditions
    bcu = DirichletBC(W.sub(0), ue, boundary)
    bcv = DirichletBC(W.sub(1), ve, boundary)
    bcs = [bcu, bcv]

    # Setup FEM functions
    p, q = TestFunctions(W)
    w = Function(W)
    u, v = w[0], w[1]

    # Time parameters
    time_step = 0.05
    t_start, t_end = 0.0, 5.0

    # Time stepping
    t = t_start
    u0 = ue
    v0 = ve
    step = 0
    while t < t_end:
        # Weak form of the wave equation
        um = 0.5 * (u + u0)
        vm = 0.5 * (v + v0)
        a1 = (u - u0) / time_step * p * dx - vm * p * dx
        a2 = (v - v0) / time_step * q * dx + a**2 * inner(grad(um),
                                                          grad(q)) * dx

        # Solve the wave equation (one time step)
        solve(a1 + a2 == 0, w, bcs)

        # Advance time in exact solution
        t += time_step
        ue.t = t
        ve.t = t

        if step % 10 == 0:
            # Plot solution at current time step
            fig = plt.figure()
            plot(u, fig=fig)
            plt.show()

            # Compute max error at vertices
            vertex_values_ue = ue.compute_vertex_values(mesh)
            vertex_values_w = w.compute_vertex_values(mesh)
            vertex_values_u = np.split(vertex_values_w, 2)[0]
            error_max = np.max(np.abs(vertex_values_ue - vertex_values_u))
            # Print error
            print(error_max)

        # Shift to next time step
        u0 = project(u, Q)
        v0 = project(v, Q)
        step += 1
def solve_heat_equation(k):
    """
    Solve the heat equation on a hard-coded mesh with a hard-coded initial and boundary conditions
    
    :param k: Thermal conductivity
    """
    mesh, boundary = setup_geometry()

    # Exact solution (Gauss curve)
    ue = Expression("exp(-(x[0]*x[0]+x[1]*x[1])/(4*a*t))/(4*pi*a*t)",
                    a=k,
                    t=1e-7,
                    domain=mesh,
                    degree=2)

    # Polynomial degree
    r = 1

    # Setup FEM function space
    V = FunctionSpace(mesh, "CG", r)

    # Create boundary condition
    bc = DirichletBC(V, ue, boundary)

    # Setup FEM functions
    v = TestFunction(V)
    u = Function(V)

    # Time parameters
    time_step = 0.5
    t_start, t_end = 0.0, 20.0

    # Time stepping
    t = t_start
    u0 = ue
    step = 0
    while t < t_end:
        # Weak form of the heat equation
        a = (u - u0) / time_step * v * dx + k * inner(grad(u), grad(v)) * dx

        # Solve the heat equation (one time step)
        solve(a == 0, u, bc)

        # Advance time in exact solution
        t += time_step
        ue.t = t

        if step % 5 == 0:
            # Plot solution at current time step
            fig = plt.figure()
            plot(u, fig=fig)
            plt.show()

            # Compute error in L2 norm
            error_L2 = errornorm(ue, u, 'L2')
            # or equivalently
            # sqrt(assemble((ue - u) * (ue - u) * dx))
            # Print error
            print(error_L2)

        # Shift to next time step
        u0 = project(u, V)
        step += 1