def jvp_jax_fem_eval_impl(primals, tangents): primals = ( jax_to_fenics_numpy(p, ft) for p, ft in zip(primals, fenics_templates) ) numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal( fenics_function, fenics_templates, *primals ) tangents = ( jax_to_fenics_numpy(t, ft) for t, ft in zip(tangents, fenics_templates) ) dnumpy_output = evaluate_pushforward( fenics_output, fenics_inputs, tape, tangents ) return numpy_output, dnumpy_output
def test_jvp_assemble_eval(): primals = inputs tangent0 = np.random.normal(size=(V.dim(), )) tangent1 = np.random.normal(size=(1, )) tangent2 = np.random.normal(size=(1, )) tangents = (tangent0, tangent1, tangent2) fdm_jvp0 = fdm.jvp(ff0, tangents[0])(primals[0]) fdm_jvp1 = fdm.jvp(ff1, tangents[1])(primals[1]) fdm_jvp2 = fdm.jvp(ff2, tangents[2])(primals[2]) _, fenics_output, fenics_inputs, tape = evaluate_primal( assemble_fenics, templates, *inputs) out_tangent = evaluate_pushforward(fenics_output, fenics_inputs, tape, tangents) assert np.allclose(fdm_jvp0 + fdm_jvp1 + fdm_jvp2, out_tangent)
def test_fenics_jvp(): primals = inputs tangent0 = np.random.normal(size=(1, )) tangent1 = np.random.normal(size=(1, )) tangents = (tangent0, tangent1) eval_p = evaluate_primal ff0 = lambda x: eval_p(solve_fenics, templates, x, primals[1])[ 0] # noqa: E731 ff1 = lambda y: eval_p(solve_fenics, templates, primals[0], y)[ 0] # noqa: E731 fdm_jvp0 = fdm.jvp(ff0, tangents[0])(primals[0]) fdm_jvp1 = fdm.jvp(ff1, tangents[1])(primals[1]) _, fenics_output, fenics_inputs, tape = evaluate_primal( solve_fenics, templates, *inputs) out_tangent = evaluate_pushforward(fenics_output, fenics_inputs, tape, tangents) assert np.allclose(fdm_jvp0 + fdm_jvp1, out_tangent)