def test_fenics_vjp():
    numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal(
        solve_fenics, templates, *inputs)
    g = np.ones_like(numpy_output)
    vjp_out = evaluate_vjp(g, fenics_output, fenics_inputs, tape)
    check1 = np.isclose(vjp_out[0], np.asarray(-2.91792642))
    check2 = np.isclose(vjp_out[1], np.asarray(2.43160535))
    assert check1 and check2
Exemple #2
0
    def decorator(fenics_function: Callable) -> Callable:
        @functools.wraps(fenics_function)
        def jax_fem_eval(*args):
            return jax_fem_eval_p.bind(*args)

        jax_fem_eval_p = Primitive("jax_fem_eval")
        jax_fem_eval_p.def_impl(lambda *args: evaluate_primal(
            fenics_function, fenics_templates, *args)[0])

        jax_fem_eval_p.def_abstract_eval(
            lambda *args: jax.abstract_arrays.make_shaped_array(
                evaluate_primal(fenics_function, fenics_templates, *args)[0]))

        def jax_fem_eval_batch(vector_arg_values, batch_axes):
            assert len(
                set(batch_axes)) == 1  # assert that all batch axes are same
            assert (
                batch_axes[0] == 0
            )  # assert that batch axis is zero, need to rewrite for a general case?
            res = list(map(jax_fem_eval, *vector_arg_values))
            res = np.asarray(res)
            return res, batch_axes[0]

        jax.interpreters.batching.primitive_batchers[
            jax_fem_eval_p] = jax_fem_eval_batch

        # @trace("djax_fem_eval")
        def djax_fem_eval(*args):
            return djax_fem_eval_p.bind(*args)

        djax_fem_eval_p = Primitive("djax_fem_eval")
        # djax_fem_eval_p.multiple_results = True
        djax_fem_eval_p.def_impl(lambda *args: vjp_fem_eval(
            fenics_function, fenics_templates, *args))

        defvjp_all(jax_fem_eval_p, djax_fem_eval)
        return jax_fem_eval
def test_vjp_assemble_eval():
    numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal(
        assemble_fenics, templates, *inputs
    )
    g = np.ones_like(numpy_output)
    vjp_out = evaluate_vjp(g, fenics_output, fenics_inputs, tape)

    fdm_jac0 = fdm.jacobian(ff0)(inputs[0])
    fdm_jac1 = fdm.jacobian(ff1)(inputs[1])
    fdm_jac2 = fdm.jacobian(ff2)(inputs[2])

    check1 = np.allclose(vjp_out[0], fdm_jac0)
    check2 = np.allclose(vjp_out[1], fdm_jac1)
    check3 = np.allclose(vjp_out[2], fdm_jac2)
    assert check1 and check2 and check3
def test_fenics_forward():
    numpy_output, _, _, _, = evaluate_primal(assemble_fenics, templates, *inputs)
    u1 = fa.interpolate(fa.Constant(1.0), V)
    J = assemble_fenics(u1, fa.Constant(0.5), fa.Constant(0.6))
    assert np.isclose(numpy_output, J)
def assemble_fenics(u, kappa0, kappa1):

    f = fa.Expression(
        "10*exp(-(pow(x[0] - 0.5, 2) + pow(x[1] - 0.5, 2)) / 0.02)", degree=2
    )

    inner, grad, dx = ufl.inner, ufl.grad, ufl.dx
    J_form = 0.5 * inner(kappa0 * grad(u), grad(u)) * dx - kappa1 * f * u * dx
    J = fa.assemble(J_form)
    return J


templates = (fa.Function(V), fa.Constant(0.0), fa.Constant(0.0))
inputs = (np.ones(V.dim()), np.ones(1) * 0.5, np.ones(1) * 0.6)
ff = lambda *args: evaluate_primal(assemble_fenics, templates, *args)[0]  # noqa: E731
ff0 = lambda x: ff(x, inputs[1], inputs[2])  # noqa: E731
ff1 = lambda y: ff(inputs[0], y, inputs[2])  # noqa: E731
ff2 = lambda z: ff(inputs[0], inputs[1], z)  # noqa: E731


