def test_jit(self, f, args): jtu.check_close(jit(f)(*args), f(*args))
def test_custom_root_with_aux(self): def root_aux(a, b): f = lambda x: high_precision_dot(a, x) - b factors = jsp.linalg.cho_factor(a) cho_solve = lambda f, b: (jsp.linalg.cho_solve(factors, b), orig_aux) def pos_def_solve(g, b): # prune aux to allow use as tangent_solve cho_solve_noaux = lambda f, b: cho_solve(f, b)[0] return lax.custom_linear_solve(g, b, cho_solve_noaux, symmetric=True) return lax.custom_root(f, b, cho_solve, pos_def_solve, has_aux=True) orig_aux = { "converged": np.array(1.), "nfev": np.array(12345.), "grad": np.array([1.0, 2.0, 3.0]) } rng = self.rng() a = rng.randn(2, 2) b = rng.randn(2) actual, actual_aux = root_aux(high_precision_dot(a, a.T), b) actual_jit, actual_jit_aux = jax.jit(root_aux)(high_precision_dot( a, a.T), b) expected = jnp.linalg.solve(high_precision_dot(a, a.T), b) self.assertAllClose(expected, actual) self.assertAllClose(expected, actual_jit) jtu.check_eq(actual_jit_aux, orig_aux) # grad check with aux jtu.check_grads(lambda x, y: root_aux(high_precision_dot(x, x.T), y), (a, b), order=2, rtol={jnp.float32: 1e-2}) # test vmap and jvp combined by jacfwd fwd = jax.jacfwd(lambda x, y: root_aux(high_precision_dot(x, x.T), y), argnums=(0, 1)) expected_fwd = jax.jacfwd( lambda x, y: jnp.linalg.solve(high_precision_dot(x, x.T), y), argnums=(0, 1)) fwd_val, fwd_aux = fwd(a, b) expected_fwd_val = expected_fwd(a, b) self.assertAllClose(fwd_val, expected_fwd_val, rtol={ np.float32: 5E-6, np.float64: 5E-12 }) jtu.check_close(fwd_aux, tree_util.tree_map(jnp.zeros_like, fwd_aux))