f = fa.Expression( "10*exp(-(pow(x[0] - 0.5, 2) + pow(x[1] - 0.5, 2)) / 0.02)", degree=2) u = fa.Function(V) bcs = [fa.DirichletBC(V, fa.Constant(0.0), "on_boundary")] inner, grad, dx = ufl.inner, ufl.grad, ufl.dx JJ = 0.5 * inner(kappa0 * grad(u), grad(u)) * dx - kappa1 * f * u * dx v = fenics.TestFunction(V) F = fenics.derivative(JJ, u, v) fa.solve(F == 0, u, bcs=bcs) return u templates = (fa.Constant(0.0), fa.Constant(0.0)) jax_solve_eval = build_jax_fem_eval(templates)(solve_fenics) # multivariate output function ff = lambda x, y: np.sqrt(np.square(jax_solve_eval(np.sqrt(x**3), y)) ) # noqa: E731 x_input = np.ones(1) y_input = 1.2 * np.ones(1) # multivariate output function of the first argument hh = lambda x: ff(x, y_input) # noqa: E731 # multivariate output function of the second argument gg = lambda y: ff(x_input, y) # noqa: E731 def test_jacobian(): fdm_jac0 = fdm.jacobian(hh)(x_input)
def assemble_firedrake(u, kappa0, kappa1): x = firedrake.SpatialCoordinate(mesh) f = x[0] inner, grad, dx = ufl.inner, ufl.grad, ufl.dx J_form = 0.5 * inner(kappa0 * grad(u), grad(u)) * dx - kappa1 * f * u * dx J = firedrake.assemble(J_form) return J templates = (firedrake.Function(V), firedrake.Constant(0.0), firedrake.Constant(0.0)) inputs = (np.ones(V.dim()), np.ones(1) * 0.5, np.ones(1) * 0.6) hh = build_jax_fem_eval(templates)(assemble_firedrake) hh0 = lambda x: hh(x, inputs[1], inputs[2]) # noqa: E731 hh1 = lambda y: hh(inputs[0], y, inputs[2]) # noqa: E731 hh2 = lambda z: hh(inputs[0], inputs[1], z) # noqa: E731 def test_jacobian_and_vjp(): fdm_jac0 = fdm.jacobian(hh0)(inputs[0]) jax_jac0 = jax.jacrev(hh0)(inputs[0]) with check: assert np.allclose(fdm_jac0, jax_jac0) v0 = np.asarray(onp.random.normal(size=(1, ))) fdm_vjp0 = v0 @ fdm_jac0 jax_vjp0 = jax.vjp(hh0, inputs[0])[1](v0) with check: assert np.allclose(fdm_vjp0, jax_vjp0)
def assemble_fenics(u, kappa0, kappa1): f = fa.Expression( "10*exp(-(pow(x[0] - 0.5, 2) + pow(x[1] - 0.5, 2)) / 0.02)", degree=2 ) inner, grad, dx = ufl.inner, ufl.grad, ufl.dx J_form = 0.5 * inner(kappa0 * grad(u), grad(u)) * dx - kappa1 * f * u * dx J = fa.assemble(J_form) return J templates = (fa.Function(V), fa.Constant(0.0), fa.Constant(0.0)) inputs = (np.ones(V.dim()), np.ones(1) * 0.5, np.ones(1) * 0.6) hh = build_jax_fem_eval(templates)(assemble_fenics) hh0 = lambda x: hh(x, inputs[1], inputs[2]) # noqa: E731 hh1 = lambda y: hh(inputs[0], y, inputs[2]) # noqa: E731 hh2 = lambda z: hh(inputs[0], inputs[1], z) # noqa: E731 def test_jacobian_and_vjp(): fdm_jac0 = fdm.jacobian(hh0)(inputs[0]) jax_jac0 = jax.jacrev(hh0)(inputs[0]) with check: assert np.allclose(fdm_jac0, jax_jac0) v0 = np.asarray(onp.random.normal(size=(1,))) fdm_vjp0 = v0 @ fdm_jac0 jax_vjp0 = jax.vjp(hh0, inputs[0])[1](v0) with check: assert np.allclose(fdm_vjp0, jax_vjp0)
f = x[0] u = firedrake.Function(V) bcs = [firedrake.DirichletBC(V, firedrake.Constant(0.0), "on_boundary")] inner, grad, dx = ufl.inner, ufl.grad, ufl.dx JJ = 0.5 * inner(kappa0 * grad(u), grad(u)) * dx - q * kappa1 * f * u * dx v = firedrake.TestFunction(V) F = firedrake.derivative(JJ, u, v) firedrake.solve(F == 0, u, bcs=bcs) return u templates = (firedrake.Function(V), firedrake.Constant(0.0), firedrake.Constant(0.0)) inputs = (np.ones(V.dim()), np.ones(1), np.ones(1) * 1.2) jax_solve_eval = build_jax_fem_eval(templates)(solve_firedrake) # multivariate output function ff = lambda x, y, z: np.sqrt( # noqa: E731 np.square(jax_solve_eval(x, np.sqrt(y ** 3), z)) ) ff0 = lambda x: ff(x, inputs[1], inputs[2]) # noqa: E731 ff1 = lambda y: ff(inputs[0], y, inputs[2]) # noqa: E731 ff2 = lambda z: ff(inputs[0], inputs[1], z) # noqa: E731 def test_vmap(): bdim = 2 vinputs = ( np.ones((bdim, V.dim())), np.ones((bdim, 1)) * 0.5,