Esempio n. 1
0
  def test_custom_linear_solve_cholesky(self):

    def positive_definite_solve(a, b):
      factors = jsp.linalg.cho_factor(a)
      def solve(matvec, x):
        return jsp.linalg.cho_solve(factors, x)
      matvec = partial(high_precision_dot, a)
      return lax.custom_linear_solve(matvec, b, solve, symmetric=True)

    rng = self.rng()
    a = rng.randn(2, 2)
    b = rng.randn(2)

    tol = {np.float32: 1E-3 if jtu.device_under_test() == "tpu" else 1E-5,
           np.float64: 1E-12}
    expected = jnp.linalg.solve(np.asarray(posify(a)), b)
    actual = positive_definite_solve(posify(a), b)
    self.assertAllClose(expected, actual, rtol=tol, atol=tol)

    actual = jax.jit(positive_definite_solve)(posify(a), b)
    self.assertAllClose(expected, actual, rtol=tol, atol=tol)

    # numerical gradients are only well defined if ``a`` is guaranteed to be
    # positive definite.
    jtu.check_grads(
        lambda x, y: positive_definite_solve(posify(x), y),
        (a, b), order=2, rtol=0.3)
Esempio n. 2
0
 def testResizeGradients(self, dtype, image_shape, target_shape, method,
                          antialias):
   rng = jtu.rand_default(self.rng())
   args_maker = lambda: (rng(image_shape, dtype),)
   jax_fn = partial(image.resize, shape=target_shape, method=method,
                    antialias=antialias)
   jtu.check_grads(jax_fn, args_maker(), order=2, rtol=1e-2, eps=1.)
Esempio n. 3
0
  def test_custom_linear_solve(self, symmetric):

    def explicit_jacobian_solve(matvec, b):
      return lax.stop_gradient(jnp.linalg.solve(jax.jacobian(matvec)(b), b))

    def matrix_free_solve(matvec, b):
      return lax.custom_linear_solve(
          matvec, b, explicit_jacobian_solve, explicit_jacobian_solve,
          symmetric=symmetric)

    def linear_solve(a, b):
      return matrix_free_solve(partial(high_precision_dot, a), b)

    rng = self.rng()
    a = rng.randn(3, 3)
    if symmetric:
      a = a + a.T
    b = rng.randn(3)
    jtu.check_grads(linear_solve, (a, b), order=2, rtol=3e-3)

    expected = jnp.linalg.solve(a, b)
    actual = jax.jit(linear_solve)(a, b)
    self.assertAllClose(expected, actual)

    c = rng.randn(3, 2)
    expected = jnp.linalg.solve(a, c)
    actual = jax.vmap(linear_solve, (None, 1), 1)(a, c)
    self.assertAllClose(expected, actual)
Esempio n. 4
0
    def test_custom_linear_solve_without_transpose_solve(self):
        def explicit_jacobian_solve(matvec, b):
            return lax.stop_gradient(
                jnp.linalg.solve(jax.jacobian(matvec)(b), b))

        def loss(a, b):
            matvec = partial(high_precision_dot, a)
            x = lax.custom_linear_solve(matvec, b, explicit_jacobian_solve)
            return jnp.sum(x)

        rng = self.rng()
        a = rng.randn(2, 2)
        b = rng.randn(2)

        jtu.check_grads(loss, (a, b),
                        order=2,
                        modes=['fwd'],
                        atol={
                            np.float32: 2e-3,
                            np.float64: 1e-11
                        })
        jtu.check_grads(jax.vmap(loss), (a[None, :, :], b[None, :]),
                        order=2,
                        modes=['fwd'],
                        atol={
                            np.float32: 2e-3,
                            np.float64: 1e-11
                        })

        with self.assertRaisesRegex(TypeError, "transpose_solve required"):
            jax.grad(loss)(a, b)
