Пример #1
0
def test_firedrake_forward():
    numpy_output, _, _, _, = evaluate_primal(assemble_firedrake, templates,
                                             *inputs)
    u1 = firedrake.interpolate(firedrake.Constant(1.0), V)
    J = assemble_firedrake(u1, firedrake.Constant(0.5),
                           firedrake.Constant(0.6))
    assert np.isclose(numpy_output, J)
Пример #2
0
    def perform(self, node, inputs, outputs):
        numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal(
            self.ofunc, self.templates, *inputs)

        self.vjp_op = FenicsVJPOp(self.ofunc, self.templates, fenics_output,
                                  tuple(fenics_inputs), tape)
        outputs[0][0] = numpy_output
Пример #3
0
 def primal(*args):
     numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal(
         fenics_function, fenics_templates, *args
     )
     return (
         numpy_output,
         (PyadjointMetadata(fenics_output, fenics_inputs, tape), args),
     )
Пример #4
0
    def decorator(fenics_function: Callable) -> Callable:
        @functools.wraps(fenics_function)
        @custom_vjp
        def jax_fem_eval(*args):
            return jax_fem_eval_p.bind(*args)

        jax_fem_eval_p = Primitive("jax_fem_eval")
        jax_fem_eval_p.def_impl(
            lambda *args: evaluate_primal(fenics_function, fenics_templates, *args)[0]
        )

        jax_fem_eval_p.def_abstract_eval(
            lambda *args: jax.abstract_arrays.make_shaped_array(
                evaluate_primal(fenics_function, fenics_templates, *args)[0]
            )
        )

        def jax_fem_eval_batch(vector_arg_values, batch_axes):
            assert len(set(batch_axes)) == 1  # assert that all batch axes are same
            assert (
                batch_axes[0] == 0
            )  # assert that batch axis is zero, need to rewrite for a general case?
            res = list(map(jax_fem_eval, *vector_arg_values))
            res = np.asarray(res)
            return res, batch_axes[0]

        jax.interpreters.batching.primitive_batchers[
            jax_fem_eval_p
        ] = jax_fem_eval_batch

        def primal(*args):
            numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal(
                fenics_function, fenics_templates, *args
            )
            return (
                numpy_output,
                (PyadjointMetadata(fenics_output, fenics_inputs, tape), args),
            )

        def pullback(aux_args, g):
            pb_fn = get_pullback_function(fenics_function, fenics_templates)
            # for some reason output of get_pullback_function is a list but we need tuple
            return tuple(pb_fn(aux_args, g))

        jax_fem_eval.defvjp(primal, pullback)
        return jax_fem_eval
Пример #5
0
def test_firedrake_vjp():
    numpy_output, firedrake_output, firedrake_inputs, tape = evaluate_primal(
        solve_firedrake, templates, *inputs)
    g = np.ones_like(numpy_output)
    vjp_out = evaluate_pullback(firedrake_output, firedrake_inputs, tape, g)
    check1 = np.isclose(vjp_out[0], np.asarray(-1.13533304))
    check2 = np.isclose(vjp_out[1], np.asarray(0.94611087))
    assert check1 and check2
Пример #6
0
def test_fenics_vjp():
    numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal(
        solve_fenics, templates, *inputs)
    g = np.ones_like(numpy_output)
    vjp_out = evaluate_pullback(fenics_output, fenics_inputs, tape, g)
    check1 = np.isclose(vjp_out[0], np.asarray(-2.91792642))
    check2 = np.isclose(vjp_out[1], np.asarray(2.43160535))
    assert check1 and check2
Пример #7
0
def test_theano_primal():
    theano.config.compute_test_value = "ignore"
    hh = create_fenics_theano_op(templates)(assemble_firedrake)
    x = theano.tensor.vector()
    y = theano.tensor.vector()
    z = theano.tensor.vector()
    f = theano.function([x, y, z], hh(x, y, z))
    theano_output = f(*inputs)
    numpy_putput = evaluate_primal(assemble_firedrake, templates, *inputs)[0]
    assert np.isclose(theano_output, numpy_putput)
