Ejemplo n.º 1
0
  def test_jvp_zeros(self):
    def foo(x):
      def bar(y):
        return jnp.sin(x * y)
      return jvp(bar, (3 * x,), (2 * x,))

    jtu.check_eq(jit(foo)(0.5), foo(0.5))
Ejemplo n.º 2
0
    def test_custom_linear_solve_aux(self):
        def explicit_jacobian_solve_aux(matvec, b):
            x = lax.stop_gradient(jnp.linalg.solve(jax.jacobian(matvec)(b), b))
            return x, array_aux

        def matrix_free_solve_aux(matvec, b):
            return lax.custom_linear_solve(matvec,
                                           b,
                                           explicit_jacobian_solve_aux,
                                           explicit_jacobian_solve_aux,
                                           symmetric=True,
                                           has_aux=True)

        def linear_solve_aux(a, b):
            return matrix_free_solve_aux(partial(high_precision_dot, a), b)

        # array aux values, to be able to use jtu.check_grads
        array_aux = {"converged": np.array(1.), "nfev": np.array(12345.)}
        rng = self.rng()
        a = rng.randn(3, 3)
        a = a + a.T
        b = rng.randn(3)

        expected = jnp.linalg.solve(a, b)
        actual_nojit, nojit_aux = linear_solve_aux(a, b)
        actual_jit, jit_aux = jax.jit(linear_solve_aux)(a, b)

        self.assertAllClose(expected, actual_nojit)
        self.assertAllClose(expected, actual_jit)
        # scalar dict equality check
        self.assertDictEqual(nojit_aux, array_aux)
        self.assertDictEqual(jit_aux, array_aux)

        # jvp / vjp test
        jtu.check_grads(linear_solve_aux, (a, b), order=2, rtol=4e-3)

        # vmap test
        c = rng.randn(3, 2)
        expected = jnp.linalg.solve(a, c)
        expected_aux = tree_util.tree_map(partial(np.repeat, repeats=2),
                                          array_aux)
        actual_vmap, vmap_aux = jax.vmap(linear_solve_aux, (None, 1), -1)(a, c)

        self.assertAllClose(expected, actual_vmap)
        jtu.check_eq(expected_aux, vmap_aux)
Ejemplo n.º 3
0
    def test_custom_root_with_aux(self):
        def root_aux(a, b):
            f = lambda x: high_precision_dot(a, x) - b
            factors = jsp.linalg.cho_factor(a)
            cho_solve = lambda f, b: (jsp.linalg.cho_solve(factors, b),
                                      orig_aux)

            def pos_def_solve(g, b):
                # prune aux to allow use as tangent_solve
                cho_solve_noaux = lambda f, b: cho_solve(f, b)[0]
                return lax.custom_linear_solve(g,
                                               b,
                                               cho_solve_noaux,
                                               symmetric=True)

            return lax.custom_root(f,
                                   b,
                                   cho_solve,
                                   pos_def_solve,
                                   has_aux=True)

        orig_aux = {
            "converged": np.array(1.),
            "nfev": np.array(12345.),
            "grad": np.array([1.0, 2.0, 3.0])
        }

        rng = self.rng()
        a = rng.randn(2, 2)
        b = rng.randn(2)

        actual, actual_aux = root_aux(high_precision_dot(a, a.T), b)
        actual_jit, actual_jit_aux = jax.jit(root_aux)(high_precision_dot(
            a, a.T), b)
        expected = jnp.linalg.solve(high_precision_dot(a, a.T), b)

        self.assertAllClose(expected, actual)
        self.assertAllClose(expected, actual_jit)
        jtu.check_eq(actual_jit_aux, orig_aux)

        # grad check with aux
        jtu.check_grads(lambda x, y: root_aux(high_precision_dot(x, x.T), y),
                        (a, b),
                        order=2,
                        rtol={jnp.float32: 1e-2})

        # test vmap and jvp combined by jacfwd
        fwd = jax.jacfwd(lambda x, y: root_aux(high_precision_dot(x, x.T), y),
                         argnums=(0, 1))
        expected_fwd = jax.jacfwd(
            lambda x, y: jnp.linalg.solve(high_precision_dot(x, x.T), y),
            argnums=(0, 1))

        fwd_val, fwd_aux = fwd(a, b)
        expected_fwd_val = expected_fwd(a, b)
        self.assertAllClose(fwd_val,
                            expected_fwd_val,
                            rtol={
                                np.float32: 5E-6,
                                np.float64: 5E-12
                            })

        jtu.check_close(fwd_aux, tree_util.tree_map(jnp.zeros_like, fwd_aux))