Esempio n. 5
0
    def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes,
                            dtypes, test_autodiff, nondiff_argnums):
        if (jtu.device_under_test() == "cpu"
                and (lax_op is lsp_special.gammainc
                     or lax_op is lsp_special.gammaincc)):
            # TODO(b/173608403): re-enable test when LLVM bug is fixed.
            raise unittest.SkipTest("Skipping test due to LLVM lowering bug")
        rng = rng_factory(self.rng())
        args_maker = self._GetArgsMaker(rng, shapes, dtypes)
        args = args_maker()
        self.assertAllClose(scipy_op(*args),
                            lax_op(*args),
                            atol=1e-3,
                            rtol=1e-3,
                            check_dtypes=False)
        self._CompileAndCheck(lax_op, args_maker, rtol=1e-4)

        if test_autodiff:

            def partial_lax_op(*vals):
                list_args = list(vals)
                for i in nondiff_argnums:
                    list_args.insert(i, args[i])
                return lax_op(*list_args)

            assert list(nondiff_argnums) == sorted(set(nondiff_argnums))
            diff_args = [
                x for i, x in enumerate(args) if i not in nondiff_argnums
            ]
            jtu.check_grads(partial_lax_op,
                            diff_args,
                            order=1,
                            atol=jtu.if_device_under_test("tpu", .1, 1e-3),
                            rtol=.1,
                            eps=1e-3)
Esempio n. 6
0
    def test_custom_root_with_custom_linear_solve(self):
        def linear_solve(a, b):
            f = lambda x: high_precision_dot(a, x) - b
            factors = jsp.linalg.cho_factor(a)
            cho_solve = lambda f, b: jsp.linalg.cho_solve(factors, b)

            def pos_def_solve(g, b):
                return lax.custom_linear_solve(g, b, cho_solve, symmetric=True)

            return lax.custom_root(f, b, cho_solve, pos_def_solve)

        rng = self.rng()
        a = rng.randn(2, 2)
        b = rng.randn(2)

        actual = linear_solve(high_precision_dot(a, a.T), b)
        expected = jnp.linalg.solve(high_precision_dot(a, a.T), b)
        self.assertAllClose(expected, actual)

        actual = jax.jit(linear_solve)(high_precision_dot(a, a.T), b)
        expected = jnp.linalg.solve(high_precision_dot(a, a.T), b)
        self.assertAllClose(expected, actual)

        jtu.check_grads(
            lambda x, y: linear_solve(high_precision_dot(x, x.T), y), (a, b),
            order=2,
            rtol={jnp.float32: 1e-2})
Esempio n. 7
0
    def test_custom_root_vector_nonlinear(self):
        def nonlinear_func(x, y):
            # func(x, y) == 0 if and only if x == y.
            return (x - y) * (x**2 + y**2 + 1)

        def tangent_solve(g, y):
            return jnp.linalg.solve(
                jax.jacobian(g)(y).reshape(-1, y.size),
                y.ravel()).reshape(y.shape)

        def nonlinear_solve(y):
            f = lambda x: nonlinear_func(x, y)
            x0 = -jnp.ones_like(y)
            return lax.custom_root(f, x0, newton_raphson, tangent_solve)

        y = self.rng().randn(3, 1)
        jtu.check_grads(nonlinear_solve, (y, ),
                        order=2,
                        rtol={
                            jnp.float32: 1e-2,
                            jnp.float64: 1e-3
                        })

        actual = jax.jit(nonlinear_solve)(y)
        self.assertAllClose(y, actual, rtol=1e-5, atol=1e-5)
Esempio n. 8
0
  def test_custom_linear_solve_lu(self):

    def linear_solve(a, b):
      a_factors = jsp.linalg.lu_factor(a)
      at_factors = jsp.linalg.lu_factor(a.T)
      def solve(matvec, x):
        return jsp.linalg.lu_solve(a_factors, x)
      def transpose_solve(vecmat, x):
        return jsp.linalg.lu_solve(at_factors, x)
      return lax.custom_linear_solve(
          partial(high_precision_dot, a), b, solve, transpose_solve)

    rng = self.rng()
    a = rng.randn(3, 3)
    b = rng.randn(3)

    expected = jnp.linalg.solve(a, b)
    actual = linear_solve(a, b)
    self.assertAllClose(expected, actual)

    jtu.check_grads(linear_solve, (a, b), order=2, rtol=2e-3)

    # regression test for https://github.com/google/jax/issues/1536
    jtu.check_grads(jax.jit(linear_solve), (a, b), order=2,
                    rtol={np.float32: 2e-3})
