コード例 #1
0
    f = fa.Expression(
        "10*exp(-(pow(x[0] - 0.5, 2) + pow(x[1] - 0.5, 2)) / 0.02)", degree=2)

    u = fa.Function(V)
    bcs = [fa.DirichletBC(V, fa.Constant(0.0), "on_boundary")]

    inner, grad, dx = ufl.inner, ufl.grad, ufl.dx
    JJ = 0.5 * inner(kappa0 * grad(u), grad(u)) * dx - kappa1 * f * u * dx
    v = fenics.TestFunction(V)
    F = fenics.derivative(JJ, u, v)
    fa.solve(F == 0, u, bcs=bcs)
    return u


templates = (fa.Constant(0.0), fa.Constant(0.0))
jax_solve_eval = build_jax_fem_eval(templates)(solve_fenics)

# multivariate output function
ff = lambda x, y: np.sqrt(np.square(jax_solve_eval(np.sqrt(x**3), y))
                          )  # noqa: E731
x_input = np.ones(1)
y_input = 1.2 * np.ones(1)

# multivariate output function of the first argument
hh = lambda x: ff(x, y_input)  # noqa: E731
# multivariate output function of the second argument
gg = lambda y: ff(x_input, y)  # noqa: E731


def test_jacobian():
    fdm_jac0 = fdm.jacobian(hh)(x_input)
コード例 #2
0
def assemble_firedrake(u, kappa0, kappa1):

    x = firedrake.SpatialCoordinate(mesh)
    f = x[0]

    inner, grad, dx = ufl.inner, ufl.grad, ufl.dx
    J_form = 0.5 * inner(kappa0 * grad(u), grad(u)) * dx - kappa1 * f * u * dx
    J = firedrake.assemble(J_form)
    return J


templates = (firedrake.Function(V), firedrake.Constant(0.0),
             firedrake.Constant(0.0))
inputs = (np.ones(V.dim()), np.ones(1) * 0.5, np.ones(1) * 0.6)

hh = build_jax_fem_eval(templates)(assemble_firedrake)
hh0 = lambda x: hh(x, inputs[1], inputs[2])  # noqa: E731
hh1 = lambda y: hh(inputs[0], y, inputs[2])  # noqa: E731
hh2 = lambda z: hh(inputs[0], inputs[1], z)  # noqa: E731


def test_jacobian_and_vjp():
    fdm_jac0 = fdm.jacobian(hh0)(inputs[0])
    jax_jac0 = jax.jacrev(hh0)(inputs[0])
    with check:
        assert np.allclose(fdm_jac0, jax_jac0)
    v0 = np.asarray(onp.random.normal(size=(1, )))
    fdm_vjp0 = v0 @ fdm_jac0
    jax_vjp0 = jax.vjp(hh0, inputs[0])[1](v0)
    with check:
        assert np.allclose(fdm_vjp0, jax_vjp0)
コード例 #3
0
def assemble_fenics(u, kappa0, kappa1):

    f = fa.Expression(
        "10*exp(-(pow(x[0] - 0.5, 2) + pow(x[1] - 0.5, 2)) / 0.02)", degree=2
    )

    inner, grad, dx = ufl.inner, ufl.grad, ufl.dx
    J_form = 0.5 * inner(kappa0 * grad(u), grad(u)) * dx - kappa1 * f * u * dx
    J = fa.assemble(J_form)
    return J


templates = (fa.Function(V), fa.Constant(0.0), fa.Constant(0.0))
inputs = (np.ones(V.dim()), np.ones(1) * 0.5, np.ones(1) * 0.6)

hh = build_jax_fem_eval(templates)(assemble_fenics)
hh0 = lambda x: hh(x, inputs[1], inputs[2])  # noqa: E731
hh1 = lambda y: hh(inputs[0], y, inputs[2])  # noqa: E731
hh2 = lambda z: hh(inputs[0], inputs[1], z)  # noqa: E731


def test_jacobian_and_vjp():
    fdm_jac0 = fdm.jacobian(hh0)(inputs[0])
    jax_jac0 = jax.jacrev(hh0)(inputs[0])
    with check:
        assert np.allclose(fdm_jac0, jax_jac0)
    v0 = np.asarray(onp.random.normal(size=(1,)))
    fdm_vjp0 = v0 @ fdm_jac0
    jax_vjp0 = jax.vjp(hh0, inputs[0])[1](v0)
    with check:
        assert np.allclose(fdm_vjp0, jax_vjp0)
コード例 #4
0
    f = x[0]

    u = firedrake.Function(V)
    bcs = [firedrake.DirichletBC(V, firedrake.Constant(0.0), "on_boundary")]

    inner, grad, dx = ufl.inner, ufl.grad, ufl.dx
    JJ = 0.5 * inner(kappa0 * grad(u), grad(u)) * dx - q * kappa1 * f * u * dx
    v = firedrake.TestFunction(V)
    F = firedrake.derivative(JJ, u, v)
    firedrake.solve(F == 0, u, bcs=bcs)
    return u


templates = (firedrake.Function(V), firedrake.Constant(0.0), firedrake.Constant(0.0))
inputs = (np.ones(V.dim()), np.ones(1), np.ones(1) * 1.2)
jax_solve_eval = build_jax_fem_eval(templates)(solve_firedrake)

# multivariate output function
ff = lambda x, y, z: np.sqrt(  # noqa: E731
    np.square(jax_solve_eval(x, np.sqrt(y ** 3), z))
)
ff0 = lambda x: ff(x, inputs[1], inputs[2])  # noqa: E731
ff1 = lambda y: ff(inputs[0], y, inputs[2])  # noqa: E731
ff2 = lambda z: ff(inputs[0], inputs[1], z)  # noqa: E731


def test_vmap():
    bdim = 2
    vinputs = (
        np.ones((bdim, V.dim())),
        np.ones((bdim, 1)) * 0.5,