Example #1
0
 def test_jit(self, f, args):
   jtu.check_close(jit(f)(*args), f(*args))
Example #2
0
    def test_custom_root_with_aux(self):
        def root_aux(a, b):
            f = lambda x: high_precision_dot(a, x) - b
            factors = jsp.linalg.cho_factor(a)
            cho_solve = lambda f, b: (jsp.linalg.cho_solve(factors, b),
                                      orig_aux)

            def pos_def_solve(g, b):
                # prune aux to allow use as tangent_solve
                cho_solve_noaux = lambda f, b: cho_solve(f, b)[0]
                return lax.custom_linear_solve(g,
                                               b,
                                               cho_solve_noaux,
                                               symmetric=True)

            return lax.custom_root(f,
                                   b,
                                   cho_solve,
                                   pos_def_solve,
                                   has_aux=True)

        orig_aux = {
            "converged": np.array(1.),
            "nfev": np.array(12345.),
            "grad": np.array([1.0, 2.0, 3.0])
        }

        rng = self.rng()
        a = rng.randn(2, 2)
        b = rng.randn(2)

        actual, actual_aux = root_aux(high_precision_dot(a, a.T), b)
        actual_jit, actual_jit_aux = jax.jit(root_aux)(high_precision_dot(
            a, a.T), b)
        expected = jnp.linalg.solve(high_precision_dot(a, a.T), b)

        self.assertAllClose(expected, actual)
        self.assertAllClose(expected, actual_jit)
        jtu.check_eq(actual_jit_aux, orig_aux)

        # grad check with aux
        jtu.check_grads(lambda x, y: root_aux(high_precision_dot(x, x.T), y),
                        (a, b),
                        order=2,
                        rtol={jnp.float32: 1e-2})

        # test vmap and jvp combined by jacfwd
        fwd = jax.jacfwd(lambda x, y: root_aux(high_precision_dot(x, x.T), y),
                         argnums=(0, 1))
        expected_fwd = jax.jacfwd(
            lambda x, y: jnp.linalg.solve(high_precision_dot(x, x.T), y),
            argnums=(0, 1))

        fwd_val, fwd_aux = fwd(a, b)
        expected_fwd_val = expected_fwd(a, b)
        self.assertAllClose(fwd_val,
                            expected_fwd_val,
                            rtol={
                                np.float32: 5E-6,
                                np.float64: 5E-12
                            })

        jtu.check_close(fwd_aux, tree_util.tree_map(jnp.zeros_like, fwd_aux))