Esempio n. 9
0
    def testFftn(self, inverse, real, shape, dtype, axes, s, norm):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: (rng(shape, dtype), )
        jnp_op = _get_fftn_func(jnp.fft, inverse, real)
        np_op = _get_fftn_func(np.fft, inverse, real)
        jnp_fn = lambda a: jnp_op(a, axes=axes, norm=norm)
        np_fn = lambda a: np_op(a, axes=axes, norm=norm
                                ) if axes is None or axes else a
        # Numpy promotes to complex128 aggressively.
        self._CheckAgainstNumpy(np_fn,
                                jnp_fn,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(jnp_fn, args_maker)
        # Test gradient for differentiable types.
        if (config.x64_enabled and dtype
                in (float_dtypes if real and not inverse else inexact_dtypes)):
            # TODO(skye): can we be more precise?
            tol = 0.15
            jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)

        # check dtypes
        dtype = jnp_fn(rng(shape, dtype)).dtype
        expected_dtype = jnp.promote_types(
            float if inverse and real else complex, dtype)
        self.assertEqual(dtype, expected_dtype)
Esempio n. 10
0
 def test_autodiff(self, shape, dtype, k, is_max_k):
   vals = np.arange(prod(shape), dtype=dtype)
   vals = self.rng().permutation(vals).reshape(shape)
   if is_max_k:
     fn = lambda vs: ann.approx_max_k(vs, k=k)[0]
   else:
     fn = lambda vs: ann.approx_min_k(vs, k=k)[0]
   jtu.check_grads(fn, (vals,), 2, ["fwd", "rev"], eps=1e-2)
Esempio n. 11
0
 def test_grad_closure(self):
   # simplification of https://github.com/google/jax/issues/2718
   def experiment(x):
     def model(y, t):
       return -x * y
     history = odeint(model, 1., np.arange(0, 10, 0.1))
     return history[-1]
   jtu.check_grads(experiment, (0.01,), modes=["rev"], order=1)
Esempio n. 12
0
 def testAutodiff(self, mesh, resources):
   if len(mesh) != 2: return
   assert resources == ('x', 'y')
   # Add a constant captured by the nested pjit to make things more complicated
   h = jnp.arange(4)
   f = pjit(lambda x: x.sum(1) * h.sum(),
            in_axis_resources=P('x', 'y'), out_axis_resources=P(('x', 'y')))
   g = pjit(lambda x: f(jnp.sin(x * 4 + 2)),
            in_axis_resources=P('x', None), out_axis_resources=P(('x', 'y')))
   jtu.check_grads(g, (jnp.arange(16, dtype=jnp.float32).reshape((4, 4)) / 100,),
                   order=2)
Esempio n. 13
0
  def test_pytree_state(self):
    """Test calling odeint with y(t) values that are pytrees."""
    def dynamics(y, _t):
      return tree_map(jnp.negative, y)

    y0 = (np.array(-0.1), np.array([[[0.1]]]))
    ts = np.linspace(0., 1., 11)
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3

    integrate = partial(odeint, dynamics)
    jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2,
                    atol=tol, rtol=tol)
Esempio n. 14
0
  def test_weird_time_pendulum_grads(self):
    """Test that gradients are correct when the dynamics depend on t."""
    def dynamics(_np, y, t):
      return _np.array([y[1] * -t, -1 * y[1] - 9.8 * _np.sin(y[0])])

    y0 = [np.pi - 0.1, 0.0]
    ts = np.linspace(0., 1., 11)
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3

    self.check_against_scipy(dynamics, y0, ts, tol=tol)

    integrate = partial(odeint, partial(dynamics, jnp))
    jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2,
                    rtol=tol, atol=tol)
Esempio n. 15
0
  def test_swoop_bigger(self):
    def swoop(_np, y, t, arg1, arg2):
      return _np.array(y - _np.sin(t) - _np.cos(t) * arg1 + arg2)

    ts = np.array([0.1, 0.2])
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3
    big_y0 = np.linspace(1.1, 10.9, 10)
    args = (0.1, 0.3)

    self.check_against_scipy(swoop, big_y0, ts, *args, tol=tol)

    integrate = partial(odeint, partial(swoop, jnp))
    jtu.check_grads(integrate, (big_y0, ts, *args), modes=["rev"], order=2,
                    rtol=tol, atol=tol)
