def test_custom_linear_solve_cholesky(self): def positive_definite_solve(a, b): factors = jsp.linalg.cho_factor(a) def solve(matvec, x): return jsp.linalg.cho_solve(factors, x) matvec = partial(high_precision_dot, a) return lax.custom_linear_solve(matvec, b, solve, symmetric=True) rng = self.rng() a = rng.randn(2, 2) b = rng.randn(2) tol = {np.float32: 1E-3 if jtu.device_under_test() == "tpu" else 1E-5, np.float64: 1E-12} expected = jnp.linalg.solve(np.asarray(posify(a)), b) actual = positive_definite_solve(posify(a), b) self.assertAllClose(expected, actual, rtol=tol, atol=tol) actual = jax.jit(positive_definite_solve)(posify(a), b) self.assertAllClose(expected, actual, rtol=tol, atol=tol) # numerical gradients are only well defined if ``a`` is guaranteed to be # positive definite. jtu.check_grads( lambda x, y: positive_definite_solve(posify(x), y), (a, b), order=2, rtol=0.3)
def testResizeGradients(self, dtype, image_shape, target_shape, method, antialias): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(image_shape, dtype),) jax_fn = partial(image.resize, shape=target_shape, method=method, antialias=antialias) jtu.check_grads(jax_fn, args_maker(), order=2, rtol=1e-2, eps=1.)
def test_custom_linear_solve(self, symmetric): def explicit_jacobian_solve(matvec, b): return lax.stop_gradient(jnp.linalg.solve(jax.jacobian(matvec)(b), b)) def matrix_free_solve(matvec, b): return lax.custom_linear_solve( matvec, b, explicit_jacobian_solve, explicit_jacobian_solve, symmetric=symmetric) def linear_solve(a, b): return matrix_free_solve(partial(high_precision_dot, a), b) rng = self.rng() a = rng.randn(3, 3) if symmetric: a = a + a.T b = rng.randn(3) jtu.check_grads(linear_solve, (a, b), order=2, rtol=3e-3) expected = jnp.linalg.solve(a, b) actual = jax.jit(linear_solve)(a, b) self.assertAllClose(expected, actual) c = rng.randn(3, 2) expected = jnp.linalg.solve(a, c) actual = jax.vmap(linear_solve, (None, 1), 1)(a, c) self.assertAllClose(expected, actual)
def test_custom_linear_solve_without_transpose_solve(self): def explicit_jacobian_solve(matvec, b): return lax.stop_gradient( jnp.linalg.solve(jax.jacobian(matvec)(b), b)) def loss(a, b): matvec = partial(high_precision_dot, a) x = lax.custom_linear_solve(matvec, b, explicit_jacobian_solve) return jnp.sum(x) rng = self.rng() a = rng.randn(2, 2) b = rng.randn(2) jtu.check_grads(loss, (a, b), order=2, modes=['fwd'], atol={ np.float32: 2e-3, np.float64: 1e-11 }) jtu.check_grads(jax.vmap(loss), (a[None, :, :], b[None, :]), order=2, modes=['fwd'], atol={ np.float32: 2e-3, np.float64: 1e-11 }) with self.assertRaisesRegex(TypeError, "transpose_solve required"): jax.grad(loss)(a, b)
def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes, test_autodiff, nondiff_argnums): if (jtu.device_under_test() == "cpu" and (lax_op is lsp_special.gammainc or lax_op is lsp_special.gammaincc)): # TODO(b/173608403): re-enable test when LLVM bug is fixed. raise unittest.SkipTest("Skipping test due to LLVM lowering bug") rng = rng_factory(self.rng()) args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) self._CompileAndCheck(lax_op, args_maker, rtol=1e-4) if test_autodiff: def partial_lax_op(*vals): list_args = list(vals) for i in nondiff_argnums: list_args.insert(i, args[i]) return lax_op(*list_args) assert list(nondiff_argnums) == sorted(set(nondiff_argnums)) diff_args = [ x for i, x in enumerate(args) if i not in nondiff_argnums ] jtu.check_grads(partial_lax_op, diff_args, order=1, atol=jtu.if_device_under_test("tpu", .1, 1e-3), rtol=.1, eps=1e-3)
def test_custom_root_with_custom_linear_solve(self): def linear_solve(a, b): f = lambda x: high_precision_dot(a, x) - b factors = jsp.linalg.cho_factor(a) cho_solve = lambda f, b: jsp.linalg.cho_solve(factors, b) def pos_def_solve(g, b): return lax.custom_linear_solve(g, b, cho_solve, symmetric=True) return lax.custom_root(f, b, cho_solve, pos_def_solve) rng = self.rng() a = rng.randn(2, 2) b = rng.randn(2) actual = linear_solve(high_precision_dot(a, a.T), b) expected = jnp.linalg.solve(high_precision_dot(a, a.T), b) self.assertAllClose(expected, actual) actual = jax.jit(linear_solve)(high_precision_dot(a, a.T), b) expected = jnp.linalg.solve(high_precision_dot(a, a.T), b) self.assertAllClose(expected, actual) jtu.check_grads( lambda x, y: linear_solve(high_precision_dot(x, x.T), y), (a, b), order=2, rtol={jnp.float32: 1e-2})
def test_custom_root_vector_nonlinear(self): def nonlinear_func(x, y): # func(x, y) == 0 if and only if x == y. return (x - y) * (x**2 + y**2 + 1) def tangent_solve(g, y): return jnp.linalg.solve( jax.jacobian(g)(y).reshape(-1, y.size), y.ravel()).reshape(y.shape) def nonlinear_solve(y): f = lambda x: nonlinear_func(x, y) x0 = -jnp.ones_like(y) return lax.custom_root(f, x0, newton_raphson, tangent_solve) y = self.rng().randn(3, 1) jtu.check_grads(nonlinear_solve, (y, ), order=2, rtol={ jnp.float32: 1e-2, jnp.float64: 1e-3 }) actual = jax.jit(nonlinear_solve)(y) self.assertAllClose(y, actual, rtol=1e-5, atol=1e-5)
def test_custom_linear_solve_lu(self): def linear_solve(a, b): a_factors = jsp.linalg.lu_factor(a) at_factors = jsp.linalg.lu_factor(a.T) def solve(matvec, x): return jsp.linalg.lu_solve(a_factors, x) def transpose_solve(vecmat, x): return jsp.linalg.lu_solve(at_factors, x) return lax.custom_linear_solve( partial(high_precision_dot, a), b, solve, transpose_solve) rng = self.rng() a = rng.randn(3, 3) b = rng.randn(3) expected = jnp.linalg.solve(a, b) actual = linear_solve(a, b) self.assertAllClose(expected, actual) jtu.check_grads(linear_solve, (a, b), order=2, rtol=2e-3) # regression test for https://github.com/google/jax/issues/1536 jtu.check_grads(jax.jit(linear_solve), (a, b), order=2, rtol={np.float32: 2e-3})
def testFftn(self, inverse, real, shape, dtype, axes, s, norm): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype), ) jnp_op = _get_fftn_func(jnp.fft, inverse, real) np_op = _get_fftn_func(np.fft, inverse, real) jnp_fn = lambda a: jnp_op(a, axes=axes, norm=norm) np_fn = lambda a: np_op(a, axes=axes, norm=norm ) if axes is None or axes else a # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker) # Test gradient for differentiable types. if (config.x64_enabled and dtype in (float_dtypes if real and not inverse else inexact_dtypes)): # TODO(skye): can we be more precise? tol = 0.15 jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol) # check dtypes dtype = jnp_fn(rng(shape, dtype)).dtype expected_dtype = jnp.promote_types( float if inverse and real else complex, dtype) self.assertEqual(dtype, expected_dtype)
def test_autodiff(self, shape, dtype, k, is_max_k): vals = np.arange(prod(shape), dtype=dtype) vals = self.rng().permutation(vals).reshape(shape) if is_max_k: fn = lambda vs: ann.approx_max_k(vs, k=k)[0] else: fn = lambda vs: ann.approx_min_k(vs, k=k)[0] jtu.check_grads(fn, (vals,), 2, ["fwd", "rev"], eps=1e-2)
def test_grad_closure(self): # simplification of https://github.com/google/jax/issues/2718 def experiment(x): def model(y, t): return -x * y history = odeint(model, 1., np.arange(0, 10, 0.1)) return history[-1] jtu.check_grads(experiment, (0.01,), modes=["rev"], order=1)
def testAutodiff(self, mesh, resources): if len(mesh) != 2: return assert resources == ('x', 'y') # Add a constant captured by the nested pjit to make things more complicated h = jnp.arange(4) f = pjit(lambda x: x.sum(1) * h.sum(), in_axis_resources=P('x', 'y'), out_axis_resources=P(('x', 'y'))) g = pjit(lambda x: f(jnp.sin(x * 4 + 2)), in_axis_resources=P('x', None), out_axis_resources=P(('x', 'y'))) jtu.check_grads(g, (jnp.arange(16, dtype=jnp.float32).reshape((4, 4)) / 100,), order=2)
def test_pytree_state(self): """Test calling odeint with y(t) values that are pytrees.""" def dynamics(y, _t): return tree_map(jnp.negative, y) y0 = (np.array(-0.1), np.array([[[0.1]]])) ts = np.linspace(0., 1., 11) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 integrate = partial(odeint, dynamics) jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2, atol=tol, rtol=tol)
def test_weird_time_pendulum_grads(self): """Test that gradients are correct when the dynamics depend on t.""" def dynamics(_np, y, t): return _np.array([y[1] * -t, -1 * y[1] - 9.8 * _np.sin(y[0])]) y0 = [np.pi - 0.1, 0.0] ts = np.linspace(0., 1., 11) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 self.check_against_scipy(dynamics, y0, ts, tol=tol) integrate = partial(odeint, partial(dynamics, jnp)) jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2, rtol=tol, atol=tol)
def test_swoop_bigger(self): def swoop(_np, y, t, arg1, arg2): return _np.array(y - _np.sin(t) - _np.cos(t) * arg1 + arg2) ts = np.array([0.1, 0.2]) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 big_y0 = np.linspace(1.1, 10.9, 10) args = (0.1, 0.3) self.check_against_scipy(swoop, big_y0, ts, *args, tol=tol) integrate = partial(odeint, partial(swoop, jnp)) jtu.check_grads(integrate, (big_y0, ts, *args), modes=["rev"], order=2, rtol=tol, atol=tol)
def test_custom_linear_solve_complex(self): def solve(a, b): def solve(matvec, x): return jsp.linalg.solve(a, x) def tr_solve(matvec, x): return jsp.linalg.solve(a.T, x) matvec = partial(high_precision_dot, a) return lax.custom_linear_solve(matvec, b, solve, tr_solve) rng = self.rng() a = 0.5 * rng.randn(2, 2) + 0.5j * rng.randn(2, 2) b = 0.5 * rng.randn(2) + 0.5j * rng.randn(2) jtu.check_grads(solve, (a, b), order=2, rtol=1e-2)
def test_pend_grads(self): def pend(_np, y, _, m, g): theta, omega = y return [omega, -m * omega - g * _np.sin(theta)] y0 = [np.pi - 0.1, 0.0] ts = np.linspace(0., 1., 11) args = (0.25, 9.8) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 self.check_against_scipy(pend, y0, ts, *args, tol=tol) integrate = partial(odeint, partial(pend, jnp)) jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2, atol=tol, rtol=tol)
def test_complex_odeint(self): # https://github.com/google/jax/issues/3986 def dy_dt(y, t, alpha): return alpha * y def f(y0, ts, alpha): return odeint(dy_dt, y0, ts, alpha).real alpha = 3 + 4j y0 = 1 + 2j ts = jnp.linspace(0., 1., 11) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 jtu.check_grads(f, (y0, ts, alpha), modes=["rev"], order=2, atol=tol, rtol=tol)
def testRfftfreq(self, size, d, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng([size], dtype),) jnp_op = jnp.fft.rfftfreq np_op = np.fft.rfftfreq jnp_fn = lambda a: jnp_op(size, d=d) np_fn = lambda a: np_op(size, d=d) # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker) # Test gradient for differentiable types. if dtype in inexact_dtypes: tol = 0.15 # TODO(skye): can we be more precise? jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
def test_decay(self): def decay(_np, y, t, arg1, arg2): return -_np.sqrt(t) - y + arg1 - _np.mean((y + arg2)**2) rng = self.rng() args = (rng.randn(3), rng.randn(3)) y0 = rng.randn(3) ts = np.linspace(0.1, 0.2, 4) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 self.check_against_scipy(decay, y0, ts, *args, tol=tol) integrate = partial(odeint, partial(decay, jnp)) jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2, rtol=tol, atol=tol)
def test_custom_linear_solve_aux(self): def explicit_jacobian_solve_aux(matvec, b): x = lax.stop_gradient(jnp.linalg.solve(jax.jacobian(matvec)(b), b)) return x, array_aux def matrix_free_solve_aux(matvec, b): return lax.custom_linear_solve(matvec, b, explicit_jacobian_solve_aux, explicit_jacobian_solve_aux, symmetric=True, has_aux=True) def linear_solve_aux(a, b): return matrix_free_solve_aux(partial(high_precision_dot, a), b) # array aux values, to be able to use jtu.check_grads array_aux = {"converged": np.array(1.), "nfev": np.array(12345.)} rng = self.rng() a = rng.randn(3, 3) a = a + a.T b = rng.randn(3) expected = jnp.linalg.solve(a, b) actual_nojit, nojit_aux = linear_solve_aux(a, b) actual_jit, jit_aux = jax.jit(linear_solve_aux)(a, b) self.assertAllClose(expected, actual_nojit) self.assertAllClose(expected, actual_jit) # scalar dict equality check self.assertDictEqual(nojit_aux, array_aux) self.assertDictEqual(jit_aux, array_aux) # jvp / vjp test jtu.check_grads(linear_solve_aux, (a, b), order=2, rtol=4e-3) # vmap test c = rng.randn(3, 2) expected = jnp.linalg.solve(a, c) expected_aux = tree_util.tree_map(partial(np.repeat, repeats=2), array_aux) actual_vmap, vmap_aux = jax.vmap(linear_solve_aux, (None, 1), -1)(a, c) self.assertAllClose(expected, actual_vmap) jtu.check_eq(expected_aux, vmap_aux)
def test_custom_linear_solve_iterative(self): def richardson_iteration(matvec, b, omega=0.1, tolerance=1e-6): # Equivalent to vanilla gradient descent: # https://en.wikipedia.org/wiki/Modified_Richardson_iteration def cond(x): return jnp.linalg.norm(matvec(x) - b) > tolerance def body(x): return x + omega * (b - matvec(x)) return lax.while_loop(cond, body, b) def matrix_free_solve(matvec, b): return lax.custom_linear_solve(matvec, b, richardson_iteration, richardson_iteration) def build_and_solve(a, b): # intentionally non-linear in a and b matvec = partial(high_precision_dot, jnp.exp(a)) return matrix_free_solve(matvec, jnp.cos(b)) # rng = self.rng() # This test is very sensitive to the inputs, so we use a known working seed. rng = np.random.RandomState(0) a = rng.randn(2, 2) b = rng.randn(2) expected = jnp.linalg.solve(jnp.exp(a), jnp.cos(b)) actual = build_and_solve(a, b) self.assertAllClose(expected, actual, atol=1e-5) jtu.check_grads(build_and_solve, (a, b), atol=1e-5, order=2, rtol={ jnp.float32: 6e-2, jnp.float64: 2e-3 }) # vmap across an empty dimension jtu.check_grads(jax.vmap(build_and_solve), (a[None, :, :], b[None, :]), atol=1e-5, order=2, rtol={ jnp.float32: 6e-2, jnp.float64: 2e-3 })
def test_cg_as_solve(self, shape, dtype): rng = jtu.rand_default(self.rng()) a = rng(shape, dtype) b = rng(shape[:1], dtype) expected = np.linalg.solve(posify(a), b) actual = lax_cg(posify(a), b) self.assertAllClose(expected, actual, atol=1e-5, rtol=1e-5) actual = jit(lax_cg)(posify(a), b) self.assertAllClose(expected, actual, atol=1e-5, rtol=1e-5) # numerical gradients are only well defined if ``a`` is guaranteed to be # positive definite. jtu.check_grads(lambda x, y: lax_cg(posify(x), y), (a, b), order=2, rtol=2e-1)
def test_custom_linear_solve_zeros(self): def explicit_jacobian_solve(matvec, b): return lax.stop_gradient(jnp.linalg.solve(jax.jacobian(matvec)(b), b)) def matrix_free_solve(matvec, b): return lax.custom_linear_solve(matvec, b, explicit_jacobian_solve, explicit_jacobian_solve) def linear_solve(a, b): return matrix_free_solve(partial(high_precision_dot, a), b) rng = self.rng() a = rng.randn(3, 3) b = rng.randn(3) jtu.check_grads(lambda x: linear_solve(x, b), (a,), order=2, rtol={np.float32: 5e-3}) jtu.check_grads(lambda x: linear_solve(a, x), (b,), order=2, rtol={np.float32: 5e-3})
def test_custom_root_scalar(self, solve_method): def scalar_solve(f, y): return y / f(1.0) def sqrt_cubed(x, tangent_solve=scalar_solve): f = lambda y: y**2 - x**3 # Note: Nonzero derivative at x0 required for newton_raphson return lax.custom_root(f, 1.0, solve_method, tangent_solve) value, grad = jax.value_and_grad(sqrt_cubed)(5.0) self.assertAllClose(value, 5**1.5, check_dtypes=False, rtol=1e-6) self.assertAllClose(grad, jax.grad(pow)(5.0, 1.5), check_dtypes=False, rtol=1e-7) jtu.check_grads(sqrt_cubed, (5.0, ), order=2, rtol={ jnp.float32: 1e-2, jnp.float64: 1e-3 }) inputs = jnp.array([4.0, 5.0]) results = jax.vmap(sqrt_cubed)(inputs) self.assertAllClose( results, inputs**1.5, check_dtypes=False, atol={ jnp.float32: 1e-3, jnp.float64: 1e-6 }, rtol={ jnp.float32: 1e-3, jnp.float64: 1e-6 }, ) results = jax.jit(sqrt_cubed)(5.0) self.assertAllClose(results, 5.0**1.5, check_dtypes=False, rtol={np.float64: 1e-7})
def test_complex_odeint(self): # https://github.com/google/jax/issues/3986 # https://github.com/google/jax/issues/8757 def dy_dt(y, t, alpha): return alpha * y * jnp.exp(-t).astype(y.dtype) def f(y0, ts, alpha): return odeint(dy_dt, y0, ts, alpha).real alpha = 3 + 4j y0 = 1 + 2j ts = jnp.linspace(0., 1., 11) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 # During the backward pass, this ravels all parameters into a single array # such that dtype promotion is unavoidable. with jax.numpy_dtype_promotion('standard'): jtu.check_grads(f, (y0, ts, alpha), modes=["rev"], order=2, atol=tol, rtol=tol)
def test_custom_root_vector_with_solve_closure(self): def vector_solve(f, y): return jnp.linalg.solve(jax.jacobian(f)(y), y) def linear_solve(a, b): f = lambda y: high_precision_dot(a, y) - b x0 = jnp.zeros_like(b) solution = jnp.linalg.solve(a, b) oracle = lambda func, x0: solution return lax.custom_root(f, x0, oracle, vector_solve) rng = self.rng() a = rng.randn(2, 2) b = rng.randn(2) jtu.check_grads(linear_solve, (a, b), order=2, atol={ np.float32: 1e-2, np.float64: 1e-11 }) actual = jax.jit(linear_solve)(a, b) expected = jnp.linalg.solve(a, b) self.assertAllClose(expected, actual)
def test_custom_linear_solve_pytree(self): """Test custom linear solve with inputs and outputs that are pytrees.""" def unrolled_matvec(mat, x): """Apply a Python list of lists of scalars to a list of scalars.""" result = [] for i in range(len(mat)): v = 0 for j in range(len(x)): if mat[i][j] is not None: v += mat[i][j] * x[j] result.append(v) return result def unrolled_substitution_solve(matvec, b, lower_tri): """Solve a triangular unrolled system with fwd/back substitution.""" zero = jnp.zeros(()) one = jnp.ones(()) x = [zero for _ in b] ordering = range(len(b)) if lower_tri else range(len(b) - 1, -1, -1) for i in ordering: residual = b[i] - matvec(x)[i] diagonal = matvec([one if i == j else zero for j in range(len(b))])[i] x[i] = residual / diagonal return x def custom_unrolled_lower_tri_solve(mat, b): return lax.custom_linear_solve( partial(unrolled_matvec, mat), b, partial(unrolled_substitution_solve, lower_tri=True), partial(unrolled_substitution_solve, lower_tri=False)) mat = [[1.0, None, None, None, None, None, None], [1.0, 1.0, None, None, None, None, None], [None, 1.0, 1.0, None, None, None, None], [None, None, 1.0, 1.0, None, None, None], [None, None, None, 1.0, 1.0, None, None], [None, None, None, None, None, 2.0, None], [None, None, None, None, None, 4.0, 3.0]] rng = self.rng() b = list(rng.randn(7)) # Non-batched jtu.check_grads(custom_unrolled_lower_tri_solve, (mat, b), order=2, rtol={jnp.float32: 2e-2}) # Batch one element of b (which, because of unrolling, should only affect # the first block of outputs) b_bat = list(b) b_bat[3] = rng.randn(3) jtu.check_grads( jax.vmap( custom_unrolled_lower_tri_solve, in_axes=(None, [None, None, None, 0, None, None, None]), out_axes=[0, 0, 0, 0, 0, None, None]), (mat, b_bat), order=2, rtol={jnp.float32: 1e-2}) # Batch one element of mat (again only affecting first block) mat[2][1] = rng.randn(3) mat_axis_tree = [ [0 if i == 2 and j == 1 else None for j in range(7)] for i in range(7) ] jtu.check_grads( jax.vmap( custom_unrolled_lower_tri_solve, in_axes=(mat_axis_tree, None), out_axes=[0, 0, 0, 0, 0, None, None]), (mat, b), order=2)
def test_custom_root_with_aux(self): def root_aux(a, b): f = lambda x: high_precision_dot(a, x) - b factors = jsp.linalg.cho_factor(a) cho_solve = lambda f, b: (jsp.linalg.cho_solve(factors, b), orig_aux) def pos_def_solve(g, b): # prune aux to allow use as tangent_solve cho_solve_noaux = lambda f, b: cho_solve(f, b)[0] return lax.custom_linear_solve(g, b, cho_solve_noaux, symmetric=True) return lax.custom_root(f, b, cho_solve, pos_def_solve, has_aux=True) orig_aux = { "converged": np.array(1.), "nfev": np.array(12345.), "grad": np.array([1.0, 2.0, 3.0]) } rng = self.rng() a = rng.randn(2, 2) b = rng.randn(2) actual, actual_aux = root_aux(high_precision_dot(a, a.T), b) actual_jit, actual_jit_aux = jax.jit(root_aux)(high_precision_dot( a, a.T), b) expected = jnp.linalg.solve(high_precision_dot(a, a.T), b) self.assertAllClose(expected, actual) self.assertAllClose(expected, actual_jit) jtu.check_eq(actual_jit_aux, orig_aux) # grad check with aux jtu.check_grads(lambda x, y: root_aux(high_precision_dot(x, x.T), y), (a, b), order=2, rtol={jnp.float32: 1e-2}) # test vmap and jvp combined by jacfwd fwd = jax.jacfwd(lambda x, y: root_aux(high_precision_dot(x, x.T), y), argnums=(0, 1)) expected_fwd = jax.jacfwd( lambda x, y: jnp.linalg.solve(high_precision_dot(x, x.T), y), argnums=(0, 1)) fwd_val, fwd_aux = fwd(a, b) expected_fwd_val = expected_fwd(a, b) self.assertAllClose(fwd_val, expected_fwd_val, rtol={ np.float32: 5E-6, np.float64: 5E-12 }) jtu.check_close(fwd_aux, tree_util.tree_map(jnp.zeros_like, fwd_aux))