def test_firedrake_forward(): numpy_output, _, _, _, = evaluate_primal(assemble_firedrake, templates, *inputs) u1 = firedrake.interpolate(firedrake.Constant(1.0), V) J = assemble_firedrake(u1, firedrake.Constant(0.5), firedrake.Constant(0.6)) assert np.isclose(numpy_output, J)
def perform(self, node, inputs, outputs): numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal( self.ofunc, self.templates, *inputs) self.vjp_op = FenicsVJPOp(self.ofunc, self.templates, fenics_output, tuple(fenics_inputs), tape) outputs[0][0] = numpy_output
def primal(*args): numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal( fenics_function, fenics_templates, *args ) return ( numpy_output, (PyadjointMetadata(fenics_output, fenics_inputs, tape), args), )
def decorator(fenics_function: Callable) -> Callable: @functools.wraps(fenics_function) @custom_vjp def jax_fem_eval(*args): return jax_fem_eval_p.bind(*args) jax_fem_eval_p = Primitive("jax_fem_eval") jax_fem_eval_p.def_impl( lambda *args: evaluate_primal(fenics_function, fenics_templates, *args)[0] ) jax_fem_eval_p.def_abstract_eval( lambda *args: jax.abstract_arrays.make_shaped_array( evaluate_primal(fenics_function, fenics_templates, *args)[0] ) ) def jax_fem_eval_batch(vector_arg_values, batch_axes): assert len(set(batch_axes)) == 1 # assert that all batch axes are same assert ( batch_axes[0] == 0 ) # assert that batch axis is zero, need to rewrite for a general case? res = list(map(jax_fem_eval, *vector_arg_values)) res = np.asarray(res) return res, batch_axes[0] jax.interpreters.batching.primitive_batchers[ jax_fem_eval_p ] = jax_fem_eval_batch def primal(*args): numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal( fenics_function, fenics_templates, *args ) return ( numpy_output, (PyadjointMetadata(fenics_output, fenics_inputs, tape), args), ) def pullback(aux_args, g): pb_fn = get_pullback_function(fenics_function, fenics_templates) # for some reason output of get_pullback_function is a list but we need tuple return tuple(pb_fn(aux_args, g)) jax_fem_eval.defvjp(primal, pullback) return jax_fem_eval
def test_firedrake_vjp(): numpy_output, firedrake_output, firedrake_inputs, tape = evaluate_primal( solve_firedrake, templates, *inputs) g = np.ones_like(numpy_output) vjp_out = evaluate_pullback(firedrake_output, firedrake_inputs, tape, g) check1 = np.isclose(vjp_out[0], np.asarray(-1.13533304)) check2 = np.isclose(vjp_out[1], np.asarray(0.94611087)) assert check1 and check2
def test_fenics_vjp(): numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal( solve_fenics, templates, *inputs) g = np.ones_like(numpy_output) vjp_out = evaluate_pullback(fenics_output, fenics_inputs, tape, g) check1 = np.isclose(vjp_out[0], np.asarray(-2.91792642)) check2 = np.isclose(vjp_out[1], np.asarray(2.43160535)) assert check1 and check2
def test_theano_primal(): theano.config.compute_test_value = "ignore" hh = create_fenics_theano_op(templates)(assemble_firedrake) x = theano.tensor.vector() y = theano.tensor.vector() z = theano.tensor.vector() f = theano.function([x, y, z], hh(x, y, z)) theano_output = f(*inputs) numpy_putput = evaluate_primal(assemble_firedrake, templates, *inputs)[0] assert np.isclose(theano_output, numpy_putput)
def test_vjp_assemble_eval(): numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal( assemble_fenics, templates, *inputs) g = np.ones_like(numpy_output) vjp_out = evaluate_pullback(fenics_output, fenics_inputs, tape, g) fdm_jac0 = fdm.jacobian(ff0)(inputs[0]) fdm_jac1 = fdm.jacobian(ff1)(inputs[1]) fdm_jac2 = fdm.jacobian(ff2)(inputs[2]) check1 = np.allclose(vjp_out[0], fdm_jac0) check2 = np.allclose(vjp_out[1], fdm_jac1) check3 = np.allclose(vjp_out[2], fdm_jac2) assert check1 and check2 and check3
def jvp_jax_fem_eval_impl(primals, tangents): primals = ( jax_to_fenics_numpy(p, ft) for p, ft in zip(primals, fenics_templates) ) numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal( fenics_function, fenics_templates, *primals ) tangents = ( jax_to_fenics_numpy(t, ft) for t, ft in zip(tangents, fenics_templates) ) dnumpy_output = evaluate_pushforward( fenics_output, fenics_inputs, tape, tangents ) return numpy_output, dnumpy_output
def test_theano_vjp(): theano.config.compute_test_value = "ignore" numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal( assemble_firedrake, templates, *inputs) vjp_op = FenicsVJPOp(assemble_firedrake, templates, fenics_output, tuple(fenics_inputs), tape) g = theano.tensor.vector() f = theano.function([g], vjp_op(g)) theano_output = f(np.ones(1)) numpy_output = evaluate_pullback(fenics_output, tuple(fenics_inputs), tape, np.ones(1)) for to, no in zip(theano_output, numpy_output): with check: assert np.allclose(to, no)
def test_jvp_assemble_eval(): primals = inputs tangent0 = np.random.normal(size=(V.dim(), )) tangent1 = np.random.normal(size=(1, )) tangent2 = np.random.normal(size=(1, )) tangents = (tangent0, tangent1, tangent2) fdm_jvp0 = fdm.jvp(ff0, tangents[0])(primals[0]) fdm_jvp1 = fdm.jvp(ff1, tangents[1])(primals[1]) fdm_jvp2 = fdm.jvp(ff2, tangents[2])(primals[2]) _, fenics_output, fenics_inputs, tape = evaluate_primal( assemble_fenics, templates, *inputs) out_tangent = evaluate_pushforward(fenics_output, fenics_inputs, tape, tangents) assert np.allclose(fdm_jvp0 + fdm_jvp1 + fdm_jvp2, out_tangent)
def test_fenics_jvp(): primals = inputs tangent0 = np.random.normal(size=(1, )) tangent1 = np.random.normal(size=(1, )) tangents = (tangent0, tangent1) eval_p = evaluate_primal ff0 = lambda x: eval_p(solve_fenics, templates, x, primals[1])[ 0] # noqa: E731 ff1 = lambda y: eval_p(solve_fenics, templates, primals[0], y)[ 0] # noqa: E731 fdm_jvp0 = fdm.jvp(ff0, tangents[0])(primals[0]) fdm_jvp1 = fdm.jvp(ff1, tangents[1])(primals[1]) _, fenics_output, fenics_inputs, tape = evaluate_primal( solve_fenics, templates, *inputs) out_tangent = evaluate_pushforward(fenics_output, fenics_inputs, tape, tangents) assert np.allclose(fdm_jvp0 + fdm_jvp1, out_tangent)
def test_fenics_forward(): numpy_output, _, _, _ = evaluate_primal(solve_fenics, templates, *inputs) u = solve_fenics(fa.Constant(0.5), fa.Constant(0.6)) assert np.allclose(numpy_output, to_numpy(u))
def test_fenics_forward(): numpy_output, _, _, _, = evaluate_primal(assemble_fenics, templates, *inputs) u1 = fa.interpolate(fa.Constant(1.0), V) J = assemble_fenics(u1, fa.Constant(0.5), fa.Constant(0.6)) assert np.isclose(numpy_output, J)
def assemble_fenics(u, kappa0, kappa1): f = fa.Expression( "10*exp(-(pow(x[0] - 0.5, 2) + pow(x[1] - 0.5, 2)) / 0.02)", degree=2) inner, grad, dx = ufl.inner, ufl.grad, ufl.dx J_form = 0.5 * inner(kappa0 * grad(u), grad(u)) * dx - kappa1 * f * u * dx J = fa.assemble(J_form) return J templates = (fa.Function(V), fa.Constant(0.0), fa.Constant(0.0)) inputs = (np.ones(V.dim()), np.ones(1) * 0.5, np.ones(1) * 0.6) ff = lambda *args: evaluate_primal(assemble_fenics, templates, *args)[ 0] # noqa: E731 ff0 = lambda x: ff(x, inputs[1], inputs[2]) # noqa: E731 ff1 = lambda y: ff(inputs[0], y, inputs[2]) # noqa: E731 ff2 = lambda z: ff(inputs[0], inputs[1], z) # noqa: E731 def test_fenics_forward(): numpy_output, _, _, _, = evaluate_primal(assemble_fenics, templates, *inputs) u1 = fa.interpolate(fa.Constant(1.0), V) J = assemble_fenics(u1, fa.Constant(0.5), fa.Constant(0.6)) assert np.isclose(numpy_output, J) def test_vjp_assemble_eval(): numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal(
def test_firedrake_forward(): numpy_output, _, _, _ = evaluate_primal(solve_firedrake, templates, *inputs) u = solve_firedrake(firedrake.Constant(0.5), firedrake.Constant(0.6)) assert np.allclose(numpy_output, to_numpy(u))
def jax_fem_eval_p_impl(*args): args = ( jax_to_fenics_numpy(arg, ft) for arg, ft in zip(args, fenics_templates) ) return evaluate_primal(fenics_function, fenics_templates, *args)[0]
def assemble_firedrake(u, kappa0, kappa1): x = firedrake.SpatialCoordinate(mesh) f = x[0] inner, grad, dx = ufl.inner, ufl.grad, ufl.dx J_form = 0.5 * inner(kappa0 * grad(u), grad(u)) * dx - kappa1 * f * u * dx J = firedrake.assemble(J_form) return J templates = (firedrake.Function(V), firedrake.Constant(0.0), firedrake.Constant(0.0)) inputs = (np.ones(V.dim()), np.ones(1) * 0.5, np.ones(1) * 0.6) ff = lambda *args: evaluate_primal(assemble_firedrake, templates, *args )[ # noqa: E731 0] ff0 = lambda x: ff(x, inputs[1], inputs[2]) # noqa: E731 ff1 = lambda y: ff(inputs[0], y, inputs[2]) # noqa: E731 ff2 = lambda z: ff(inputs[0], inputs[1], z) # noqa: E731 def test_firedrake_forward(): numpy_output, _, _, _, = evaluate_primal(assemble_firedrake, templates, *inputs) u1 = firedrake.interpolate(firedrake.Constant(1.0), V) J = assemble_firedrake(u1, firedrake.Constant(0.5), firedrake.Constant(0.6)) assert np.isclose(numpy_output, J)