Esempio n. 16
0
  def test_custom_linear_solve_complex(self):

    def solve(a, b):
      def solve(matvec, x):
        return jsp.linalg.solve(a, x)
      def tr_solve(matvec, x):
        return jsp.linalg.solve(a.T, x)
      matvec = partial(high_precision_dot, a)
      return lax.custom_linear_solve(matvec, b, solve, tr_solve)

    rng = self.rng()
    a = 0.5 * rng.randn(2, 2) + 0.5j * rng.randn(2, 2)
    b = 0.5 * rng.randn(2) + 0.5j * rng.randn(2)
    jtu.check_grads(solve, (a, b), order=2, rtol=1e-2)
Esempio n. 17
0
  def test_pend_grads(self):
    def pend(_np, y, _, m, g):
      theta, omega = y
      return [omega, -m * omega - g * _np.sin(theta)]

    y0 = [np.pi - 0.1, 0.0]
    ts = np.linspace(0., 1., 11)
    args = (0.25, 9.8)
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3

    self.check_against_scipy(pend, y0, ts, *args, tol=tol)

    integrate = partial(odeint, partial(pend, jnp))
    jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2,
                    atol=tol, rtol=tol)
Esempio n. 18
0
  def test_complex_odeint(self):
    # https://github.com/google/jax/issues/3986

    def dy_dt(y, t, alpha):
      return alpha * y

    def f(y0, ts, alpha):
      return odeint(dy_dt, y0, ts, alpha).real

    alpha = 3 + 4j
    y0 = 1 + 2j
    ts = jnp.linspace(0., 1., 11)
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3

    jtu.check_grads(f, (y0, ts, alpha), modes=["rev"], order=2, atol=tol, rtol=tol)
Esempio n. 19
0
 def testRfftfreq(self, size, d, dtype):
   rng = jtu.rand_default(self.rng())
   args_maker = lambda: (rng([size], dtype),)
   jnp_op = jnp.fft.rfftfreq
   np_op = np.fft.rfftfreq
   jnp_fn = lambda a: jnp_op(size, d=d)
   np_fn = lambda a: np_op(size, d=d)
   # Numpy promotes to complex128 aggressively.
   self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
                           tol=1e-4)
   self._CompileAndCheck(jnp_fn, args_maker)
   # Test gradient for differentiable types.
   if dtype in inexact_dtypes:
     tol = 0.15  # TODO(skye): can we be more precise?
     jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
Esempio n. 20
0
  def test_decay(self):
    def decay(_np, y, t, arg1, arg2):
        return -_np.sqrt(t) - y + arg1 - _np.mean((y + arg2)**2)


    rng = self.rng()
    args = (rng.randn(3), rng.randn(3))
    y0 = rng.randn(3)
    ts = np.linspace(0.1, 0.2, 4)
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3

    self.check_against_scipy(decay, y0, ts, *args, tol=tol)

    integrate = partial(odeint, partial(decay, jnp))
    jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2,
                    rtol=tol, atol=tol)
Esempio n. 21
0
    def test_custom_linear_solve_aux(self):
        def explicit_jacobian_solve_aux(matvec, b):
            x = lax.stop_gradient(jnp.linalg.solve(jax.jacobian(matvec)(b), b))
            return x, array_aux

        def matrix_free_solve_aux(matvec, b):
            return lax.custom_linear_solve(matvec,
                                           b,
                                           explicit_jacobian_solve_aux,
                                           explicit_jacobian_solve_aux,
                                           symmetric=True,
                                           has_aux=True)

        def linear_solve_aux(a, b):
            return matrix_free_solve_aux(partial(high_precision_dot, a), b)

        # array aux values, to be able to use jtu.check_grads
        array_aux = {"converged": np.array(1.), "nfev": np.array(12345.)}
        rng = self.rng()
        a = rng.randn(3, 3)
        a = a + a.T
        b = rng.randn(3)

        expected = jnp.linalg.solve(a, b)
        actual_nojit, nojit_aux = linear_solve_aux(a, b)
        actual_jit, jit_aux = jax.jit(linear_solve_aux)(a, b)

        self.assertAllClose(expected, actual_nojit)
        self.assertAllClose(expected, actual_jit)
        # scalar dict equality check
        self.assertDictEqual(nojit_aux, array_aux)
        self.assertDictEqual(jit_aux, array_aux)

        # jvp / vjp test
        jtu.check_grads(linear_solve_aux, (a, b), order=2, rtol=4e-3)

        # vmap test
        c = rng.randn(3, 2)
        expected = jnp.linalg.solve(a, c)
        expected_aux = tree_util.tree_map(partial(np.repeat, repeats=2),
                                          array_aux)
        actual_vmap, vmap_aux = jax.vmap(linear_solve_aux, (None, 1), -1)(a, c)

        self.assertAllClose(expected, actual_vmap)
        jtu.check_eq(expected_aux, vmap_aux)
