def test_jvp_zeros(self): def foo(x): def bar(y): return jnp.sin(x * y) return jvp(bar, (3 * x,), (2 * x,)) jtu.check_eq(jit(foo)(0.5), foo(0.5))
def test_custom_linear_solve_aux(self): def explicit_jacobian_solve_aux(matvec, b): x = lax.stop_gradient(jnp.linalg.solve(jax.jacobian(matvec)(b), b)) return x, array_aux def matrix_free_solve_aux(matvec, b): return lax.custom_linear_solve(matvec, b, explicit_jacobian_solve_aux, explicit_jacobian_solve_aux, symmetric=True, has_aux=True) def linear_solve_aux(a, b): return matrix_free_solve_aux(partial(high_precision_dot, a), b) # array aux values, to be able to use jtu.check_grads array_aux = {"converged": np.array(1.), "nfev": np.array(12345.)} rng = self.rng() a = rng.randn(3, 3) a = a + a.T b = rng.randn(3) expected = jnp.linalg.solve(a, b) actual_nojit, nojit_aux = linear_solve_aux(a, b) actual_jit, jit_aux = jax.jit(linear_solve_aux)(a, b) self.assertAllClose(expected, actual_nojit) self.assertAllClose(expected, actual_jit) # scalar dict equality check self.assertDictEqual(nojit_aux, array_aux) self.assertDictEqual(jit_aux, array_aux) # jvp / vjp test jtu.check_grads(linear_solve_aux, (a, b), order=2, rtol=4e-3) # vmap test c = rng.randn(3, 2) expected = jnp.linalg.solve(a, c) expected_aux = tree_util.tree_map(partial(np.repeat, repeats=2), array_aux) actual_vmap, vmap_aux = jax.vmap(linear_solve_aux, (None, 1), -1)(a, c) self.assertAllClose(expected, actual_vmap) jtu.check_eq(expected_aux, vmap_aux)
def test_custom_root_with_aux(self): def root_aux(a, b): f = lambda x: high_precision_dot(a, x) - b factors = jsp.linalg.cho_factor(a) cho_solve = lambda f, b: (jsp.linalg.cho_solve(factors, b), orig_aux) def pos_def_solve(g, b): # prune aux to allow use as tangent_solve cho_solve_noaux = lambda f, b: cho_solve(f, b)[0] return lax.custom_linear_solve(g, b, cho_solve_noaux, symmetric=True) return lax.custom_root(f, b, cho_solve, pos_def_solve, has_aux=True) orig_aux = { "converged": np.array(1.), "nfev": np.array(12345.), "grad": np.array([1.0, 2.0, 3.0]) } rng = self.rng() a = rng.randn(2, 2) b = rng.randn(2) actual, actual_aux = root_aux(high_precision_dot(a, a.T), b) actual_jit, actual_jit_aux = jax.jit(root_aux)(high_precision_dot( a, a.T), b) expected = jnp.linalg.solve(high_precision_dot(a, a.T), b) self.assertAllClose(expected, actual) self.assertAllClose(expected, actual_jit) jtu.check_eq(actual_jit_aux, orig_aux) # grad check with aux jtu.check_grads(lambda x, y: root_aux(high_precision_dot(x, x.T), y), (a, b), order=2, rtol={jnp.float32: 1e-2}) # test vmap and jvp combined by jacfwd fwd = jax.jacfwd(lambda x, y: root_aux(high_precision_dot(x, x.T), y), argnums=(0, 1)) expected_fwd = jax.jacfwd( lambda x, y: jnp.linalg.solve(high_precision_dot(x, x.T), y), argnums=(0, 1)) fwd_val, fwd_aux = fwd(a, b) expected_fwd_val = expected_fwd(a, b) self.assertAllClose(fwd_val, expected_fwd_val, rtol={ np.float32: 5E-6, np.float64: 5E-12 }) jtu.check_close(fwd_aux, tree_util.tree_map(jnp.zeros_like, fwd_aux))