Пример #8
0
def test_vjp_assemble_eval():
    numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal(
        assemble_fenics, templates, *inputs)
    g = np.ones_like(numpy_output)
    vjp_out = evaluate_pullback(fenics_output, fenics_inputs, tape, g)

    fdm_jac0 = fdm.jacobian(ff0)(inputs[0])
    fdm_jac1 = fdm.jacobian(ff1)(inputs[1])
    fdm_jac2 = fdm.jacobian(ff2)(inputs[2])

    check1 = np.allclose(vjp_out[0], fdm_jac0)
    check2 = np.allclose(vjp_out[1], fdm_jac1)
    check3 = np.allclose(vjp_out[2], fdm_jac2)
    assert check1 and check2 and check3
Пример #9
0
        def jvp_jax_fem_eval_impl(primals, tangents):
            primals = (
                jax_to_fenics_numpy(p, ft) for p, ft in zip(primals, fenics_templates)
            )
            numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal(
                fenics_function, fenics_templates, *primals
            )

            tangents = (
                jax_to_fenics_numpy(t, ft) for t, ft in zip(tangents, fenics_templates)
            )
            dnumpy_output = evaluate_pushforward(
                fenics_output, fenics_inputs, tape, tangents
            )
            return numpy_output, dnumpy_output
Пример #10
0
def test_theano_vjp():
    theano.config.compute_test_value = "ignore"
    numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal(
        assemble_firedrake, templates, *inputs)
    vjp_op = FenicsVJPOp(assemble_firedrake, templates, fenics_output,
                         tuple(fenics_inputs), tape)
    g = theano.tensor.vector()
    f = theano.function([g], vjp_op(g))
    theano_output = f(np.ones(1))

    numpy_output = evaluate_pullback(fenics_output, tuple(fenics_inputs), tape,
                                     np.ones(1))
    for to, no in zip(theano_output, numpy_output):
        with check:
            assert np.allclose(to, no)
Пример #11
0
def test_jvp_assemble_eval():
    primals = inputs
    tangent0 = np.random.normal(size=(V.dim(), ))
    tangent1 = np.random.normal(size=(1, ))
    tangent2 = np.random.normal(size=(1, ))
    tangents = (tangent0, tangent1, tangent2)

    fdm_jvp0 = fdm.jvp(ff0, tangents[0])(primals[0])
    fdm_jvp1 = fdm.jvp(ff1, tangents[1])(primals[1])
    fdm_jvp2 = fdm.jvp(ff2, tangents[2])(primals[2])

    _, fenics_output, fenics_inputs, tape = evaluate_primal(
        assemble_fenics, templates, *inputs)
    out_tangent = evaluate_pushforward(fenics_output, fenics_inputs, tape,
                                       tangents)

    assert np.allclose(fdm_jvp0 + fdm_jvp1 + fdm_jvp2, out_tangent)
Пример #12
0
def test_fenics_jvp():
    primals = inputs
    tangent0 = np.random.normal(size=(1, ))
    tangent1 = np.random.normal(size=(1, ))
    tangents = (tangent0, tangent1)

    eval_p = evaluate_primal
    ff0 = lambda x: eval_p(solve_fenics, templates, x, primals[1])[
        0]  # noqa: E731
    ff1 = lambda y: eval_p(solve_fenics, templates, primals[0], y)[
        0]  # noqa: E731
    fdm_jvp0 = fdm.jvp(ff0, tangents[0])(primals[0])
    fdm_jvp1 = fdm.jvp(ff1, tangents[1])(primals[1])

    _, fenics_output, fenics_inputs, tape = evaluate_primal(
        solve_fenics, templates, *inputs)
    out_tangent = evaluate_pushforward(fenics_output, fenics_inputs, tape,
                                       tangents)

    assert np.allclose(fdm_jvp0 + fdm_jvp1, out_tangent)