Esempio n. 22
0
    def test_custom_linear_solve_iterative(self):
        def richardson_iteration(matvec, b, omega=0.1, tolerance=1e-6):
            # Equivalent to vanilla gradient descent:
            # https://en.wikipedia.org/wiki/Modified_Richardson_iteration
            def cond(x):
                return jnp.linalg.norm(matvec(x) - b) > tolerance

            def body(x):
                return x + omega * (b - matvec(x))

            return lax.while_loop(cond, body, b)

        def matrix_free_solve(matvec, b):
            return lax.custom_linear_solve(matvec, b, richardson_iteration,
                                           richardson_iteration)

        def build_and_solve(a, b):
            # intentionally non-linear in a and b
            matvec = partial(high_precision_dot, jnp.exp(a))
            return matrix_free_solve(matvec, jnp.cos(b))

        # rng = self.rng()
        # This test is very sensitive to the inputs, so we use a known working seed.
        rng = np.random.RandomState(0)
        a = rng.randn(2, 2)
        b = rng.randn(2)
        expected = jnp.linalg.solve(jnp.exp(a), jnp.cos(b))
        actual = build_and_solve(a, b)
        self.assertAllClose(expected, actual, atol=1e-5)
        jtu.check_grads(build_and_solve, (a, b),
                        atol=1e-5,
                        order=2,
                        rtol={
                            jnp.float32: 6e-2,
                            jnp.float64: 2e-3
                        })

        # vmap across an empty dimension
        jtu.check_grads(jax.vmap(build_and_solve), (a[None, :, :], b[None, :]),
                        atol=1e-5,
                        order=2,
                        rtol={
                            jnp.float32: 6e-2,
                            jnp.float64: 2e-3
                        })
Esempio n. 23
0
    def test_cg_as_solve(self, shape, dtype):

        rng = jtu.rand_default(self.rng())
        a = rng(shape, dtype)
        b = rng(shape[:1], dtype)

        expected = np.linalg.solve(posify(a), b)
        actual = lax_cg(posify(a), b)
        self.assertAllClose(expected, actual, atol=1e-5, rtol=1e-5)

        actual = jit(lax_cg)(posify(a), b)
        self.assertAllClose(expected, actual, atol=1e-5, rtol=1e-5)

        # numerical gradients are only well defined if ``a`` is guaranteed to be
        # positive definite.
        jtu.check_grads(lambda x, y: lax_cg(posify(x), y), (a, b),
                        order=2,
                        rtol=2e-1)
Esempio n. 24
0
  def test_custom_linear_solve_zeros(self):
    def explicit_jacobian_solve(matvec, b):
      return lax.stop_gradient(jnp.linalg.solve(jax.jacobian(matvec)(b), b))

    def matrix_free_solve(matvec, b):
      return lax.custom_linear_solve(matvec, b, explicit_jacobian_solve,
                                     explicit_jacobian_solve)

    def linear_solve(a, b):
      return matrix_free_solve(partial(high_precision_dot, a), b)

    rng = self.rng()
    a = rng.randn(3, 3)
    b = rng.randn(3)
    jtu.check_grads(lambda x: linear_solve(x, b), (a,), order=2,
                    rtol={np.float32: 5e-3})
    jtu.check_grads(lambda x: linear_solve(a, x), (b,), order=2,
                    rtol={np.float32: 5e-3})
