Exemplo n.º 1
0
def test_jacobian_and_vjp():
    fdm_jac0 = fdm.jacobian(hh0)(inputs[0])
    jax_jac0 = jax.jacrev(hh0)(inputs[0])
    with check:
        assert np.allclose(fdm_jac0, jax_jac0)
    v0 = np.asarray(onp.random.normal(size=(1,)))
    fdm_vjp0 = v0 @ fdm_jac0
    jax_vjp0 = jax.vjp(hh0, inputs[0])[1](v0)
    with check:
        assert np.allclose(fdm_vjp0, jax_vjp0)

    fdm_jac1 = fdm.jacobian(hh1)(inputs[1])
    jax_jac1 = jax.jacrev(hh1)(inputs[1])
    with check:
        assert np.allclose(fdm_jac1, jax_jac1)
    v1 = np.asarray(onp.random.normal(size=(1,)))
    fdm_vjp1 = v1 @ fdm_jac1
    jax_vjp1 = jax.vjp(hh1, inputs[1])[1](v1)
    with check:
        assert np.allclose(fdm_vjp1, jax_vjp1)

    fdm_jac2 = fdm.jacobian(hh2)(inputs[2])
    jax_jac2 = jax.jacrev(hh2)(inputs[2])
    with check:
        assert np.allclose(fdm_jac2, jax_jac2)
    v2 = np.asarray(onp.random.normal(size=(1,)))
    fdm_vjp2 = v2 @ fdm_jac2
    jax_vjp2 = jax.vjp(hh2, inputs[2])[1](v2)
    with check:
        assert np.allclose(fdm_vjp2, jax_vjp2)
Exemplo n.º 2
0
def test_jacobian():
    fdm_jac0 = fdm.jacobian(hh)(x_input)
    jax_jac0 = jax.jacrev(hh)(x_input)

    with check:
        assert np.allclose(fdm_jac0, jax_jac0)

    rngkey = jax.random.PRNGKey(0)
    v = jax.random.normal(rngkey, shape=(V.dim(), ), dtype="float64")

    fdm_vjp0 = v @ fdm_jac0
    jax_vjp0 = jax.vjp(hh, x_input)[1](v)

    with check:
        assert np.allclose(fdm_vjp0, jax_vjp0)

    fdm_jac1 = fdm.jacobian(gg)(y_input)
    jax_jac1 = jax.jacrev(gg)(y_input)

    with check:
        assert np.allclose(fdm_jac1, jax_jac1)

    rngkey = jax.random.PRNGKey(1)
    v = jax.random.normal(rngkey, shape=(V.dim(), ), dtype="float64")

    fdm_vjp1 = v @ fdm_jac1
    jax_vjp1 = jax.vjp(gg, y_input)[1](v)

    with check:
        assert np.allclose(fdm_vjp1, jax_vjp1)
Exemplo n.º 3
0
def test_vjp_assemble_eval():
    numpy_output, vjp_fun = vjp_fem_eval(assemble_fenics, templates, *inputs)
    g = np.ones_like(numpy_output)
    vjp_out = vjp_fun(g)

    fdm_jac0 = fdm.jacobian(ff0)(inputs[0])
    fdm_jac1 = fdm.jacobian(ff1)(inputs[1])
    fdm_jac2 = fdm.jacobian(ff2)(inputs[2])

    check1 = np.allclose(vjp_out[0], fdm_jac0)
    check2 = np.allclose(vjp_out[1], fdm_jac1)
    check3 = np.allclose(vjp_out[2], fdm_jac2)
    assert check1 and check2 and check3
Exemplo n.º 4
0
def test_vjp_assemble_eval():
    numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal(
        assemble_fenics, templates, *inputs)
    g = np.ones_like(numpy_output)
    vjp_out = evaluate_pullback(fenics_output, fenics_inputs, tape, g)

    fdm_jac0 = fdm.jacobian(ff0)(inputs[0])
    fdm_jac1 = fdm.jacobian(ff1)(inputs[1])
    fdm_jac2 = fdm.jacobian(ff2)(inputs[2])

    check1 = np.allclose(vjp_out[0], fdm_jac0)
    check2 = np.allclose(vjp_out[1], fdm_jac1)
    check3 = np.allclose(vjp_out[2], fdm_jac2)
    assert check1 and check2 and check3
def test_jacobian():
    # skipping ff0 as it is expensive with fdm
    for func, inp in zip((ff1, ff2), (inputs[1], inputs[2])):
        jax_jac = jax.jacfwd(func)(inp)
        fdm_jac = fdm.jacobian(func)(inp)
        with check:
            assert np.allclose(jax_jac, fdm_jac)
def test_jacobian():
    m = central_fdm(10, 1)
    a = np.random.randn(3, 3)

    def f(x):
        return np.matmul(a, x)

    x = np.random.randn(3)
    allclose(jacobian(f, m)(x), a)
Exemplo n.º 7
0
def test_jacobian_and_vjp():
    rngkey = jax.random.PRNGKey(0)
    v = jax.random.normal(rngkey, shape=(V.dim(), ), dtype="float64")
    # skipping ff0 as it is expensive with fdm
    for func, inp in zip((ff1, ff2), (inputs[1], inputs[2])):
        fdm_jac = fdm.jacobian(func)(inp)
        jax_jac = jax.jacrev(func)(inp)

        with check:
            assert np.allclose(fdm_jac, jax_jac)

        fdm_vjp = v @ fdm_jac
        jax_vjp = jax.vjp(func, inp)[1](v)

        with check:
            assert np.allclose(fdm_vjp, jax_vjp)
def test_jacobian():
    jax_fwd_jac = jax.jvp(ff, (x_input, y_input),
                          (np.ones_like(x_input), np.ones_like(y_input)))[1]
    fdm_jac0 = fdm.jacobian(hh)(x_input)
    fdm_jac1 = fdm.jacobian(gg)(y_input)
    assert np.allclose(fdm_jac0 + fdm_jac1, jax_fwd_jac[:, None])