Пример #13
0
def test_fenics_forward():
    numpy_output, _, _, _ = evaluate_primal(solve_fenics, templates, *inputs)
    u = solve_fenics(fa.Constant(0.5), fa.Constant(0.6))
    assert np.allclose(numpy_output, to_numpy(u))
Пример #14
0
def test_fenics_forward():
    numpy_output, _, _, _, = evaluate_primal(assemble_fenics, templates,
                                             *inputs)
    u1 = fa.interpolate(fa.Constant(1.0), V)
    J = assemble_fenics(u1, fa.Constant(0.5), fa.Constant(0.6))
    assert np.isclose(numpy_output, J)
Пример #15
0

def assemble_fenics(u, kappa0, kappa1):

    f = fa.Expression(
        "10*exp(-(pow(x[0] - 0.5, 2) + pow(x[1] - 0.5, 2)) / 0.02)", degree=2)

    inner, grad, dx = ufl.inner, ufl.grad, ufl.dx
    J_form = 0.5 * inner(kappa0 * grad(u), grad(u)) * dx - kappa1 * f * u * dx
    J = fa.assemble(J_form)
    return J


templates = (fa.Function(V), fa.Constant(0.0), fa.Constant(0.0))
inputs = (np.ones(V.dim()), np.ones(1) * 0.5, np.ones(1) * 0.6)
ff = lambda *args: evaluate_primal(assemble_fenics, templates, *args)[
    0]  # noqa: E731
ff0 = lambda x: ff(x, inputs[1], inputs[2])  # noqa: E731
ff1 = lambda y: ff(inputs[0], y, inputs[2])  # noqa: E731
ff2 = lambda z: ff(inputs[0], inputs[1], z)  # noqa: E731


def test_fenics_forward():
    numpy_output, _, _, _, = evaluate_primal(assemble_fenics, templates,
                                             *inputs)
    u1 = fa.interpolate(fa.Constant(1.0), V)
    J = assemble_fenics(u1, fa.Constant(0.5), fa.Constant(0.6))
    assert np.isclose(numpy_output, J)


def test_vjp_assemble_eval():
    numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal(
Пример #16
0
def test_firedrake_forward():
    numpy_output, _, _, _ = evaluate_primal(solve_firedrake, templates,
                                            *inputs)
    u = solve_firedrake(firedrake.Constant(0.5), firedrake.Constant(0.6))
    assert np.allclose(numpy_output, to_numpy(u))
Пример #17
0
 def jax_fem_eval_p_impl(*args):
     args = (
         jax_to_fenics_numpy(arg, ft) for arg, ft in zip(args, fenics_templates)
     )
     return evaluate_primal(fenics_function, fenics_templates, *args)[0]
Пример #18
0
def assemble_firedrake(u, kappa0, kappa1):

    x = firedrake.SpatialCoordinate(mesh)
    f = x[0]

    inner, grad, dx = ufl.inner, ufl.grad, ufl.dx
    J_form = 0.5 * inner(kappa0 * grad(u), grad(u)) * dx - kappa1 * f * u * dx
    J = firedrake.assemble(J_form)
    return J


templates = (firedrake.Function(V), firedrake.Constant(0.0),
             firedrake.Constant(0.0))
inputs = (np.ones(V.dim()), np.ones(1) * 0.5, np.ones(1) * 0.6)
ff = lambda *args: evaluate_primal(assemble_firedrake, templates, *args
                                   )[  # noqa: E731
                                       0]
ff0 = lambda x: ff(x, inputs[1], inputs[2])  # noqa: E731
ff1 = lambda y: ff(inputs[0], y, inputs[2])  # noqa: E731
ff2 = lambda z: ff(inputs[0], inputs[1], z)  # noqa: E731


def test_firedrake_forward():
    numpy_output, _, _, _, = evaluate_primal(assemble_firedrake, templates,
                                             *inputs)
    u1 = firedrake.interpolate(firedrake.Constant(1.0), V)
    J = assemble_firedrake(u1, firedrake.Constant(0.5),
                           firedrake.Constant(0.6))
    assert np.isclose(numpy_output, J)