def test_jacobian_and_vjp(): fdm_jac0 = fdm.jacobian(hh0)(inputs[0]) jax_jac0 = jax.jacrev(hh0)(inputs[0]) with check: assert np.allclose(fdm_jac0, jax_jac0) v0 = np.asarray(onp.random.normal(size=(1,))) fdm_vjp0 = v0 @ fdm_jac0 jax_vjp0 = jax.vjp(hh0, inputs[0])[1](v0) with check: assert np.allclose(fdm_vjp0, jax_vjp0) fdm_jac1 = fdm.jacobian(hh1)(inputs[1]) jax_jac1 = jax.jacrev(hh1)(inputs[1]) with check: assert np.allclose(fdm_jac1, jax_jac1) v1 = np.asarray(onp.random.normal(size=(1,))) fdm_vjp1 = v1 @ fdm_jac1 jax_vjp1 = jax.vjp(hh1, inputs[1])[1](v1) with check: assert np.allclose(fdm_vjp1, jax_vjp1) fdm_jac2 = fdm.jacobian(hh2)(inputs[2]) jax_jac2 = jax.jacrev(hh2)(inputs[2]) with check: assert np.allclose(fdm_jac2, jax_jac2) v2 = np.asarray(onp.random.normal(size=(1,))) fdm_vjp2 = v2 @ fdm_jac2 jax_vjp2 = jax.vjp(hh2, inputs[2])[1](v2) with check: assert np.allclose(fdm_vjp2, jax_vjp2)
def test_jacobian(): fdm_jac0 = fdm.jacobian(hh)(x_input) jax_jac0 = jax.jacrev(hh)(x_input) with check: assert np.allclose(fdm_jac0, jax_jac0) rngkey = jax.random.PRNGKey(0) v = jax.random.normal(rngkey, shape=(V.dim(), ), dtype="float64") fdm_vjp0 = v @ fdm_jac0 jax_vjp0 = jax.vjp(hh, x_input)[1](v) with check: assert np.allclose(fdm_vjp0, jax_vjp0) fdm_jac1 = fdm.jacobian(gg)(y_input) jax_jac1 = jax.jacrev(gg)(y_input) with check: assert np.allclose(fdm_jac1, jax_jac1) rngkey = jax.random.PRNGKey(1) v = jax.random.normal(rngkey, shape=(V.dim(), ), dtype="float64") fdm_vjp1 = v @ fdm_jac1 jax_vjp1 = jax.vjp(gg, y_input)[1](v) with check: assert np.allclose(fdm_vjp1, jax_vjp1)
def test_vjp_assemble_eval(): numpy_output, vjp_fun = vjp_fem_eval(assemble_fenics, templates, *inputs) g = np.ones_like(numpy_output) vjp_out = vjp_fun(g) fdm_jac0 = fdm.jacobian(ff0)(inputs[0]) fdm_jac1 = fdm.jacobian(ff1)(inputs[1]) fdm_jac2 = fdm.jacobian(ff2)(inputs[2]) check1 = np.allclose(vjp_out[0], fdm_jac0) check2 = np.allclose(vjp_out[1], fdm_jac1) check3 = np.allclose(vjp_out[2], fdm_jac2) assert check1 and check2 and check3
def test_vjp_assemble_eval(): numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal( assemble_fenics, templates, *inputs) g = np.ones_like(numpy_output) vjp_out = evaluate_pullback(fenics_output, fenics_inputs, tape, g) fdm_jac0 = fdm.jacobian(ff0)(inputs[0]) fdm_jac1 = fdm.jacobian(ff1)(inputs[1]) fdm_jac2 = fdm.jacobian(ff2)(inputs[2]) check1 = np.allclose(vjp_out[0], fdm_jac0) check2 = np.allclose(vjp_out[1], fdm_jac1) check3 = np.allclose(vjp_out[2], fdm_jac2) assert check1 and check2 and check3
def test_jacobian(): # skipping ff0 as it is expensive with fdm for func, inp in zip((ff1, ff2), (inputs[1], inputs[2])): jax_jac = jax.jacfwd(func)(inp) fdm_jac = fdm.jacobian(func)(inp) with check: assert np.allclose(jax_jac, fdm_jac)
def test_jacobian(): m = central_fdm(10, 1) a = np.random.randn(3, 3) def f(x): return np.matmul(a, x) x = np.random.randn(3) allclose(jacobian(f, m)(x), a)
def test_jacobian_and_vjp(): rngkey = jax.random.PRNGKey(0) v = jax.random.normal(rngkey, shape=(V.dim(), ), dtype="float64") # skipping ff0 as it is expensive with fdm for func, inp in zip((ff1, ff2), (inputs[1], inputs[2])): fdm_jac = fdm.jacobian(func)(inp) jax_jac = jax.jacrev(func)(inp) with check: assert np.allclose(fdm_jac, jax_jac) fdm_vjp = v @ fdm_jac jax_vjp = jax.vjp(func, inp)[1](v) with check: assert np.allclose(fdm_vjp, jax_vjp)
def test_jacobian(): jax_fwd_jac = jax.jvp(ff, (x_input, y_input), (np.ones_like(x_input), np.ones_like(y_input)))[1] fdm_jac0 = fdm.jacobian(hh)(x_input) fdm_jac1 = fdm.jacobian(gg)(y_input) assert np.allclose(fdm_jac0 + fdm_jac1, jax_fwd_jac[:, None])