def test_fenics_vjp(): numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal( solve_fenics, templates, *inputs) g = np.ones_like(numpy_output) vjp_out = evaluate_vjp(g, fenics_output, fenics_inputs, tape) check1 = np.isclose(vjp_out[0], np.asarray(-2.91792642)) check2 = np.isclose(vjp_out[1], np.asarray(2.43160535)) assert check1 and check2
def decorator(fenics_function: Callable) -> Callable: @functools.wraps(fenics_function) def jax_fem_eval(*args): return jax_fem_eval_p.bind(*args) jax_fem_eval_p = Primitive("jax_fem_eval") jax_fem_eval_p.def_impl(lambda *args: evaluate_primal( fenics_function, fenics_templates, *args)[0]) jax_fem_eval_p.def_abstract_eval( lambda *args: jax.abstract_arrays.make_shaped_array( evaluate_primal(fenics_function, fenics_templates, *args)[0])) def jax_fem_eval_batch(vector_arg_values, batch_axes): assert len( set(batch_axes)) == 1 # assert that all batch axes are same assert ( batch_axes[0] == 0 ) # assert that batch axis is zero, need to rewrite for a general case? res = list(map(jax_fem_eval, *vector_arg_values)) res = np.asarray(res) return res, batch_axes[0] jax.interpreters.batching.primitive_batchers[ jax_fem_eval_p] = jax_fem_eval_batch # @trace("djax_fem_eval") def djax_fem_eval(*args): return djax_fem_eval_p.bind(*args) djax_fem_eval_p = Primitive("djax_fem_eval") # djax_fem_eval_p.multiple_results = True djax_fem_eval_p.def_impl(lambda *args: vjp_fem_eval( fenics_function, fenics_templates, *args)) defvjp_all(jax_fem_eval_p, djax_fem_eval) return jax_fem_eval
def test_vjp_assemble_eval(): numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal( assemble_fenics, templates, *inputs ) g = np.ones_like(numpy_output) vjp_out = evaluate_vjp(g, fenics_output, fenics_inputs, tape) fdm_jac0 = fdm.jacobian(ff0)(inputs[0]) fdm_jac1 = fdm.jacobian(ff1)(inputs[1]) fdm_jac2 = fdm.jacobian(ff2)(inputs[2]) check1 = np.allclose(vjp_out[0], fdm_jac0) check2 = np.allclose(vjp_out[1], fdm_jac1) check3 = np.allclose(vjp_out[2], fdm_jac2) assert check1 and check2 and check3
def test_fenics_forward(): numpy_output, _, _, _, = evaluate_primal(assemble_fenics, templates, *inputs) u1 = fa.interpolate(fa.Constant(1.0), V) J = assemble_fenics(u1, fa.Constant(0.5), fa.Constant(0.6)) assert np.isclose(numpy_output, J)
def assemble_fenics(u, kappa0, kappa1): f = fa.Expression( "10*exp(-(pow(x[0] - 0.5, 2) + pow(x[1] - 0.5, 2)) / 0.02)", degree=2 ) inner, grad, dx = ufl.inner, ufl.grad, ufl.dx J_form = 0.5 * inner(kappa0 * grad(u), grad(u)) * dx - kappa1 * f * u * dx J = fa.assemble(J_form) return J templates = (fa.Function(V), fa.Constant(0.0), fa.Constant(0.0)) inputs = (np.ones(V.dim()), np.ones(1) * 0.5, np.ones(1) * 0.6) ff = lambda *args: evaluate_primal(assemble_fenics, templates, *args)[0] # noqa: E731 ff0 = lambda x: ff(x, inputs[1], inputs[2]) # noqa: E731 ff1 = lambda y: ff(inputs[0], y, inputs[2]) # noqa: E731 ff2 = lambda z: ff(inputs[0], inputs[1], z) # noqa: E731 def test_fenics_forward(): numpy_output, _, _, _, = evaluate_primal(assemble_fenics, templates, *inputs) u1 = fa.interpolate(fa.Constant(1.0), V) J = assemble_fenics(u1, fa.Constant(0.5), fa.Constant(0.6)) assert np.isclose(numpy_output, J) def test_vjp_assemble_eval(): numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal( assemble_fenics, templates, *inputs
def test_fenics_forward(): numpy_output, _, _, _ = evaluate_primal(solve_fenics, templates, *inputs) u = solve_fenics(fa.Constant(0.5), fa.Constant(0.6)) assert np.allclose(numpy_output, fenics_to_numpy(u))
def vjp_fem_eval( fenics_function: Callable, fenics_templates: Iterable[FenicsVariable], *args: np.array, ) -> Tuple[np.array, Callable]: """Computes the gradients of the output with respect to the input Input: fenics_function (callable): FEniCS function to be executed during the forward pass args (tuple): jax array representation of the input to fenics_function Output: A pair where the first element is the value of fun applied to the arguments and the second element is a Python callable representing the VJP map from output cotangents to input cotangents. The returned VJP function must accept a value with the same shape as the value of fun applied to the arguments and must return a tuple with length equal to the number of positional arguments to fun. """ numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal( fenics_function, fenics_templates, *args) # @trace("vjp_fun1") def vjp_fun1(g): return vjp_fun1_p.bind(g) vjp_fun1_p = Primitive("vjp_fun1") vjp_fun1_p.multiple_results = True vjp_fun1_p.def_impl(lambda g: tuple( vjp if vjp is not None else jax.ad_util.zeros_like_jaxval(args[i]) for i, vjp in enumerate( evaluate_vjp(g, fenics_output, fenics_inputs, tape)))) # @trace("vjp_fun1_abstract_eval") def vjp_fun1_abstract_eval(g): if len(args) > 1: return tuple( (jax.abstract_arrays.ShapedArray(arg.shape, arg.dtype) for arg in args)) else: return (jax.abstract_arrays.ShapedArray((1, *args[0].shape), args[0].dtype), ) vjp_fun1_p.def_abstract_eval(vjp_fun1_abstract_eval) # @trace("vjp_fun1_batch") def vjp_fun1_batch(vector_arg_values, batch_axes): """Computes the batched version of the primitive. This must be a JAX-traceable function. Args: vector_arg_values: a tuple of arguments, each being a tensor of matching shape. batch_axes: the axes that are being batched. See vmap documentation. Returns: a tuple of the result, and the result axis that was batched. """ # _trace("Using vjp_fun1 to compute the batch:") assert ( batch_axes[0] == 0 ) # assert that batch axis is zero, need to rewrite for a general case? # apply function row-by-row res = list(map(vjp_fun1, *vector_arg_values)) # transpose resulting list res_T = list(itertools.zip_longest(*res)) return tuple(map(np.vstack, res_T)), (batch_axes[0], ) * len(args) jax.interpreters.batching.primitive_batchers[vjp_fun1_p] = vjp_fun1_batch return numpy_output, vjp_fun1
def jax_fem_eval_p_abstract_eval(*args): args = (jax_to_fenics_numpy(arg, ft) for arg, ft in zip(args, fenics_templates)) return jax.abstract_arrays.make_shaped_array( evaluate_primal(fenics_function, fenics_templates, *args)[0])
def jax_fem_eval_p_impl(*args): args = (jax_to_fenics_numpy(arg, ft) for arg, ft in zip(args, fenics_templates)) return evaluate_primal(fenics_function, fenics_templates, *args)[0]