Esempio n. 25
0
    def test_custom_root_scalar(self, solve_method):
        def scalar_solve(f, y):
            return y / f(1.0)

        def sqrt_cubed(x, tangent_solve=scalar_solve):
            f = lambda y: y**2 - x**3
            # Note: Nonzero derivative at x0 required for newton_raphson
            return lax.custom_root(f, 1.0, solve_method, tangent_solve)

        value, grad = jax.value_and_grad(sqrt_cubed)(5.0)
        self.assertAllClose(value, 5**1.5, check_dtypes=False, rtol=1e-6)
        self.assertAllClose(grad,
                            jax.grad(pow)(5.0, 1.5),
                            check_dtypes=False,
                            rtol=1e-7)
        jtu.check_grads(sqrt_cubed, (5.0, ),
                        order=2,
                        rtol={
                            jnp.float32: 1e-2,
                            jnp.float64: 1e-3
                        })

        inputs = jnp.array([4.0, 5.0])
        results = jax.vmap(sqrt_cubed)(inputs)
        self.assertAllClose(
            results,
            inputs**1.5,
            check_dtypes=False,
            atol={
                jnp.float32: 1e-3,
                jnp.float64: 1e-6
            },
            rtol={
                jnp.float32: 1e-3,
                jnp.float64: 1e-6
            },
        )

        results = jax.jit(sqrt_cubed)(5.0)
        self.assertAllClose(results,
                            5.0**1.5,
                            check_dtypes=False,
                            rtol={np.float64: 1e-7})
Esempio n. 26
0
    def test_complex_odeint(self):
        # https://github.com/google/jax/issues/3986
        # https://github.com/google/jax/issues/8757

        def dy_dt(y, t, alpha):
            return alpha * y * jnp.exp(-t).astype(y.dtype)

        def f(y0, ts, alpha):
            return odeint(dy_dt, y0, ts, alpha).real

        alpha = 3 + 4j
        y0 = 1 + 2j
        ts = jnp.linspace(0., 1., 11)
        tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3

        # During the backward pass, this ravels all parameters into a single array
        # such that dtype promotion is unavoidable.
        with jax.numpy_dtype_promotion('standard'):
            jtu.check_grads(f, (y0, ts, alpha),
                            modes=["rev"],
                            order=2,
                            atol=tol,
                            rtol=tol)
Esempio n. 27
0
    def test_custom_root_vector_with_solve_closure(self):
        def vector_solve(f, y):
            return jnp.linalg.solve(jax.jacobian(f)(y), y)

        def linear_solve(a, b):
            f = lambda y: high_precision_dot(a, y) - b
            x0 = jnp.zeros_like(b)
            solution = jnp.linalg.solve(a, b)
            oracle = lambda func, x0: solution
            return lax.custom_root(f, x0, oracle, vector_solve)

        rng = self.rng()
        a = rng.randn(2, 2)
        b = rng.randn(2)
        jtu.check_grads(linear_solve, (a, b),
                        order=2,
                        atol={
                            np.float32: 1e-2,
                            np.float64: 1e-11
                        })

        actual = jax.jit(linear_solve)(a, b)
        expected = jnp.linalg.solve(a, b)
        self.assertAllClose(expected, actual)