def test_fenics_forward():
    numpy_output, _, _, _, = evaluate_primal(assemble_fenics, templates, *inputs)
    u1 = fa.interpolate(fa.Constant(1.0), V)
    J = assemble_fenics(u1, fa.Constant(0.5), fa.Constant(0.6))
    assert np.isclose(numpy_output, J)


def test_vjp_assemble_eval():
    numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal(
        assemble_fenics, templates, *inputs
def test_fenics_forward():
    numpy_output, _, _, _ = evaluate_primal(solve_fenics, templates, *inputs)
    u = solve_fenics(fa.Constant(0.5), fa.Constant(0.6))
    assert np.allclose(numpy_output, fenics_to_numpy(u))
Exemple #7
0
def vjp_fem_eval(
    fenics_function: Callable,
    fenics_templates: Iterable[FenicsVariable],
    *args: np.array,
) -> Tuple[np.array, Callable]:
    """Computes the gradients of the output with respect to the input
    Input:
        fenics_function (callable): FEniCS function to be executed during the forward pass
        args (tuple): jax array representation of the input to fenics_function
    Output:
        A pair where the first element is the value of fun applied to the arguments and the second element
        is a Python callable representing the VJP map from output cotangents to input cotangents.
        The returned VJP function must accept a value with the same shape as the value of fun applied
        to the arguments and must return a tuple with length equal to the number of positional arguments to fun.
    """

    numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal(
        fenics_function, fenics_templates, *args)

    # @trace("vjp_fun1")
    def vjp_fun1(g):
        return vjp_fun1_p.bind(g)

    vjp_fun1_p = Primitive("vjp_fun1")
    vjp_fun1_p.multiple_results = True
    vjp_fun1_p.def_impl(lambda g: tuple(
        vjp if vjp is not None else jax.ad_util.zeros_like_jaxval(args[i])
        for i, vjp in enumerate(
            evaluate_vjp(g, fenics_output, fenics_inputs, tape))))

    # @trace("vjp_fun1_abstract_eval")
    def vjp_fun1_abstract_eval(g):
        if len(args) > 1:
            return tuple(
                (jax.abstract_arrays.ShapedArray(arg.shape, arg.dtype)
                 for arg in args))
        else:
            return (jax.abstract_arrays.ShapedArray((1, *args[0].shape),
                                                    args[0].dtype), )

    vjp_fun1_p.def_abstract_eval(vjp_fun1_abstract_eval)

    # @trace("vjp_fun1_batch")
    def vjp_fun1_batch(vector_arg_values, batch_axes):
        """Computes the batched version of the primitive.

        This must be a JAX-traceable function.

        Args:
            vector_arg_values: a tuple of arguments, each being a tensor of matching
            shape.
            batch_axes: the axes that are being batched. See vmap documentation.
        Returns:
            a tuple of the result, and the result axis that was batched.
        """
        # _trace("Using vjp_fun1 to compute the batch:")
        assert (
            batch_axes[0] == 0
        )  # assert that batch axis is zero, need to rewrite for a general case?
        # apply function row-by-row
        res = list(map(vjp_fun1, *vector_arg_values))
        # transpose resulting list
        res_T = list(itertools.zip_longest(*res))
        return tuple(map(np.vstack, res_T)), (batch_axes[0], ) * len(args)

    jax.interpreters.batching.primitive_batchers[vjp_fun1_p] = vjp_fun1_batch

    return numpy_output, vjp_fun1
Exemple #8
0
 def jax_fem_eval_p_abstract_eval(*args):
     args = (jax_to_fenics_numpy(arg, ft)
             for arg, ft in zip(args, fenics_templates))
     return jax.abstract_arrays.make_shaped_array(
         evaluate_primal(fenics_function, fenics_templates, *args)[0])
Exemple #9
0
 def jax_fem_eval_p_impl(*args):
     args = (jax_to_fenics_numpy(arg, ft)
             for arg, ft in zip(args, fenics_templates))
     return evaluate_primal(fenics_function, fenics_templates, *args)[0]