예제 #1
0
        def jvp_jax_fem_eval_impl(primals, tangents):
            primals = (
                jax_to_fenics_numpy(p, ft) for p, ft in zip(primals, fenics_templates)
            )
            numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal(
                fenics_function, fenics_templates, *primals
            )

            tangents = (
                jax_to_fenics_numpy(t, ft) for t, ft in zip(tangents, fenics_templates)
            )
            dnumpy_output = evaluate_pushforward(
                fenics_output, fenics_inputs, tape, tangents
            )
            return numpy_output, dnumpy_output
예제 #2
0
def test_jvp_assemble_eval():
    primals = inputs
    tangent0 = np.random.normal(size=(V.dim(), ))
    tangent1 = np.random.normal(size=(1, ))
    tangent2 = np.random.normal(size=(1, ))
    tangents = (tangent0, tangent1, tangent2)

    fdm_jvp0 = fdm.jvp(ff0, tangents[0])(primals[0])
    fdm_jvp1 = fdm.jvp(ff1, tangents[1])(primals[1])
    fdm_jvp2 = fdm.jvp(ff2, tangents[2])(primals[2])

    _, fenics_output, fenics_inputs, tape = evaluate_primal(
        assemble_fenics, templates, *inputs)
    out_tangent = evaluate_pushforward(fenics_output, fenics_inputs, tape,
                                       tangents)

    assert np.allclose(fdm_jvp0 + fdm_jvp1 + fdm_jvp2, out_tangent)
예제 #3
0
def test_fenics_jvp():
    primals = inputs
    tangent0 = np.random.normal(size=(1, ))
    tangent1 = np.random.normal(size=(1, ))
    tangents = (tangent0, tangent1)

    eval_p = evaluate_primal
    ff0 = lambda x: eval_p(solve_fenics, templates, x, primals[1])[
        0]  # noqa: E731
    ff1 = lambda y: eval_p(solve_fenics, templates, primals[0], y)[
        0]  # noqa: E731
    fdm_jvp0 = fdm.jvp(ff0, tangents[0])(primals[0])
    fdm_jvp1 = fdm.jvp(ff1, tangents[1])(primals[1])

    _, fenics_output, fenics_inputs, tape = evaluate_primal(
        solve_fenics, templates, *inputs)
    out_tangent = evaluate_pushforward(fenics_output, fenics_inputs, tape,
                                       tangents)

    assert np.allclose(fdm_jvp0 + fdm_jvp1, out_tangent)