Esempio n. 28
0
  def test_custom_linear_solve_pytree(self):
    """Test custom linear solve with inputs and outputs that are pytrees."""

    def unrolled_matvec(mat, x):
      """Apply a Python list of lists of scalars to a list of scalars."""
      result = []
      for i in range(len(mat)):
        v = 0
        for j in range(len(x)):
          if mat[i][j] is not None:
            v += mat[i][j] * x[j]
        result.append(v)
      return result

    def unrolled_substitution_solve(matvec, b, lower_tri):
      """Solve a triangular unrolled system with fwd/back substitution."""
      zero = jnp.zeros(())
      one = jnp.ones(())
      x = [zero for _ in b]
      ordering = range(len(b)) if lower_tri else range(len(b) - 1, -1, -1)
      for i in ordering:
        residual = b[i] - matvec(x)[i]
        diagonal = matvec([one if i == j else zero for j in range(len(b))])[i]
        x[i] = residual / diagonal
      return x

    def custom_unrolled_lower_tri_solve(mat, b):
      return lax.custom_linear_solve(
          partial(unrolled_matvec, mat), b,
          partial(unrolled_substitution_solve, lower_tri=True),
          partial(unrolled_substitution_solve, lower_tri=False))

    mat = [[1.0, None, None, None, None, None, None],
           [1.0, 1.0, None, None, None, None, None],
           [None, 1.0, 1.0, None, None, None, None],
           [None, None, 1.0, 1.0, None, None, None],
           [None, None, None, 1.0, 1.0, None, None],
           [None, None, None, None, None, 2.0, None],
           [None, None, None, None, None, 4.0, 3.0]]

    rng = self.rng()
    b = list(rng.randn(7))

    # Non-batched
    jtu.check_grads(custom_unrolled_lower_tri_solve, (mat, b), order=2,
                    rtol={jnp.float32: 2e-2})

    # Batch one element of b (which, because of unrolling, should only affect
    # the first block of outputs)
    b_bat = list(b)
    b_bat[3] = rng.randn(3)
    jtu.check_grads(
        jax.vmap(
            custom_unrolled_lower_tri_solve,
            in_axes=(None, [None, None, None, 0, None, None, None]),
            out_axes=[0, 0, 0, 0, 0, None, None]), (mat, b_bat),
        order=2,
        rtol={jnp.float32: 1e-2})

    # Batch one element of mat (again only affecting first block)
    mat[2][1] = rng.randn(3)
    mat_axis_tree = [
        [0 if i == 2 and j == 1 else None for j in range(7)] for i in range(7)
    ]
    jtu.check_grads(
        jax.vmap(
            custom_unrolled_lower_tri_solve,
            in_axes=(mat_axis_tree, None),
            out_axes=[0, 0, 0, 0, 0, None, None]), (mat, b),
        order=2)
Esempio n. 29
0
    def test_custom_root_with_aux(self):
        def root_aux(a, b):
            f = lambda x: high_precision_dot(a, x) - b
            factors = jsp.linalg.cho_factor(a)
            cho_solve = lambda f, b: (jsp.linalg.cho_solve(factors, b),
                                      orig_aux)

            def pos_def_solve(g, b):
                # prune aux to allow use as tangent_solve
                cho_solve_noaux = lambda f, b: cho_solve(f, b)[0]
                return lax.custom_linear_solve(g,
                                               b,
                                               cho_solve_noaux,
                                               symmetric=True)

            return lax.custom_root(f,
                                   b,
                                   cho_solve,
                                   pos_def_solve,
                                   has_aux=True)

        orig_aux = {
            "converged": np.array(1.),
            "nfev": np.array(12345.),
            "grad": np.array([1.0, 2.0, 3.0])
        }

        rng = self.rng()
        a = rng.randn(2, 2)
        b = rng.randn(2)

        actual, actual_aux = root_aux(high_precision_dot(a, a.T), b)
        actual_jit, actual_jit_aux = jax.jit(root_aux)(high_precision_dot(
            a, a.T), b)
        expected = jnp.linalg.solve(high_precision_dot(a, a.T), b)

        self.assertAllClose(expected, actual)
        self.assertAllClose(expected, actual_jit)
        jtu.check_eq(actual_jit_aux, orig_aux)

        # grad check with aux
        jtu.check_grads(lambda x, y: root_aux(high_precision_dot(x, x.T), y),
                        (a, b),
                        order=2,
                        rtol={jnp.float32: 1e-2})

        # test vmap and jvp combined by jacfwd
        fwd = jax.jacfwd(lambda x, y: root_aux(high_precision_dot(x, x.T), y),
                         argnums=(0, 1))
        expected_fwd = jax.jacfwd(
            lambda x, y: jnp.linalg.solve(high_precision_dot(x, x.T), y),
            argnums=(0, 1))

        fwd_val, fwd_aux = fwd(a, b)
        expected_fwd_val = expected_fwd(a, b)
        self.assertAllClose(fwd_val,
                            expected_fwd_val,
                            rtol={
                                np.float32: 5E-6,
                                np.float64: 5E-12
                            })

        jtu.check_close(fwd_aux, tree_util.tree_map(jnp.zeros_like, fwd_aux))