def test_jvp_assemble_eval(): primals = inputs tangent0 = np.random.normal(size=(V.dim(),)) tangent1 = np.random.normal(size=(1,)) tangent2 = np.random.normal(size=(1,)) tangents = (tangent0, tangent1, tangent2) fdm_jvp0 = fdm.jvp(ff0, tangents[0])(primals[0]) fdm_jvp1 = fdm.jvp(ff1, tangents[1])(primals[1]) fdm_jvp2 = fdm.jvp(ff2, tangents[2])(primals[2]) _, out_tangent = evaluate_jvp(assemble_fenics, templates, primals, tangents) assert np.allclose(fdm_jvp0 + fdm_jvp1 + fdm_jvp2, out_tangent)
def test_fenics_jvp(): primals = inputs tangent0 = np.asarray(onp.random.normal(size=(1, ))) tangent1 = np.asarray(onp.random.normal(size=(1, ))) tangents = (tangent0, tangent1) ff0 = lambda x: fem_eval(solve_fenics, templates, x, primals[1])[ 0] # noqa: E731 ff1 = lambda y: fem_eval(solve_fenics, templates, primals[0], y)[ 0] # noqa: E731 fdm_jvp0 = fdm.jvp(ff0, tangents[0])(primals[0]) fdm_jvp1 = fdm.jvp(ff1, tangents[1])(primals[1]) _, out_tangent = jvp_fem_eval(solve_fenics, templates, primals, tangents) assert np.allclose(fdm_jvp0 + fdm_jvp1, out_tangent)
def test_jvp(): for func, inp in zip((ff0, ff1, ff2), inputs): dir_v = 0.432543 * np.ones_like(inp) fdm_jvp = fdm.jvp(func, dir_v)(inp) jax_jvp = jax.jvp(func, (inp, ), (dir_v, ))[1] with check: assert np.allclose(fdm_jvp, jax_jvp)
def test_jvp_assemble_eval(): primals = inputs tangent0 = np.random.normal(size=(V.dim(), )) tangent1 = np.random.normal(size=(1, )) tangent2 = np.random.normal(size=(1, )) tangents = (tangent0, tangent1, tangent2) fdm_jvp0 = fdm.jvp(ff0, tangents[0])(primals[0]) fdm_jvp1 = fdm.jvp(ff1, tangents[1])(primals[1]) fdm_jvp2 = fdm.jvp(ff2, tangents[2])(primals[2]) _, firedrake_output, firedrake_inputs, tape = evaluate_primal( assemble_firedrake, templates, *inputs) out_tangent = evaluate_pushforward(firedrake_output, firedrake_inputs, tape, tangents) assert np.allclose(fdm_jvp0 + fdm_jvp1 + fdm_jvp2, out_tangent)
def test_jvp_directional(): m = central_fdm(10, 1) a = np.random.randn(3) def f(x): return np.sum(a * x) x = np.random.randn(3) v = np.random.randn(3) allclose(np.sum(gradient(f, m)(x) * v), jvp(f, v, m)(x))
def test_jvp(): m = central_fdm(10, 1) a = np.random.randn(3, 3) def f(x): return np.matmul(a, x) x = np.random.randn(3) v = np.random.randn(3) allclose(jvp(f, v, m)(x), np.matmul(a, v))
def test_fenics_jvp(): primals = inputs tangent0 = np.random.normal(size=(1, )) tangent1 = np.random.normal(size=(1, )) tangents = (tangent0, tangent1) eval_p = evaluate_primal ff0 = lambda x: eval_p(solve_fenics, templates, x, primals[1])[ 0] # noqa: E731 ff1 = lambda y: eval_p(solve_fenics, templates, primals[0], y)[ 0] # noqa: E731 fdm_jvp0 = fdm.jvp(ff0, tangents[0])(primals[0]) fdm_jvp1 = fdm.jvp(ff1, tangents[1])(primals[1]) _, fenics_output, fenics_inputs, tape = evaluate_primal( solve_fenics, templates, *inputs) out_tangent = evaluate_pushforward(fenics_output, fenics_inputs, tape, tangents) assert np.allclose(fdm_jvp0 + fdm_jvp1, out_tangent)
def test_jvp(): dir_v = 0.432543 * np.ones_like(x_input) fdm_jvp = fdm.jvp(hh, dir_v)(x_input) jax_jvp = jax.jvp(hh, (x_input, ), (dir_v, ))[1] assert np.allclose(fdm_jvp, jax_jvp)