def test_defvjp_all_higher_order_revmode(self): foo_p = Primitive('foo') def foo(x): return 2. * foo_p.bind(x) ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (g * x ** 2,))) ans = api.grad(api.grad(foo))(3.) self.assertAllClose(ans, 2 * 2 * 3., check_dtypes=False)
def test_defvjp_all_const(self): foo_p = Primitive('foo') def foo(x): return foo_p.bind(x) ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (12.,))) val_ans, grad_ans = api.value_and_grad(foo)(3.) self.assertAllClose(val_ans, 9., check_dtypes=False) self.assertAllClose(grad_ans, 12., check_dtypes=True)
def test_defvjp_all(self): foo_p = Primitive('foo') def foo(x): return 2. * foo_p.bind(x) ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (4 * g * np.sin(x),))) val_ans, grad_ans = api.value_and_grad(foo)(3.) self.assertAllClose(val_ans, 2 * 3.**2, check_dtypes=False) self.assertAllClose(grad_ans, 4 * 2 * onp.sin(3.), check_dtypes=False)
def test_defvjp_all_multiple_arguments(self): # also tests passing in symbolic zero tangents b/c we differentiate wrt only # the first argument in one case foo_p = Primitive('foo') def foo(x, y): return foo_p.bind(x, y) def vjpfun(x, y): out = x**2 + y**3 vjp = lambda g: (g + x + y, g * x * 9.) return out, vjp ad.defvjp_all(foo_p, vjpfun) val_ans, grad_ans = api.value_and_grad(foo)(3., 4.) self.assertAllClose(val_ans, 3.**2 + 4.**3, check_dtypes=False) self.assertAllClose(grad_ans, 1. + 3. + 4., check_dtypes=False) ans = api.grad(foo, (0, 1))(3., 4.) self.assertAllClose(ans, (1. + 3. + 4., 1. * 3. * 9.), check_dtypes=False)
def decorator(fenics_function: Callable) -> Callable: @functools.wraps(fenics_function) def jax_solve_eval(*args): return jax_solve_eval_p.bind(*args) jax_solve_eval_p = Primitive("jax_solve_eval") jax_solve_eval_p.def_impl(lambda *args: solve_eval( fenics_function, fenics_templates, *args)[0]) jax_solve_eval_p.def_abstract_eval( lambda *args: jax.abstract_arrays.make_shaped_array( solve_eval(fenics_function, fenics_templates, *args)[0])) def jax_solve_eval_batch(vector_arg_values, batch_axes): assert len( set(batch_axes)) == 1 # assert that all batch axes are same assert ( batch_axes[0] == 0 ) # assert that batch axis is zero, need to rewrite for a general case? # compute function row-by-row res = np.asarray([ jax_solve_eval(*(vector_arg_values[j][i] for j in range(len(batch_axes)))) for i in range(vector_arg_values[0].shape[0]) ]) return res, batch_axes[0] jax.batching.primitive_batchers[ jax_solve_eval_p] = jax_solve_eval_batch # @trace("djax_solve_eval") def djax_solve_eval(*args): return djax_solve_eval_p.bind(*args) djax_solve_eval_p = Primitive("djax_solve_eval") # djax_solve_eval_p.multiple_results = True djax_solve_eval_p.def_impl(lambda *args: vjp_solve_eval( fenics_function, fenics_templates, *args)) defvjp_all(jax_solve_eval_p, djax_solve_eval) return jax_solve_eval
def decorator(fenics_function: Callable) -> Callable: def jax_assemble_eval(*args): return jax_assemble_eval_p.bind(*args) jax_assemble_eval_p = Primitive("jax_assemble_eval") jax_assemble_eval_p.def_impl( lambda *args: assemble_eval(fenics_function, fenics_templates, *args)[0] ) jax_assemble_eval_p.def_abstract_eval( lambda *args: jax.abstract_arrays.make_shaped_array( assemble_eval(fenics_function, fenics_templates, *args)[0] ) ) def jax_assemble_eval_batch(vector_arg_values, batch_axes): assert len(set(batch_axes)) == 1 # assert that all batch axes are same assert ( batch_axes[0] == 0 ) # assert that batch axis is zero, need to rewrite for a general case? res = list(map(jax_assemble_eval, *vector_arg_values)) res = np.asarray(res) return res, batch_axes[0] batching.primitive_batchers[jax_assemble_eval_p] = jax_assemble_eval_batch # @trace("djax_assemble_eval") def djax_assemble_eval(*args): return djax_assemble_eval_p.bind(*args) djax_assemble_eval_p = Primitive("djax_assemble_eval") # djax_assemble_eval_p.multiple_results = True djax_assemble_eval_p.def_impl( lambda *args: vjp_assemble_eval(fenics_function, fenics_templates, *args) ) defvjp_all(jax_assemble_eval_p, djax_assemble_eval) return jax_assemble_eval