Пример #1
0
class TestPolynomial(jtu.JaxTestCase):

  def testNotImplemented(self):
    for name in jnp.polynomial._NOT_IMPLEMENTED:
      func = getattr(jnp.polynomial, name)
      with self.assertRaises(NotImplementedError):
        func()

  @parameterized.named_parameters(jtu.cases_from_list(
    {"testcase_name": "_dtype={}_leading={}_trailing={}".format(
       jtu.format_shape_dtype_string((length+leading+trailing,), dtype),
       leading, trailing),
     "dtype": dtype, "rng_factory": rng_factory, "length": length,
     "leading": leading, "trailing": trailing}
    for dtype in all_dtypes
    for rng_factory in [jtu.rand_default]
    for length in [0, 3, 9, 10, 17]
    for leading in [0, 1, 2, 3, 5, 7, 10]
    for trailing in [0, 1, 2, 3, 5, 7, 10]))
  def testRoots(self, dtype, rng_factory, length, leading, trailing):
    rng = rng_factory(np.random.RandomState(0))

    def args_maker():
      p = rng((length,), dtype)
      return jnp.concatenate(
        [jnp.zeros(leading, p.dtype), p, jnp.zeros(trailing, p.dtype)]),

    jnp_fn = lambda arg: jnp.sort(jnp.roots(arg))
    np_fn = lambda arg: np.sort(np.roots(arg))
    self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
                            tol=3e-6)

  @parameterized.named_parameters(jtu.cases_from_list(
    {"testcase_name": "_dtype={}_trailing={}".format(
        jtu.format_shape_dtype_string((length+trailing,), dtype), trailing),
     "dtype": dtype, "rng_factory": rng_factory, "length": length,
     "trailing": trailing}
    for dtype in all_dtypes
    for rng_factory in [jtu.rand_default]
    for length in [0, 1, 3, 10]
    for trailing in [0, 1, 3, 7]))
  def testRootsNostrip(self, length, dtype, rng_factory, trailing):
    rng = rng_factory(np.random.RandomState(0))

    def args_maker():
      p = rng((length,), dtype)
      if length != 0:
        return jnp.concatenate([p, jnp.zeros(trailing, p.dtype)]),
      else:
        # adding trailing would make input invalid (start with zeros)
        return p,

    jnp_fn = lambda arg: jnp.sort(jnp.roots(arg, strip_zeros=False))
    np_fn = lambda arg: np.sort(np.roots(arg))
    self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker,
                            check_dtypes=False, tol=1e-6)

  @parameterized.named_parameters(jtu.cases_from_list(
    {"testcase_name": "_dtype={}_trailing={}".format(
      jtu.format_shape_dtype_string((length + trailing,), dtype), trailing),
      "dtype": dtype, "rng_factory": rng_factory, "length": length,
      "trailing": trailing}
    for dtype in all_dtypes
    for rng_factory in [jtu.rand_default]
    for length in [0, 1, 3, 10]
    for trailing in [0, 1, 3, 7]))
  # TODO: enable when there is an eigendecomposition implementation
  # for GPU/TPU.
  @jtu.skip_on_devices("gpu", "tpu")
  def testRootsJit(self, length, dtype, rng_factory, trailing):
    rng = rng_factory(np.random.RandomState(0))

    def args_maker():
      p = rng((length,), dtype)
      if length != 0:
        return jnp.concatenate([p, jnp.zeros(trailing, p.dtype)]),
      else:
        # adding trailing would make input invalid (start with zeros)
        return p,

    roots_compiled = jit(partial(jnp.roots, strip_zeros=False))
    jnp_fn = lambda arg: jnp.sort(roots_compiled(arg))
    np_fn = lambda arg: np.sort(np.roots(arg))
    # Using strip_zeros=False makes the algorithm less efficient
    # and leads to slightly different values compared ot numpy
    self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker,
                            check_dtypes=False, tol=1e-6)

  @parameterized.named_parameters(jtu.cases_from_list(
    {"testcase_name": "_dtype={}_zeros={}_nonzeros={}".format(
        jtu.format_shape_dtype_string((zeros+nonzeros,), dtype),
        zeros, nonzeros),
     "zeros": zeros, "nonzeros": nonzeros, "dtype": dtype,
     "rng_factory": rng_factory}
    for dtype in all_dtypes
    for rng_factory in [jtu.rand_default]
    for zeros in [1, 2, 5]
    for nonzeros in [0, 3]))
  @jtu.skip_on_devices("gpu")
  def testRootsInvalid(self, zeros, nonzeros, dtype, rng_factory):
    rng = rng_factory(np.random.RandomState(0))

    # The polynomial coefficients here start with zero and would have to
    # be stripped before computing eigenvalues of the companion matrix.
    # Setting strip_zeros=False skips this check,
    # allowing jit transformation but yielding nan's for these inputs.
    p = jnp.concatenate([jnp.zeros(zeros, dtype), rng((nonzeros,), dtype)])

    if p.size == 1:
      # polynomial = const has no roots
      self.assertTrue(jnp.roots(p, strip_zeros=False).size == 0)
    else:
      self.assertTrue(jnp.any(jnp.isnan(jnp.roots(p, strip_zeros=False))))
Пример #2
0
class NumpyLinalgTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)]
                            for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testCholesky(self, shape, dtype, rng):
        _skip_if_unsupported_type(dtype)

        def args_maker():
            factor_shape = shape[:-1] + (2 * shape[-1], )
            a = rng(factor_shape, dtype)
            return [onp.matmul(a, np.conj(T(a)))]

        if np.issubdtype(dtype, np.complexfloating) and (
                len(shape) > 2 or jtu.device_under_test() != "cpu"):
            self.skipTest(
                "Unimplemented case for complex Cholesky decomposition.")

        self._CheckAgainstNumpy(onp.linalg.cholesky,
                                np.linalg.cholesky,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(np.linalg.cholesky,
                              args_maker,
                              check_dtypes=True)

        if onp.finfo(dtype).bits == 64:
            jtu.check_grads(np.linalg.cholesky, args_maker(), order=2)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_n={}".format(jtu.format_shape_dtype_string((n, n), dtype)),
            "n":
            n,
            "dtype":
            dtype,
            "rng":
            rng
        } for n in [0, 4, 5, 25
                    ]  # TODO(mattjj): complex64 unstable on large sizes?
                            for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testDet(self, n, dtype, rng):
        _skip_if_unsupported_type(dtype)
        args_maker = lambda: [rng((n, n), dtype)]

        self._CheckAgainstNumpy(onp.linalg.det,
                                np.linalg.det,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(np.linalg.det, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_n={}".format(jtu.format_shape_dtype_string((n, n), dtype)),
            "n":
            n,
            "dtype":
            dtype,
            "rng":
            rng
        } for n in [0, 4, 10, 200] for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testSlogdet(self, n, dtype, rng):
        _skip_if_unsupported_type(dtype)
        args_maker = lambda: [rng((n, n), dtype)]

        self._CheckAgainstNumpy(onp.linalg.slogdet,
                                np.linalg.slogdet,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(np.linalg.slogdet, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(0, 0), (4, 4), (5, 5), (50, 50), (2, 6, 6)]
                            for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    # TODO(phawkins): enable when there is an eigendecomposition implementation
    # for GPU/TPU.
    @jtu.skip_on_devices("gpu", "tpu")
    def testEig(self, shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        n = shape[-1]
        args_maker = lambda: [rng(shape, dtype)]

        # Norm, adjusted for dimension and type.
        def norm(x):
            norm = onp.linalg.norm(x, axis=(-2, -1))
            return norm / ((n + 1) * onp.finfo(dtype).eps)

        a, = args_maker()
        w, v = np.linalg.eig(a)
        self.assertTrue(
            onp.all(norm(onp.matmul(a, v) - w[..., None, :] * v) < 100))

        self._CompileAndCheck(partial(np.linalg.eig),
                              args_maker,
                              check_dtypes=True,
                              rtol=1e-3)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(1, 1), (4, 4), (5, 5)]
                            for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    @jtu.skip_on_devices("gpu", "tpu")
    def testEigBatching(self, shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        shape = (10, ) + shape
        args = rng(shape, dtype)
        ws, vs = vmap(np.linalg.eig)(args)
        self.assertTrue(
            onp.all(
                onp.linalg.norm(onp.matmul(args, vs) -
                                ws[..., None, :] * vs) < 1e-3))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_n={}_lower={}".format(
                jtu.format_shape_dtype_string((n, n), dtype), lower),
            "n":
            n,
            "dtype":
            dtype,
            "lower":
            lower,
            "rng":
            rng
        } for n in [0, 4, 5, 50] for dtype in float_types + complex_types
                            for lower in [False, True]
                            for rng in [jtu.rand_default()]))
    # TODO(phawkins): enable when there is an eigendecomposition implementation
    # for TPU.
    @jtu.skip_on_devices("tpu")
    def testEigh(self, n, dtype, lower, rng):
        _skip_if_unsupported_type(dtype)
        args_maker = lambda: [rng((n, n), dtype)]

        uplo = "L" if lower else "U"

        # Norm, adjusted for dimension and type.
        def norm(x):
            norm = onp.linalg.norm(x, axis=(-2, -1))
            return norm / ((n + 1) * onp.finfo(dtype).eps)

        a, = args_maker()
        a = (a + onp.conj(a.T)) / 2
        w, v = np.linalg.eigh(onp.tril(a) if lower else onp.triu(a),
                              UPLO=uplo,
                              symmetrize_input=False)
        self.assertTrue(norm(onp.eye(n) - onp.matmul(onp.conj(T(v)), v)) < 5)
        self.assertTrue(norm(onp.matmul(a, v) - w * v) < 30)

        self._CompileAndCheck(partial(np.linalg.eigh, UPLO=uplo),
                              args_maker,
                              check_dtypes=True,
                              rtol=1e-3)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_lower={}".format(
                jtu.format_shape_dtype_string(shape, dtype), lower),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng,
            "lower":
            lower
        } for shape in [(1, 1), (4, 4), (5, 5), (50, 50)]
                            for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]
                            for lower in [True, False]))
    # TODO(phawkins): enable when there is an eigendecomposition implementation
    # for TPU.
    @jtu.skip_on_devices("tpu")
    def testEighGrad(self, shape, dtype, rng, lower):
        self.skipTest("Test fails with numeric errors.")
        uplo = "L" if lower else "U"
        a = rng(shape, dtype)
        a = (a + onp.conj(a.T)) / 2
        a = onp.tril(a) if lower else onp.triu(a)
        # Gradient checks will fail without symmetrization as the eigh jvp rule
        # is only correct for tangents in the symmetric subspace, whereas the
        # checker checks against unconstrained (co)tangents.
        if dtype not in complex_types:
            f = partial(np.linalg.eigh, UPLO=uplo, symmetrize_input=True)
        else:  # only check eigenvalue grads for complex matrices
            f = lambda a: partial(
                np.linalg.eigh, UPLO=uplo, symmetrize_input=True)(a)[0]
        jtu.check_grads(f, (a, ), 2, rtol=1e-1)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_lower={}".format(
                jtu.format_shape_dtype_string(shape, dtype), lower),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng,
            "lower":
            lower,
            "eps":
            eps
        } for shape in [(1, 1), (4, 4), (5, 5), (50, 50)]
                            for dtype in complex_types
                            for rng in [jtu.rand_default()]
                            for lower in [True, False] for eps in [1e-4]))
    # TODO(phawkins): enable when there is an eigendecomposition implementation
    # for TPU.
    @jtu.skip_on_devices("tpu")
    def testEighGradVectorComplex(self, shape, dtype, rng, lower, eps):
        _skip_if_unsupported_type(dtype)
        # Special case to test for complex eigenvector grad correctness.
        # Exact eigenvector coordinate gradients are hard to test numerically for complex
        # eigensystem solvers given the extra degrees of per-eigenvector phase freedom.
        # Instead, we numerically verify the eigensystem properties on the perturbed
        # eigenvectors.  You only ever want to optimize eigenvector directions, not coordinates!
        uplo = "L" if lower else "U"
        a = rng(shape, dtype)
        a = (a + onp.conj(a.T)) / 2
        a = onp.tril(a) if lower else onp.triu(a)
        a_dot = eps * rng(shape, dtype)
        a_dot = (a_dot + onp.conj(a_dot.T)) / 2
        a_dot = onp.tril(a_dot) if lower else onp.triu(a_dot)
        # evaluate eigenvector gradient and groundtruth eigensystem for perturbed input matrix
        f = partial(np.linalg.eigh, UPLO=uplo)
        (w, v), (dw, dv) = jvp(f, primals=(a, ), tangents=(a_dot, ))
        new_a = a + a_dot
        new_w, new_v = f(new_a)
        new_a = (new_a + onp.conj(new_a.T)) / 2
        # Assert rtol eigenvalue delta between perturbed eigenvectors vs new true eigenvalues.
        RTOL = 1e-2
        assert onp.max(
            onp.abs((onp.diag(
                onp.dot(onp.conj(
                    (v + dv).T), onp.dot(new_a,
                                         (v + dv)))) - new_w) / new_w)) < RTOL
        # Redundant to above, but also assert rtol for eigenvector property with new true eigenvalues.
        assert onp.max(
            onp.linalg.norm(
                onp.abs(new_w * (v + dv) - onp.dot(new_a, (v + dv))), axis=0) /
            onp.linalg.norm(onp.abs(new_w * (v + dv)), axis=0)) < RTOL

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(1, 1), (4, 4), (5, 5)]
                            for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    @jtu.skip_on_devices("tpu")
    def testEighBatching(self, shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        shape = (10, ) + shape
        args = rng(shape, dtype)
        args = (args + onp.conj(T(args))) / 2
        ws, vs = vmap(jsp.linalg.eigh)(args)
        self.assertTrue(
            onp.all(
                onp.linalg.norm(onp.matmul(args, vs) -
                                ws[..., None, :] * vs) < 1e-3))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_ord={}_axis={}_keepdims={}".format(
                jtu.format_shape_dtype_string(shape, dtype), ord, axis,
                keepdims),
            "shape":
            shape,
            "dtype":
            dtype,
            "axis":
            axis,
            "keepdims":
            keepdims,
            "ord":
            ord,
            "rng":
            rng
        } for axis, shape in [(None, (1, )), (None, (7, )), (None, (
            5, 8)), (0, (9, )), (0, (4, 5)), ((1, ), (
                10, 7,
                3)), ((-2, ),
                      (4, 8)), (-1, (6, 3)), ((0, 2),
                                              (3, 4,
                                               5)), ((2, 0),
                                                     (7, 8,
                                                      9)), (None, (7, 8, 11))]
                            for keepdims in [False, True] for ord in
                            ([None] if axis is None and len(shape) > 2 else
                             [None, 0, 1, 2, 3, -1, -2, -3, np.inf, -np.inf] if
                             (axis is None and len(shape) == 1
                              ) or isinstance(axis, int) or (
                                  isinstance(axis, tuple) and len(axis) == 1
                              ) else [
                                  None, 'fro', 1, 2, -1, -2, np.
                                  inf, -np.inf, 'nuc'
                              ]) for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testNorm(self, shape, dtype, ord, axis, keepdims, rng):
        _skip_if_unsupported_type(dtype)
        if (ord in ('nuc', 2, -2)
                and (jtu.device_under_test() != "cpu" or
                     (isinstance(axis, tuple) and len(axis) == 2))):
            raise unittest.SkipTest("No adequate SVD implementation available")

        args_maker = lambda: [rng(shape, dtype)]
        onp_fn = partial(onp.linalg.norm,
                         ord=ord,
                         axis=axis,
                         keepdims=keepdims)
        np_fn = partial(np.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
        # Older numpy versions promote to float64 unnecessarily..
        check_dtypes = numpy_version >= (1, 15)
        self._CheckAgainstNumpy(onp_fn,
                                np_fn,
                                args_maker,
                                check_dtypes=check_dtypes,
                                tol=1e-3)
        self._CompileAndCheck(np_fn, args_maker, check_dtypes=check_dtypes)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_n={}_full_matrices={}_compute_uv={}".format(
                jtu.format_shape_dtype_string((
                    m, n), dtype), full_matrices, compute_uv),
            "m":
            m,
            "n":
            n,
            "dtype":
            dtype,
            "full_matrices":
            full_matrices,
            "compute_uv":
            compute_uv,
            "rng":
            rng
        } for m in [2, 7, 29, 53] for n in [2, 7, 29, 53]
                            for dtype in float_types + complex_types
                            for full_matrices in [False, True]
                            for compute_uv in [False, True]
                            for rng in [jtu.rand_default()]))
    @jtu.skip_on_devices("tpu")
    def testSVD(self, m, n, dtype, full_matrices, compute_uv, rng):
        _skip_if_unsupported_type(dtype)
        args_maker = lambda: [rng((m, n), dtype)]

        # Norm, adjusted for dimension and type.
        def norm(x):
            norm = onp.linalg.norm(x, axis=(-2, -1))
            return norm / (max(m, n) * onp.finfo(dtype).eps)

        a, = args_maker()
        out = np.linalg.svd(a,
                            full_matrices=full_matrices,
                            compute_uv=compute_uv)

        if compute_uv:
            # Check the reconstructed matrices
            if full_matrices:
                k = min(m, n)
                if m < n:
                    self.assertTrue(
                        onp.all(
                            norm(a - onp.matmul(out[1] *
                                                out[0], out[2][:k, :])) < 50))
                else:
                    self.assertTrue(
                        onp.all(
                            norm(a - onp.matmul(out[1] *
                                                out[0][:, :k], out[2])) < 50))
            else:
                self.assertTrue(
                    onp.all(
                        norm(a - onp.matmul(out[1] * out[0], out[2])) < 50))

            # Check the unitary properties of the singular vector matrices.
            self.assertTrue(
                onp.all(
                    norm(
                        onp.eye(out[0].shape[1]) -
                        onp.matmul(onp.conj(T(out[0])), out[0])) < 10))
            if m >= n:
                self.assertTrue(
                    onp.all(
                        norm(
                            onp.eye(out[2].shape[1]) -
                            onp.matmul(onp.conj(T(out[2])), out[2])) < 10))
            else:
                self.assertTrue(
                    onp.all(
                        norm(
                            onp.eye(out[2].shape[0]) -
                            onp.matmul(out[2], onp.conj(T(out[2])))) < 20))

        else:
            self.assertTrue(
                onp.allclose(onp.linalg.svd(a, compute_uv=False),
                             onp.asarray(out),
                             atol=1e-4,
                             rtol=1e-4))

        self._CompileAndCheck(partial(np.linalg.svd,
                                      full_matrices=full_matrices,
                                      compute_uv=compute_uv),
                              args_maker,
                              check_dtypes=True)
        if not full_matrices:
            svd = partial(np.linalg.svd, full_matrices=False)
            jtu.check_jvp(svd,
                          partial(jvp, svd), (a, ),
                          atol=1e-1 if FLAGS.jax_enable_x64 else jtu.ATOL)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_fullmatrices={}".format(
                jtu.format_shape_dtype_string(shape, dtype), full_matrices),
            "shape":
            shape,
            "dtype":
            dtype,
            "full_matrices":
            full_matrices,
            "rng":
            rng
        } for shape in [(1, 1), (3, 3), (3, 4), (2, 10, 5), (2, 200, 100)]
                            for dtype in float_types + complex_types
                            for full_matrices in [False, True]
                            for rng in [jtu.rand_default()]))
    def testQr(self, shape, dtype, full_matrices, rng):
        _skip_if_unsupported_type(dtype)
        if (onp.issubdtype(dtype, onp.complexfloating)
                and (jtu.device_under_test() == "tpu" or jax.lib.version <=
                     (0, 1, 27))):
            raise unittest.SkipTest("No complex QR implementation")
        m, n = shape[-2:]

        if full_matrices:
            mode, k = "complete", m
        else:
            mode, k = "reduced", min(m, n)

        a = rng(shape, dtype)
        lq, lr = np.linalg.qr(a, mode=mode)

        # onp.linalg.qr doesn't support batch dimensions. But it seems like an
        # inevitable extension so we support it in our version.
        nq = onp.zeros(shape[:-2] + (m, k), dtype)
        nr = onp.zeros(shape[:-2] + (k, n), dtype)
        for index in onp.ndindex(*shape[:-2]):
            nq[index], nr[index] = onp.linalg.qr(a[index], mode=mode)

        max_rank = max(m, n)

        # Norm, adjusted for dimension and type.
        def norm(x):
            n = onp.linalg.norm(x, axis=(-2, -1))
            return n / (max_rank * onp.finfo(dtype).eps)

        def compare_orthogonal(q1, q2):
            # Q is unique up to sign, so normalize the sign first.
            sum_of_ratios = onp.sum(onp.divide(q1, q2), axis=-2, keepdims=True)
            phases = onp.divide(sum_of_ratios, onp.abs(sum_of_ratios))
            q1 *= phases
            self.assertTrue(onp.all(norm(q1 - q2) < 30))

        # Check a ~= qr
        self.assertTrue(onp.all(norm(a - onp.matmul(lq, lr)) < 30))

        # Compare the first 'k' vectors of Q; the remainder form an arbitrary
        # orthonormal basis for the null space.
        compare_orthogonal(nq[..., :k], lq[..., :k])

        # Check that q is close to unitary.
        self.assertTrue(
            onp.all(norm(onp.eye(k) - onp.matmul(onp.conj(T(lq)), lq)) < 5))

        if not full_matrices and m >= n:
            jtu.check_jvp(np.linalg.qr,
                          partial(jvp, np.linalg.qr), (a, ),
                          atol=1e-3)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(10, 4, 5), (5, 3, 3), (7, 6, 4)]
                            for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testQrBatching(self, shape, dtype, rng):
        args = rng(shape, np.float32)
        qs, rs = vmap(jsp.linalg.qr)(args)
        self.assertTrue(
            onp.all(onp.linalg.norm(args - onp.matmul(qs, rs)) < 1e-3))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs={}_rhs={}".format(
                jtu.format_shape_dtype_string(lhs_shape, dtype),
                jtu.format_shape_dtype_string(rhs_shape, dtype)),
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for lhs_shape, rhs_shape in [
            ((1, 1), (1, 1)),
            ((4, 4), (4, )),
            ((8, 8), (8, 4)),
            ((1, 2, 2), (3, 2)),
            ((2, 1, 3, 3), (2, 4, 3, 4)),
        ] for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testSolve(self, lhs_shape, rhs_shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]

        self._CheckAgainstNumpy(onp.linalg.solve,
                                np.linalg.solve,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(np.linalg.solve, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (5, 5, 5)]
                            for dtype in float_types
                            for rng in [jtu.rand_default()]))
    def testInv(self, shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        if jtu.device_under_test() == "gpu" and shape == (200, 200):
            raise unittest.SkipTest("Test is flaky on GPU")

        def args_maker():
            invertible = False
            while not invertible:
                a = rng(shape, dtype)
                try:
                    onp.linalg.inv(a)
                    invertible = True
                except onp.linalg.LinAlgError:
                    pass
            return [a]

        self._CheckAgainstNumpy(onp.linalg.inv,
                                np.linalg.inv,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(np.linalg.inv, args_maker, check_dtypes=True)

    # Regression test for incorrect type for eigenvalues of a complex matrix.
    @jtu.skip_on_devices("tpu"
                         )  # TODO(phawkins): No eigh implementation on TPU.
    def testIssue669(self):
        def test(x):
            val, vec = np.linalg.eigh(x)
            return np.real(np.sum(val))

        grad_test_jc = jit(grad(jit(test)))
        xc = onp.eye(3, dtype=onp.complex)
        self.assertAllClose(xc, grad_test_jc(xc), check_dtypes=True)

    def testIssue1151(self):
        A = np.array(onp.random.randn(100, 3, 3), dtype=np.float32)
        b = np.array(onp.random.randn(100, 3), dtype=np.float32)
        x = np.linalg.solve(A, b)
        self.assertAllClose(vmap(np.dot)(A, x),
                            b,
                            atol=1e-3,
                            rtol=1e-3,
                            check_dtypes=True)
        jac0 = jax.jacobian(np.linalg.solve, argnums=0)(A, b)
        jac1 = jax.jacobian(np.linalg.solve, argnums=1)(A, b)
        jac0 = jax.jacobian(np.linalg.solve, argnums=0)(A[0], b[0])
        jac1 = jax.jacobian(np.linalg.solve, argnums=1)(A[0], b[0])
Пример #3
0
class LaxRandomTest(jtu.JaxTestCase):
    def _CheckCollisions(self, samples, nbits):
        fail_prob = 0.01  # conservative bound on statistical fail prob by Chebyshev
        nitems = len(samples)
        nbins = 2**nbits
        nexpected = nbins * (1 - ((nbins - 1) / nbins)**nitems)
        ncollisions = len(np.unique(samples))
        sq_percent_deviation = ((ncollisions - nexpected) / nexpected)**2
        self.assertLess(sq_percent_deviation,
                        1 / np.sqrt(nexpected * fail_prob))

    def _CheckKolmogorovSmirnovCDF(self, samples, cdf):
        fail_prob = 0.01  # conservative bound on statistical fail prob by Kolmo CDF
        self.assertGreater(scipy.stats.kstest(samples, cdf).pvalue, fail_prob)

    def _CheckChiSquared(self, samples, pmf):
        alpha = 0.01  # significance level, threshold for p-value
        values, actual_freq = np.unique(samples, return_counts=True)
        expected_freq = pmf(values) * samples.size
        # per scipy: "A typical rule is that all of the observed and expected
        # frequencies should be at least 5."
        valid = (actual_freq > 5) & (expected_freq > 5)
        self.assertGreater(
            valid.sum(),
            1,
            msg='not enough valid frequencies for chi-squared test')
        _, p_value = scipy.stats.chisquare(actual_freq[valid],
                                           expected_freq[valid])
        self.assertGreater(p_value,
                           alpha,
                           msg=f'Failed chi-squared test with p={p_value}.\n'
                           'Expected vs. actual frequencies:\n'
                           f'{expected_freq[valid]}\n{actual_freq[valid]}')

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_{}".format(dtype),
            "dtype": np.dtype(dtype).name
        } for dtype in [np.float32, np.float64]))
    def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype):
        if not FLAGS.jax_enable_x64 and jnp.issubdtype(dtype, np.float64):
            raise SkipTest("can't test float64 agreement")

        bits_dtype = np.uint32 if jnp.finfo(dtype).bits == 32 else np.uint64
        numpy_bits = np.array(1., dtype).view(bits_dtype)
        xla_bits = api.jit(lambda: lax.bitcast_convert_type(
            np.array(1., dtype), bits_dtype))()
        self.assertEqual(numpy_bits, xla_bits)

    def testThreefry2x32(self):
        # We test the hash by comparing to known values provided in the test code of
        # the original reference implementation of Threefry. For the values, see
        # https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32
        def result_to_hex(result):
            return tuple([hex(x.copy()).rstrip("L") for x in result])

        expected = ("0x6b200159", "0x99ba4efe")
        result = random.threefry_2x32(np.uint32([0, 0]), np.uint32([0, 0]))

        self.assertEqual(expected, result_to_hex(result))

        expected = ("0x1cb996fc", "0xbb002be7")
        result = random.threefry_2x32(np.uint32([-1, -1]), np.uint32([-1, -1]))
        self.assertEqual(expected, result_to_hex(result))

        expected = ("0xc4923a9c", "0x483df7a0")
        result = random.threefry_2x32(np.uint32([0x13198a2e, 0x03707344]),
                                      np.uint32([0x243f6a88, 0x85a308d3]))
        self.assertEqual(expected, result_to_hex(result))

    def testThreefry2x32Large(self):
        n = 10000000
        result = random.threefry_2x32(
            (np.uint32(0x13198a2e), np.uint32(0x03707344)),
            jnp.concatenate([
                jnp.full((n, ), 0x243f6a88, jnp.uint32),
                jnp.full((n, ), 0x85a308d3, jnp.uint32)
            ]))
        np.testing.assert_equal(result[:n],
                                np.full((n, ), 0xc4923a9c, dtype=np.uint32))
        np.testing.assert_equal(result[n:],
                                np.full((n, ), 0x483df7a0, dtype=np.uint32))

    def testRngRandomBitsViewProperty(self):
        # TODO: add 64-bit if it ever supports this property.
        # TODO: will this property hold across endian-ness?
        N = 10
        key = random.PRNGKey(1701)
        nbits = [8, 16, 32]
        if jtu.device_under_test() == "tpu":
            # U8 and U16 are not supported on TPU.
            nbits = [32]
        rand_bits = [
            random._random_bits(key, n, (N * 64 // n, )) for n in nbits
        ]
        rand_bits_32 = np.array(
            [np.array(r).view(np.uint32) for r in rand_bits])
        print(rand_bits_32)
        assert np.all(rand_bits_32 == rand_bits_32[0])

    def testRngRandomBits(self):
        # Test specific outputs to ensure consistent random values between JAX versions.
        key = random.PRNGKey(1701)

        # U8 and U16 are not supported on TPU.
        if jtu.device_under_test() != "tpu":
            bits8 = random._random_bits(key, 8, (3, ))
            expected8 = np.array([216, 115, 43], dtype=np.uint8)
            self.assertArraysEqual(bits8, expected8, check_dtypes=True)

            bits16 = random._random_bits(key, 16, (3, ))
            expected16 = np.array([41682, 1300, 55017], dtype=np.uint16)
            self.assertArraysEqual(bits16, expected16, check_dtypes=True)

        bits32 = random._random_bits(key, 32, (3, ))
        expected32 = np.array([56197195, 4200222568, 961309823],
                              dtype=np.uint32)
        self.assertArraysEqual(bits32, expected32, check_dtypes=True)

        bits64 = random._random_bits(key, 64, (3, ))
        if FLAGS.jax_enable_x64:
            expected64 = np.array([
                3982329540505020460, 16822122385914693683, 7882654074788531506
            ],
                                  dtype=np.uint64)
        else:
            expected64 = np.array([676898860, 3164047411, 4010691890],
                                  dtype=np.uint32)
        self.assertArraysEqual(bits64, expected64, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_{}".format(dtype),
            "dtype": np.dtype(dtype).name
        } for dtype in [np.float32, np.float64]))
    def testRngUniform(self, dtype):
        key = random.PRNGKey(0)
        rand = lambda key: random.uniform(key, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key)
        compiled_samples = crand(key)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckCollisions(samples, jnp.finfo(dtype).nmant)
            self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_{}".format(dtype),
            "dtype": np.dtype(dtype).name
        } for dtype in [np.int32, np.int64]))
    def testRngRandint(self, dtype):
        lo = 5
        hi = 10

        key = random.PRNGKey(0)
        rand = lambda key: random.randint(key, (10000, ), lo, hi, dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key)
        compiled_samples = crand(key)

        for samples in [uncompiled_samples, compiled_samples]:
            self.assertTrue(np.all(lo <= samples))
            self.assertTrue(np.all(samples < hi))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_{}".format(dtype),
            "dtype": np.dtype(dtype).name
        } for dtype in [np.float32, np.float64]))
    def testNormal(self, dtype):
        key = random.PRNGKey(0)
        rand = lambda key: random.normal(key, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key)
        compiled_samples = crand(key)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.norm().cdf)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_{}".format(dtype),
            "dtype": np.dtype(dtype).name
        } for dtype in [np.float32, np.float64, np.int32, np.int64]))
    def testShuffle(self, dtype):
        key = random.PRNGKey(0)
        x = np.arange(100).astype(dtype)
        rand = lambda key: random.shuffle(key, x)
        crand = api.jit(rand)

        with self.assertWarns(FutureWarning):
            perm1 = rand(key)
        with self.assertWarns(FutureWarning):
            perm2 = crand(key)

        self.assertAllClose(perm1, perm2, check_dtypes=True)
        self.assertFalse(np.all(perm1 == x))  # seems unlikely!
        self.assertAllClose(np.sort(perm1), x, check_dtypes=False)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "dtype":
            np.dtype(dtype).name,
            "shape":
            shape
        } for dtype in [np.float32, np.float64, np.int32, np.int64]
                            for shape in [100, (10, 10), (10, 5, 2)]))
    def testPermutationArray(self, dtype, shape):
        key = random.PRNGKey(0)
        x = jnp.arange(jnp.prod(shape)).reshape(shape).astype(dtype)
        rand = lambda key: random.permutation(key, x)
        crand = api.jit(rand)

        perm1 = rand(key)
        perm2 = crand(key)

        self.assertAllClose(perm1, perm2, check_dtypes=True)
        self.assertFalse(np.all(perm1 == x))  # seems unlikely!
        self.assertAllClose(np.sort(perm1.ravel()),
                            x.ravel(),
                            check_dtypes=False)
        self.assertArraysAllClose(
            x,
            jnp.arange(jnp.prod(shape)).reshape(shape).astype(dtype),
            check_dtypes=True)

    def testPermutationInteger(self):
        key = random.PRNGKey(0)
        x = 100
        rand = lambda key: random.permutation(key, x)
        crand = api.jit(rand)

        perm1 = rand(key)
        perm2 = crand(key)

        self.assertAllClose(perm1, perm2, check_dtypes=True)
        self.assertEqual(perm1.dtype, perm2.dtype)
        self.assertFalse(np.all(perm1 == np.arange(100)))  # seems unlikely!
        self.assertAllClose(np.sort(perm1), np.arange(100), check_dtypes=False)

    def testPermutationErrors(self):
        key = random.PRNGKey(0)
        with self.assertRaises(TypeError):
            random.permutation(key, 10.)
        with self.assertRaises(core.ConcretizationTypeError):
            api.jit(random.permutation)(key, 10)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_p={}_{}".format(p, dtype),
            "p": p,
            "dtype": np.dtype(dtype).name
        } for p in [0.1, 0.5, 0.9] for dtype in [np.float32, np.float64]))
    def testBernoulli(self, p, dtype):
        key = random.PRNGKey(0)
        p = np.array(p, dtype=dtype)
        rand = lambda key, p: random.bernoulli(key, p, (10000, ))
        crand = api.jit(rand)

        uncompiled_samples = rand(key, p)
        compiled_samples = crand(key, p)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckChiSquared(samples, scipy.stats.bernoulli(p).pmf)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_p={}_{}_{}".format(p, dtype, sample_shape),
                "p": p,
                "axis": axis,
                "dtype": np.dtype(dtype).name,
                'sample_shape': sample_shape
            } for (p, axis) in [
                ([.25] * 4, -1),
                ([.1, .2, .3, .4], -1),
                ([[.5, .5], [.1, .9]], 1),
                ([[.5, .1], [.5, .9]], 0),
            ] for sample_shape in [(10000, ), (5000, 2)]
            for dtype in [np.float32, np.float64]))
    def testCategorical(self, p, axis, dtype, sample_shape):
        key = random.PRNGKey(0)
        p = np.array(p, dtype=dtype)
        logits = np.log(p) - 42  # test unnormalized
        out_shape = tuple(np.delete(logits.shape, axis))
        shape = sample_shape + out_shape
        rand = lambda key, p: random.categorical(
            key, logits, shape=shape, axis=axis)
        crand = api.jit(rand)

        uncompiled_samples = rand(key, p)
        compiled_samples = crand(key, p)

        if axis < 0:
            axis += len(logits.shape)

        for samples in [uncompiled_samples, compiled_samples]:
            assert samples.shape == shape
            samples = jnp.reshape(samples, (10000, ) + out_shape)
            if len(p.shape[:-1]) > 0:
                ps = np.transpose(p, (1, 0)) if axis == 0 else p
                for cat_samples, cat_p in zip(samples.transpose(), ps):
                    self._CheckChiSquared(cat_samples, pmf=lambda x: cat_p[x])
            else:
                self._CheckChiSquared(samples, pmf=lambda x: p[x])

    def testBernoulliShape(self):
        key = random.PRNGKey(0)
        x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2))
        assert x.shape == (3, 2)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_a={}_b={}_{}".format(a, b, dtype),
                "a": a,
                "b": b,
                "dtype": np.dtype(dtype).name
            } for a in [0.2, 5.] for b in [0.2, 5.]
            for dtype in [np.float64]))  # NOTE: KS test fails with float32
    def testBeta(self, a, b, dtype):
        if not FLAGS.jax_enable_x64:
            raise SkipTest("skip test except on X64")
        key = random.PRNGKey(0)
        rand = lambda key, a, b: random.beta(key, a, b, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key, a, b)
        compiled_samples = crand(key, a, b)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples,
                                            scipy.stats.beta(a, b).cdf)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_{}".format(dtype),
            "dtype": np.dtype(dtype).name
        } for dtype in [np.float32, np.float64]))
    def testCauchy(self, dtype):
        key = random.PRNGKey(0)
        rand = lambda key: random.cauchy(key, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key)
        compiled_samples = crand(key)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.cauchy().cdf)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_alpha={}_{}".format(alpha, dtype),
                "alpha": alpha,
                "dtype": np.dtype(dtype).name
            } for alpha in [
                np.array([0.2, 1., 5.]),
            ] for dtype in [np.float32, np.float64]))
    def testDirichlet(self, alpha, dtype):
        key = random.PRNGKey(0)
        rand = lambda key, alpha: random.dirichlet(key, alpha,
                                                   (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key, alpha)
        compiled_samples = crand(key, alpha)

        for samples in [uncompiled_samples, compiled_samples]:
            self.assertAllClose(samples.sum(-1),
                                np.ones(10000, dtype=dtype),
                                check_dtypes=True)
            alpha_sum = sum(alpha)
            for i, a in enumerate(alpha):
                self._CheckKolmogorovSmirnovCDF(
                    samples[..., i],
                    scipy.stats.beta(a, alpha_sum - a).cdf)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_{}".format(dtype),
            "dtype": np.dtype(dtype).name
        } for dtype in [np.float32, np.float64]))
    def testExponential(self, dtype):
        key = random.PRNGKey(0)
        rand = lambda key: random.exponential(key, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key)
        compiled_samples = crand(key)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.expon().cdf)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_a={}_{}".format(a, dtype),
            "a": a,
            "dtype": np.dtype(dtype).name
        } for a in [0.1, 1., 10.] for dtype in [np.float32, np.float64]))
    def testGamma(self, a, dtype):
        key = random.PRNGKey(0)
        rand = lambda key, a: random.gamma(key, a, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key, a)
        compiled_samples = crand(key, a)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gamma(a).cdf)

    def testGammaShape(self):
        key = random.PRNGKey(0)
        x = random.gamma(key, np.array([0.2, 0.3]), shape=(3, 2))
        assert x.shape == (3, 2)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_a={}".format(alpha),
            "alpha": alpha
        } for alpha in [1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]))
    def testGammaGrad(self, alpha):
        rng = random.PRNGKey(0)
        alphas = np.full((100, ), alpha)
        z = random.gamma(rng, alphas)
        actual_grad = api.grad(lambda x: random.gamma(rng, x).sum())(alphas)

        eps = 0.01 * alpha / (1.0 + np.sqrt(alpha))
        cdf_dot = (scipy.stats.gamma.cdf(z, alpha + eps) -
                   scipy.stats.gamma.cdf(z, alpha - eps)) / (2 * eps)
        pdf = scipy.stats.gamma.pdf(z, alpha)
        expected_grad = -cdf_dot / pdf

        self.assertAllClose(
            actual_grad,
            expected_grad,
            check_dtypes=True,
            rtol=2e-2 if jtu.device_under_test() == "tpu" else 5e-4)

    def testGammaGradType(self):
        # Regression test for https://github.com/google/jax/issues/2130
        key = random.PRNGKey(0)
        a = jnp.array(1., dtype=jnp.float32)
        b = jnp.array(3., dtype=jnp.float32)
        f = lambda x, y: random.gamma(key=key, a=x, dtype=jnp.float32) / y
        # Should not crash with a type error.
        api.vjp(f, a, b)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_lam={}_{}".format(lam, dtype),
            "lam": lam,
            "dtype": np.dtype(dtype).name
        } for lam in [0.5, 3, 9, 11, 50, 500]
                            for dtype in [np.int32, np.int64]))
    def testPoisson(self, lam, dtype):
        key = random.PRNGKey(0)
        rand = lambda key, lam: random.poisson(key, lam, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key, lam)
        compiled_samples = crand(key, lam)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckChiSquared(samples, scipy.stats.poisson(lam).pmf)
            # TODO(shoyer): determine error bounds for moments more rigorously (e.g.,
            # based on the central limit theorem).
            self.assertAllClose(samples.mean(),
                                lam,
                                rtol=0.01,
                                check_dtypes=False)
            self.assertAllClose(samples.var(),
                                lam,
                                rtol=0.03,
                                check_dtypes=False)

    def testPoissonBatched(self):
        key = random.PRNGKey(0)
        lam = jnp.concatenate([2 * jnp.ones(10000), 20 * jnp.ones(10000)])
        samples = random.poisson(key, lam, shape=(20000, ))
        self._CheckChiSquared(samples[:10000], scipy.stats.poisson(2.0).pmf)
        self._CheckChiSquared(samples[10000:], scipy.stats.poisson(20.0).pmf)

    def testPoissonShape(self):
        key = random.PRNGKey(0)
        x = random.poisson(key, np.array([2.0, 20.0]), shape=(3, 2))
        assert x.shape == (3, 2)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_{}".format(dtype),
            "dtype": np.dtype(dtype).name
        } for dtype in [np.float32, np.float64]))
    def testGumbel(self, dtype):
        key = random.PRNGKey(0)
        rand = lambda key: random.gumbel(key, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key)
        compiled_samples = crand(key)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples,
                                            scipy.stats.gumbel_r().cdf)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_{}".format(dtype),
            "dtype": np.dtype(dtype).name
        } for dtype in [np.float32, np.float64]))
    def testLaplace(self, dtype):
        key = random.PRNGKey(0)
        rand = lambda key: random.laplace(key, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key)
        compiled_samples = crand(key)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.laplace().cdf)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_{}".format(dtype),
            "dtype": np.dtype(dtype).name
        } for dtype in [np.float32, np.float64]))
    def testLogistic(self, dtype):
        key = random.PRNGKey(0)
        rand = lambda key: random.logistic(key, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key)
        compiled_samples = crand(key)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples,
                                            scipy.stats.logistic().cdf)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_b={}_{}".format(b, dtype),
            "b": b,
            "dtype": np.dtype(dtype).name
        } for b in [0.1, 1., 10.] for dtype in [np.float32, np.float64]))
    def testPareto(self, b, dtype):
        key = random.PRNGKey(0)
        rand = lambda key, b: random.pareto(key, b, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key, b)
        compiled_samples = crand(key, b)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.pareto(b).cdf)

    def testParetoShape(self):
        key = random.PRNGKey(0)
        x = random.pareto(key, np.array([0.2, 0.3]), shape=(3, 2))
        assert x.shape == (3, 2)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_df={}_{}".format(df, dtype),
            "df": df,
            "dtype": np.dtype(dtype).name
        } for df in [0.1, 1., 10.] for dtype in [np.float32, np.float64]))
    @jtu.skip_on_devices("cpu",
                         "tpu")  # TODO(phawkins): slow compilation times
    def testT(self, df, dtype):
        key = random.PRNGKey(0)
        rand = lambda key, df: random.t(key, df, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key, df)
        compiled_samples = crand(key, df)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.t(df).cdf)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_{}D_{}".format(dim,
                                                  np.dtype(dtype).name),
                "dim": dim,
                "dtype": dtype
            } for dim in [1, 3, 5] for dtype in [np.float32, np.float64]))
    def testMultivariateNormal(self, dim, dtype):
        r = np.random.RandomState(dim)
        mean = r.randn(dim)
        cov_factor = r.randn(dim, dim)
        cov = np.dot(cov_factor, cov_factor.T) + dim * np.eye(dim)

        key = random.PRNGKey(0)
        rand = partial(random.multivariate_normal,
                       mean=mean,
                       cov=cov,
                       shape=(10000, ))
        crand = api.jit(rand)

        uncompiled_samples = np.asarray(rand(key), np.float64)
        compiled_samples = np.asarray(crand(key), np.float64)

        inv_scale = scipy.linalg.lapack.dtrtri(np.linalg.cholesky(cov),
                                               lower=True)[0]
        for samples in [uncompiled_samples, compiled_samples]:
            centered = samples - mean
            whitened = np.einsum('nj,ij->ni', centered, inv_scale)

            # This is a quick-and-dirty multivariate normality check that tests that a
            # uniform mixture of the marginals along the covariance matrix's
            # eigenvectors follow a standard normal distribution.
            self._CheckKolmogorovSmirnovCDF(whitened.ravel(),
                                            scipy.stats.norm().cdf)

    def testMultivariateNormalCovariance(self):
        # test code based on https://github.com/google/jax/issues/1869
        N = 100000
        cov = jnp.array([[0.19, 0.00, -0.13, 0.00], [0.00, 0.29, 0.00, -0.23],
                         [-0.13, 0.00, 0.39, 0.00], [0.00, -0.23, 0.00, 0.49]])
        mean = jnp.zeros(4)

        out_np = np.random.RandomState(0).multivariate_normal(mean, cov, N)

        key = random.PRNGKey(0)
        out_jnp = random.multivariate_normal(key,
                                             mean=mean,
                                             cov=cov,
                                             shape=(N, ))

        var_np = out_np.var(axis=0)
        var_jnp = out_jnp.var(axis=0)
        self.assertAllClose(var_np,
                            var_jnp,
                            rtol=1e-2,
                            atol=1e-2,
                            check_dtypes=False)

        var_np = np.cov(out_np, rowvar=False)
        var_jnp = np.cov(out_jnp, rowvar=False)
        self.assertAllClose(var_np,
                            var_jnp,
                            rtol=1e-2,
                            atol=1e-2,
                            check_dtypes=False)

    def testIssue222(self):
        x = random.randint(random.PRNGKey(10003), (), 0, 0)
        assert x == 0

    def testFoldIn(self):
        key = random.PRNGKey(0)
        keys = [random.fold_in(key, i) for i in range(10)]
        assert np.unique(np.ravel(keys)).shape == (20, )

    def testStaticShapeErrors(self):
        if config.read("jax_disable_jit"):
            raise SkipTest("test only relevant when jit enabled")

        @api.jit
        def feature_map(n, d, sigma=1.0, seed=123):
            key = random.PRNGKey(seed)
            W = random.normal(key, (d, n)) / sigma
            w = random.normal(key, (d, )) / sigma
            b = 2 * jnp.pi * random.uniform(key, (d, ))

            phi = lambda x, t: jnp.sqrt(2.0 / d) * jnp.cos(
                jnp.matmul(W, x) + w * t + b)
            return phi

        self.assertRaisesRegex(TypeError, 'Shapes must be 1D.*',
                               lambda: feature_map(5, 3))

    def testIssue756(self):
        key = random.PRNGKey(0)
        w = random.normal(key, ())
        if FLAGS.jax_enable_x64:
            self.assertEqual(np.result_type(w), np.float64)
        else:
            self.assertEqual(np.result_type(w), np.float32)

    def testIssue1789(self):
        def f(x):
            return random.gamma(random.PRNGKey(0), x)

        grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2))

    def testNoOpByOpUnderHash(self):
        def fail(*args, **kwargs):
            assert False

        apply_primitive, xla.apply_primitive = xla.apply_primitive, fail
        try:
            out = random.threefry_2x32(np.zeros(2, np.uint32),
                                       np.arange(10, dtype=np.uint32))
        finally:
            xla.apply_primitive = apply_primitive

    def testPRNGValues(self):
        # Test to ensure consistent random values between JAX versions
        k = random.PRNGKey(0)

        randints = random.randint(k, (3, 3), 0, 8)
        if FLAGS.jax_enable_x64:
            self.assertAllClose(random.randint(k, (3, 3), 0, 8),
                                np.array([[7, 2, 6], [2, 1, 0], [6, 7, 7]],
                                         dtype='int64'),
                                check_dtypes=True)
        else:
            self.assertAllClose(random.randint(k, (3, 3), 0, 8),
                                np.array([[2, 1, 3], [6, 1, 5], [6, 3, 4]],
                                         dtype='int32'),
                                check_dtypes=True)

        self.assertAllClose(
            random.split(k, 4),
            np.array([[2285895361, 1501764800], [1518642379, 4090693311],
                      [433833334, 4221794875], [839183663, 3740430601]],
                     dtype='uint32'),
            check_dtypes=True)

        self.assertAllClose(random.fold_in(k, 4),
                            np.array([2285895361, 433833334], dtype='uint32'),
                            check_dtypes=True)
Пример #4
0
class LaxBackedScipyTests(jtu.JaxTestCase):
  def _fetch_preconditioner(self, preconditioner, A, rng=None,
                            return_function=False):
    """
    Returns one of various preconditioning matrices depending on the identifier
    `preconditioner' and the input matrix A whose inverse it supposedly
    approximates.
    """
    if preconditioner == 'identity':
      M = np.eye(A.shape[0], dtype=A.dtype)
    elif preconditioner == 'random':
      if rng is None:
        rng = jtu.rand_default(self.rng())
      M = np.linalg.inv(rand_sym_pos_def(rng, A.shape, A.dtype))
    elif preconditioner == 'exact':
      M = np.linalg.inv(A)
    else:
      M = None

    if M is None or not return_function:
      return M
    else:
      return lambda x: jnp.dot(M, x, precision=lax.Precision.HIGHEST)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_preconditioner={}".format(
            jtu.format_shape_dtype_string(shape, dtype),
            preconditioner),
       "shape": shape, "dtype": dtype, "preconditioner": preconditioner}
      for shape in [(4, 4), (7, 7)]
      for dtype in [np.float64, np.complex128]
      for preconditioner in [None, 'identity', 'exact', 'random']))
  def test_cg_against_scipy(self, shape, dtype, preconditioner):
    if not config.FLAGS.jax_enable_x64:
      raise unittest.SkipTest("requires x64 mode")

    rng = jtu.rand_default(self.rng())
    A = rand_sym_pos_def(rng, shape, dtype)
    b = rng(shape[:1], dtype)
    M = self._fetch_preconditioner(preconditioner, A, rng=rng)

    def args_maker():
      return A, b

    self._CheckAgainstNumpy(
        partial(scipy_cg, M=M, maxiter=1),
        partial(lax_cg, M=M, maxiter=1),
        args_maker,
        tol=1e-12)

    self._CheckAgainstNumpy(
        partial(scipy_cg, M=M, maxiter=3),
        partial(lax_cg, M=M, maxiter=3),
        args_maker,
        tol=1e-12)

    self._CheckAgainstNumpy(
        np.linalg.solve,
        partial(lax_cg, M=M, atol=1e-10),
        args_maker,
        tol=1e-6)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype}
      for shape in [(2, 2)]
      for dtype in float_types + complex_types))
  def test_cg_as_solve(self, shape, dtype):

    rng = jtu.rand_default(self.rng())
    a = rng(shape, dtype)
    b = rng(shape[:1], dtype)

    expected = np.linalg.solve(posify(a), b)
    actual = lax_cg(posify(a), b)
    self.assertAllClose(expected, actual)

    actual = jit(lax_cg)(posify(a), b)
    self.assertAllClose(expected, actual)

    # numerical gradients are only well defined if ``a`` is guaranteed to be
    # positive definite.
    jtu.check_grads(
        lambda x, y: lax_cg(posify(x), y),
        (a, b), order=2, rtol=1e-2)

  def test_cg_ndarray(self):
    A = lambda x: 2 * x
    b = jnp.arange(9.0).reshape((3, 3))
    expected = b / 2
    actual, _ = jax.scipy.sparse.linalg.cg(A, b)
    self.assertAllClose(expected, actual)

  def test_cg_pytree(self):
    A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
    b = {"a": 1.0, "b": -4.0}
    expected = {"a": 4.0, "b": -6.0}
    actual, _ = jax.scipy.sparse.linalg.cg(A, b)
    self.assertEqual(expected.keys(), actual.keys())
    self.assertAlmostEqual(expected["a"], actual["a"], places=6)
    self.assertAlmostEqual(expected["b"], actual["b"], places=6)

  def test_cg_errors(self):
    A = lambda x: x
    b = jnp.zeros((2,))
    with self.assertRaisesRegex(
        ValueError, "x0 and b must have matching tree structure"):
      jax.scipy.sparse.linalg.cg(A, {'x': b}, {'y': b})
    with self.assertRaisesRegex(
        ValueError, "x0 and b must have matching shape"):
      jax.scipy.sparse.linalg.cg(A, b, b[:, np.newaxis])

  def test_cg_without_pytree_equality(self):

    @register_pytree_node_class
    class MinimalPytree:
      def __init__(self, value):
        self.value = value
      def tree_flatten(self):
        return [self.value], None
      @classmethod
      def tree_unflatten(cls, aux_data, children):
        return cls(*children)

    A = lambda x: MinimalPytree(2 * x.value)
    b = MinimalPytree(jnp.arange(5.0))
    expected = b.value / 2
    actual, _ = jax.scipy.sparse.linalg.cg(A, b)
    self.assertAllClose(expected, actual.value)

  # GMRES
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_preconditioner={}_solve_method={}".format(
            jtu.format_shape_dtype_string(shape, dtype),
            preconditioner,
            solve_method),
       "shape": shape, "dtype": dtype, "preconditioner": preconditioner,
       "solve_method": solve_method}
      for shape in [(3, 3)]
      for dtype in [np.float64, np.complex128]
      for preconditioner in [None, 'identity', 'exact', 'random']
      for solve_method in ['incremental', 'batched']))
  def test_gmres_against_scipy(
      self, shape, dtype, preconditioner, solve_method):
    if not config.FLAGS.jax_enable_x64:
      raise unittest.SkipTest("requires x64 mode")

    rng = jtu.rand_default(self.rng())
    A = rng(shape, dtype)
    b = rng(shape[:1], dtype)
    M = self._fetch_preconditioner(preconditioner, A, rng=rng)

    def args_maker():
      return A, b

    self._CheckAgainstNumpy(
        partial(scipy_gmres, M=M, restart=1, maxiter=1),
        partial(lax_gmres, M=M, restart=1, maxiter=1, solve_method=solve_method),
        args_maker,
        tol=1e-10)

    self._CheckAgainstNumpy(
        partial(scipy_gmres, M=M, restart=1, maxiter=2),
        partial(lax_gmres, M=M, restart=1, maxiter=2, solve_method=solve_method),
        args_maker,
        tol=1e-10)

    self._CheckAgainstNumpy(
        partial(scipy_gmres, M=M, restart=2, maxiter=1),
        partial(lax_gmres, M=M, restart=2, maxiter=1, solve_method=solve_method),
        args_maker,
        tol=1e-10)

    self._CheckAgainstNumpy(
        np.linalg.solve,
        partial(lax_gmres, M=M, atol=1e-6, solve_method=solve_method),
        args_maker,
        tol=1e-10)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_preconditioner={}_solve_method={}".format(
         jtu.format_shape_dtype_string(shape, dtype),
         preconditioner,
         solve_method),
      "shape": shape, "dtype": dtype, "preconditioner": preconditioner,
      "solve_method": solve_method}
      for shape in [(2, 2), (7, 7)]
      for dtype in float_types + complex_types
      for preconditioner in [None, 'identity', 'exact']
      for solve_method in ['batched', 'incremental']
      ))
  def test_gmres_on_identity_system(self, shape, dtype, preconditioner,
                                    solve_method):
    A = jnp.eye(shape[1], dtype=dtype)

    solution = jnp.ones(shape[1], dtype=dtype)
    @jax.tree_util.Partial
    def A_mv(x):
      return matmul_high_precision(A, x)
    rng = jtu.rand_default(self.rng())
    M = self._fetch_preconditioner(preconditioner, A, rng=rng,
                                   return_function=True)
    b = A_mv(solution)
    restart = shape[-1]
    tol = shape[0] * jnp.finfo(dtype).eps
    x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol,
                                            restart=restart,
                                            M=M, solve_method=solve_method)
    using_x64 = solution.dtype.kind in {np.float64, np.complex128}
    solution_tol = 1e-8 if using_x64 else 1e-4
    self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_preconditioner={}_solve_method={}".format(
         jtu.format_shape_dtype_string(shape, dtype),
         preconditioner,
         solve_method),
      "shape": shape, "dtype": dtype, "preconditioner": preconditioner,
      "solve_method": solve_method}
      for shape in [(2, 2), (4, 4)]
      for dtype in float_types + complex_types
      for preconditioner in [None, 'identity', 'exact']
      for solve_method in ['incremental', 'batched']
      ))
  def test_gmres_on_random_system(self, shape, dtype, preconditioner,
                                  solve_method):
    rng = jtu.rand_default(self.rng())
    A = rng(shape, dtype)

    solution = rng(shape[1:], dtype)
    @jax.tree_util.Partial
    def A_mv(x):
      return matmul_high_precision(A, x)
    M = self._fetch_preconditioner(preconditioner, A, rng=rng,
                                   return_function=True)
    b = A_mv(solution)
    restart = shape[-1]
    tol = shape[0] * jnp.finfo(A.dtype).eps
    x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol,
                                            restart=restart,
                                            M=M, solve_method=solve_method)
    using_x64 = solution.dtype.kind in {np.float64, np.complex128}
    solution_tol = 1e-8 if using_x64 else 1e-4
    self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)

  def test_gmres_pytree(self):
    A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
    b = {"a": 1.0, "b": -4.0}
    expected = {"a": 4.0, "b": -6.0}
    actual, _ = jax.scipy.sparse.linalg.gmres(A, b)
    self.assertEqual(expected.keys(), actual.keys())
    self.assertAlmostEqual(expected["a"], actual["a"], places=5)
    self.assertAlmostEqual(expected["b"], actual["b"], places=5)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_preconditioner={}".format(
         jtu.format_shape_dtype_string(shape, dtype),
         preconditioner),
      "shape": shape, "dtype": dtype, "preconditioner": preconditioner}
      for shape in [(2, 2), (3, 3)]
      for dtype in float_types + complex_types
      for preconditioner in [None, 'identity']))
  def test_gmres_arnoldi_step(self, shape, dtype, preconditioner):
    """
    The Arnoldi decomposition within GMRES is correct.
    """
    if not config.FLAGS.jax_enable_x64:
      raise unittest.SkipTest("requires x64 mode")

    rng = jtu.rand_default(self.rng())
    A = rng(shape, dtype)
    if preconditioner is None:
      M = lambda x: x
    else:
      M = self._fetch_preconditioner(preconditioner, A, rng=rng,
                                     return_function=True)

    n = shape[0]
    x0 = rng(shape[:1], dtype)
    Q = np.zeros((n, n + 1), dtype=dtype)
    Q[:, 0] = x0/jnp.linalg.norm(x0)
    Q = jnp.array(Q)
    H = jnp.eye(n, n + 1, dtype=dtype)

    @jax.tree_util.Partial
    def A_mv(x):
      return matmul_high_precision(A, x)
    for k in range(n):
      Q, H, _ = jax._src.scipy.sparse.linalg._kth_arnoldi_iteration(
          k, A_mv, M, Q, H)
    QA = matmul_high_precision(Q[:, :n].conj().T, A)
    QAQ = matmul_high_precision(QA, Q[:, :n])
    self.assertAllClose(QAQ, H.T[:n, :], rtol=1e-5, atol=1e-5)
Пример #5
0
class BatchingTest(jtu.JaxTestCase):

  def testConstantFunction(self):
    ans = vmap(lambda x: 3)(np.ones(4))
    expected = 3 * np.ones(4)
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testNestedBatchingMatMat(self):
    matvec = vmap(jnp.vdot, in_axes=(0, None))
    matmat = vmap(matvec, in_axes=(None, 1), out_axes=1)

    R = np.random.RandomState(0).randn
    A = R(4, 3)
    B = R(3, 2)

    ans = matmat(A, B)
    expected = np.dot(A, B)
    self.assertAllClose(
        ans, expected, check_dtypes=False,
        rtol={np.float32:1e-2} if jtu.device_under_test() == "tpu" else None)

    jaxpr = make_jaxpr(matmat)(A, B)
    self.assertEqual(len(jaxpr.jaxpr.eqns), 1)

  def testPerExampleGradients(self):
    def predict(params, inputs):
      for W, b in params:
        outputs = jnp.dot(W, inputs) + b
        inputs = jnp.tanh(outputs)
      return outputs

    def loss(params, data):
      inputs, targets = data
      predictions = predict(params, inputs)
      return jnp.sum((predictions - targets)**2)

    batch_size = 5
    layer_sizes = [3, 2, 4]

    R = np.random.RandomState(0).randn
    params = [(R(m, n), R(m))
              for m, n in zip(layer_sizes[1:], layer_sizes[:-1])]

    input_batch = R(5, 3)
    target_batch = R(5, 4)
    batch = (input_batch, target_batch)

    ans = vmap(partial(grad(loss), params))(batch)

    for ans_pair, param_pair in zip(ans, params):
      dW, db = ans_pair
      W, b = param_pair

      self.assertEqual(dW.shape, (batch_size,) + W.shape)
      self.assertEqual(db.shape, (batch_size,) + b.shape)

  def testJacobians(self):
    def jacbwd(f, x):
      y, pullback = vjp(f, x)
      std_basis = np.eye(np.size(y)).reshape((-1,) + np.shape(y))
      jac_flat, = vmap(pullback, out_axes=np.ndim(y))(std_basis)
      return jac_flat.reshape(np.shape(y) + np.shape(x))

    def jacfwd(f, x):
      pushfwd = lambda v: jvp(f, (x,), (v,))
      std_basis = np.eye(np.size(x)).reshape((-1,) + np.shape(x))
      y, jac_flat = vmap(pushfwd, out_axes=(None, 0))(std_basis)
      return jac_flat.reshape(np.shape(y) + np.shape(x))

    R = np.random.RandomState(0).randn

    A = R(4, 3)
    b = R(4)
    f = lambda x: jnp.tanh(jnp.dot(A, x) + b)

    x = R(3)
    self.assertAllClose(jacfwd(f, x), jacbwd(f, x), check_dtypes=False)

  def testBatchOfCompile(self):
    side = []

    @jit
    def f(x):
      side.append(None)
      return x + x

    g = jit(vmap(f))
    self.assertAllClose(g(np.ones(2)), 2 * np.ones(2), check_dtypes=False)
    self.assertEqual(len(side), 1)
    self.assertAllClose(g(2 * np.ones(2)), 4 * np.ones(2),
                        check_dtypes=False)
    self.assertEqual(len(side), 1)

  def testSliceLax(self):
    fun = lambda x: lax.slice(x, (2,), (4,))
    R = np.random.RandomState(0).randn
    x = R(5, 10)

    ans = vmap(fun)(x)
    expected_ans = x[:, 2:4]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

  def testSliceNumpy(self):
    fun = lambda x: x[:, 2]
    R = np.random.RandomState(0).randn
    x = R(10, 5, 3, 7)

    ans = vmap(fun)(x)
    expected_ans = x[:, :, 2]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

  def testRevLax(self):
    fun = lambda x: lax.rev(x, [0])
    R = np.random.RandomState(0).randn
    x = R(2, 3)

    ans = vmap(fun)(x)
    expected_ans = x[:, ::-1]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

    ans = vmap(fun, (1,), 1)(x)
    expected_ans = x[::-1, :]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

  def testRevNumpy(self):
    fun = lambda x: x[:, ::-1]
    R = np.random.RandomState(0).randn
    x = R(3, 2, 4)

    ans = vmap(fun)(x)
    expected_ans = x[:, :, ::-1]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

    ans = vmap(fun, (1,), 1)(x)
    expected_ans = x[:, :, ::-1]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

    ans = vmap(fun, (2,), 2)(x)
    expected_ans = x[:, ::-1, :]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

  def testNpMaximum(self):
    fun = lambda x: jnp.maximum(x, 0.0)
    R = np.random.RandomState(0).randn
    x = R(10, 5, 3, 7)

    ans = vmap(fun)(x)
    expected_ans = np.maximum(x, 0.0)
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

  def testNpGtrThan(self):
    R = np.random.RandomState(0).randn
    x = R(10, 5, 3, 7)

    ans = vmap(lambda x: x > 1.0)(x)
    expected_ans = x > 1.0
    self.assertAllClose(ans, expected_ans)

  def testNpMaximumPerExampleGrad(self):
    R = np.random.RandomState(0).randn
    x = R(10, 5)
    W = R(5, 5)

    fun = lambda W, x: jnp.sum(jnp.maximum(jnp.dot(x, W), 0.0) ** 2)

    ans = vmap(partial(grad(fun), W))(x)

    W_t = jnp.transpose(W)
    for i in range(10):
      x_ex = x[i:i + 1]

      expected_ans = 2.0 * jnp.dot(
          jnp.maximum(jnp.dot(W_t, jnp.transpose(x_ex)), 0.0), x_ex)
      expected_ans = jnp.transpose(expected_ans)

      self.assertAllClose(
          ans[i], expected_ans, check_dtypes=False,
          atol={np.float32:5e-2} if jtu.device_under_test() == "tpu" else None)

  def testDotGeneral(self):
    R = np.random.RandomState(0).randn

    x = R(10, 3, 4, 5)
    y = R(10, 3, 5, 6)
    fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
    ans = vmap(fun)(x, y)
    expected = lax.dot_general(x, y, [((3,), (2,)), ((0, 1), (0, 1))])
    self.assertAllClose(ans, expected)

    x = R(3, 4, 10, 5)
    y = R(3, 10, 5, 6)
    fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
    ans = vmap(fun, in_axes=(2, 1))(x, y)
    expected = np.stack([fun(x[..., i, :], y[:, i, ...]) for i in range(10)])
    self.assertAllClose(ans, expected)

    x = R(3, 4, 5, 10)
    y = R(3, 5, 6)
    fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
    ans = vmap(fun, in_axes=(3, None))(x, y)
    expected = np.stack([fun(x[..., i], y) for i in range(10)])
    self.assertAllClose(ans, expected)

    x = R(3, 4, 5)
    y = R(3, 5, 10, 6)
    fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
    ans = vmap(fun, in_axes=(None, 2))(x, y)
    expected = np.stack([fun(x, y[..., i, :]) for i in range(10)])
    self.assertAllClose(ans, expected)

    x = R(4)
    y = R(4, 10)
    fun = lambda x, y: lax.dot_general(x, y, [((0,), (0,)), ((), ())])
    ans = vmap(fun, in_axes=(None, 1))(x, y)
    expected = np.stack([fun(x, y[..., i]) for i in range(10)])
    self.assertAllClose(ans, expected)

  def testDot(self):
    # these tests are based on @shoyer's notebook studying gufuncs

    def vecvec(a, b):
      dot = jnp.dot
      for ndim in range(1, max(a.ndim, b.ndim)):
        a_ax = 0 if a.ndim > ndim else None
        b_ax = 0 if b.ndim > ndim else None
        dot = vmap(dot, in_axes=(a_ax, b_ax))
      return dot(a, b)

    assert vecvec(jnp.zeros((3,)), jnp.zeros((3,))).shape == ()
    assert vecvec(jnp.zeros((2, 3)), jnp.zeros((3,))).shape == (2,)
    assert vecvec(jnp.zeros((4, 2, 3)), jnp.zeros((3,))).shape == (4, 2)

  def testDot2(self):
    R = np.random.RandomState(0).randn
    xs = R(10, 3)
    ys = R(10, 3)
    ans = vmap(jnp.dot)(xs, ys)
    expected = np.einsum('ni,ni->n', xs, ys)
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testDot3(self):
    R = np.random.RandomState(0).randn
    xs = R(5, 8, 10)
    ys = R(10, 1)
    ans = vmap(jnp.dot, in_axes=(1, None))(xs, ys)
    expected = np.einsum('inj,jk->nik', xs, ys)
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testDot4(self):
    R = np.random.RandomState(0).randn
    xs = R(3, 2)
    ys = R(3)
    ans = vmap(jnp.dot, in_axes=(1, None))(xs, ys)
    expected = np.einsum('ij,i->j', xs, ys)
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testPad(self):
    R = np.random.RandomState(0).randn

    fun = lambda x: lax.pad(x, np.float32(0), [(1, 2, 1)])
    x = R(5, 10).astype(np.float32)
    ans = vmap(fun)(x)
    expected_ans = jnp.stack(list(map(fun, x)))
    self.assertAllClose(ans, expected_ans, check_dtypes=False)


    fun = lambda x: lax.pad(x, np.float32(0), [(1, 2, 1), (0, 1, 0)])
    x = R(5, 10, 3).astype(np.float32)
    ans = vmap(fun)(x)
    expected_ans = jnp.stack(list(map(fun, x)))
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

  def testConcatenate(self):
    R = lambda *shape: np.random.RandomState(0).randn(*shape).astype(np.float32)

    fun = lambda *args: lax.concatenate(args, dimension=0)
    x, y, z = R(10, 2, 3), R(1, 10, 3), R(4, 3)
    ans = vmap(fun, in_axes=(0, 1, None))(x, y, z)
    expected_ans = np.concatenate([x, np.swapaxes(y, 0, 1),
                                    np.broadcast_to(z, (10, 4, 3))], 1)
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

    fun = lambda *args: lax.concatenate(args, dimension=1)
    x, y, z = R(10, 2, 1), R(2, 3), R(2, 4, 10)
    ans = vmap(fun, in_axes=(0, None, 2))(x, y, z)
    expected_ans = np.concatenate([x, np.broadcast_to(y, (10, 2, 3)),
                                    np.moveaxis(z, 2, 0)], 2)
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

  def testJacobianIssue54(self):
    # test modeling the code in https://github.com/google/jax/issues/54

    def func(xs):
      return jnp.array(list(xs))

    xs = jnp.ones((5, 1))
    jacrev(func)(xs)  # don't crash
    jacfwd(func)(xs)  # don't crash

  def testAny(self):
    # test modeling the code in https://github.com/google/jax/issues/108

    ans = vmap(jnp.any)(jnp.array([[True, False], [False, False]]))
    expected = jnp.array([True, False])
    self.assertAllClose(ans, expected)

  @jtu.skip_on_devices("tpu")
  def testHessian(self):
    # test based on code from sindhwani@google
    def fun(x, t):
      return jnp.sum(jnp.power(jnp.maximum(x, 0.0), 2)) + t

    x = np.array([-1., -0.5, 0., 0.5, 1.0])

    ans = hessian(lambda x: fun(x, 0.0))(x)
    expected = np.array([[0., 0., 0., 0., 0.],
                          [0., 0., 0., 0., 0.],
                          [0., 0.,0.5, 0., 0.],
                          [0., 0., 0., 2., 0.],
                          [0., 0., 0., 0., 2.]])
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testDynamicSlice(self):
    # test dynamic_slice via numpy indexing syntax
    # see https://github.com/google/jax/issues/1613 for an explanation of why we
    # need to use np rather than np to create x and idx
    x = jnp.arange(30).reshape((10, 3))

    ans = vmap(lambda x, i: x[i], in_axes=(0, None))(x, 1)
    expected = x[:, 1]
    self.assertAllClose(ans, expected, check_dtypes=False)


    idx = jnp.array([0, 1, 2, 1, 0] * 2)
    ans = vmap(lambda x, i: x[i], in_axes=(0, 0))(x, idx)
    expected = x[np.arange(10), idx]
    self.assertAllClose(ans, expected, check_dtypes=False)

    x = jnp.arange(3)
    idx = jnp.array([0, 1, 2, 1, 0] * 2)
    ans = vmap(lambda x, i: x[i], in_axes=(None, 0))(x, idx)
    expected = x[idx]
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testDynamicUpdateSlice(self):
    x = np.random.randn(10, 3)
    y = np.random.randn(10)
    ans = vmap(lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0),
               in_axes=(0, 0, None))(x, y, 1)
    expected = x.copy()
    expected[:, 1] = y
    self.assertAllClose(ans, expected, check_dtypes=False)

    x = np.random.randn(3)
    idx = np.array([0, 1, 2, 1, 0] * 2)
    ans = vmap(lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0),
               in_axes=(None, 0, 0))(x, y, idx)
    expected = np.broadcast_to(x, (10, 3)).copy()
    expected[np.arange(10), idx] = y
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testRandom(self):
    seeds = vmap(random.PRNGKey)(np.arange(10))
    ans = vmap(partial(random.normal, shape=(3, 2)))(seeds)
    expected = np.stack([random.normal(random.PRNGKey(seed), (3, 2))
                          for seed in np.arange(10)])
    self.assertAllClose(ans, expected, check_dtypes=False)
    assert len(np.unique(ans)) == 10 * 3 * 2

  def testSort(self):
    v = np.arange(12)[::-1].reshape(3, 4)

    sv = vmap(partial(lax.sort, dimension=0), (0,))(v)
    self.assertAllClose(sv, v[:, ::-1])

    sv = vmap(partial(lax.sort, dimension=-1), (0,))(v)
    self.assertAllClose(sv, v[:, ::-1])

    sv = vmap(partial(lax.sort, dimension=0), (1,))(v)
    self.assertAllClose(sv, v[::-1, :].T)

    sv = vmap(partial(lax.sort, dimension=0), (1,), 1)(v)
    self.assertAllClose(sv, v[::-1, :])

  def testSortKeyVal(self):
    k = np.arange(12)[::-1].reshape(3, 4)
    v = np.random.RandomState(0).permutation(12).reshape(3, 4)

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 0))(k, v)
    self.assertAllClose(sk, k[:, ::-1])
    self.assertAllClose(sv, v[:, ::-1])

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 1), 1)(k, v)
    self.assertAllClose(sk, k[::-1, :])
    self.assertAllClose(sv, v[::-1, :])

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 1))(k, v.T)
    self.assertAllClose(sk, k[:, ::-1])
    self.assertAllClose(sv, v[:, ::-1])

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 0))(k.T, v)
    self.assertAllClose(sk, k[:, ::-1])
    self.assertAllClose(sv, v[:, ::-1])

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (None, 0))(k[0], v)
    self.assertAllClose(sk, np.broadcast_to(k[0, ::-1], (3, 4)))
    self.assertAllClose(sv, v[:, ::-1])

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, None))(k.T, v[0])
    self.assertAllClose(sk, k[:, ::-1])
    self.assertAllClose(sv, np.broadcast_to(v[0, ::-1], (3, 4)))

  def testConvGeneralDilated(self):
    W = jnp.array(np.random.randn(3, 3, 1, 5), dtype=np.float32)
    X = jnp.array(np.random.randn(10, 5, 5, 1), dtype=np.float32)

    def f(params, x):
      one = (1, 1)
      dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
      y = lax.conv_general_dilated(
          x, params, one, 'SAME', one, one, dimension_numbers)
      return y
    grad_loss = grad(lambda params, x: jnp.mean(f(params, x) ** 2))

    # Test forward prop.
    per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
    per_example = jnp.reshape(per_example, (10, 5, 5, 5))
    per_example_direct = f(W, X)
    self.assertAllClose(per_example, per_example_direct)

    # Test gradients.
    per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
    per_example_direct = []
    for i in range(10):
      g = grad_loss(W, jnp.reshape(X[i], (1, 5, 5, 1)))
      per_example_direct += [
          jnp.reshape(g, (1,) + g.shape)]
    per_example_direct = jnp.concatenate(per_example_direct, axis=0)
    self.assertAllClose(per_example, per_example_direct,
                        rtol=2e-2, atol=2e-3)

  def testConvGeneralDilatedBatchNotMajor(self):
    W = jnp.array(np.random.randn(3, 3, 1, 4), dtype=np.float32)
    x = jnp.array(np.random.randn(3, 5, 7, 5, 1), dtype=np.float32)

    def f(params, x):
      one = (1, 1)
      dimension_numbers = ('HNWC', 'HWIO', 'HWNC')
      y = lax.conv_general_dilated(
          x, params, one, 'SAME', one, one, dimension_numbers)
      return y

    per_example = vmap(partial(f, W))(x)
    per_example = jnp.reshape(jnp.transpose(per_example, (1, 2, 0, 3, 4)),
                             (5, 5, 21, 4))
    per_example_direct = f(W, jnp.reshape(jnp.transpose(x, (1, 0, 2, 3, 4)),
                                         (5, 21, 5, 1)))
    self.assertAllClose(per_example, per_example_direct)

  @parameterized.named_parameters(
    {"testcase_name": "_op={}".format(name), "op": op, "unit": unit}
    for name, op, unit in [("max", lax.max, -jnp.inf), ("min", lax.min, jnp.inf)])
  def testMinMaxPool(self, op, unit):
    W = jnp.array(np.random.randn(3, 3, 1, 5), dtype=np.float32)
    X = jnp.array(np.random.randn(10, 5, 5, 1), dtype=np.float32)

    def f(params, x):
      one = (1, 1)
      dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
      y = lax.conv_general_dilated(
          x, params, one, 'SAME', one, one, dimension_numbers)
      y = lax.reduce_window(
          y, unit, op, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME')
      return y
    grad_loss = grad(lambda params, x: jnp.mean(f(params, x) ** 2))

    # Test forward prop.
    per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
    per_example = jnp.reshape(per_example, (10, 5, 5, 5))
    per_example_direct = f(W, X)
    self.assertAllClose(per_example, per_example_direct)

    # Test gradients.
    per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
    per_example_direct = []
    for i in range(10):
      g = grad_loss(W, jnp.reshape(X[i], (1, 5, 5, 1)))
      per_example_direct += [
          jnp.reshape(g, (1,) + g.shape)]
    per_example_direct = jnp.concatenate(per_example_direct, axis=0)
    self.assertAllClose(per_example, per_example_direct, rtol=5e-2, atol=1e-3)

  def testSumPool(self):
    W = jnp.array(np.random.randn(3, 3, 1, 5), dtype=np.float32)
    X = jnp.array(np.random.randn(10, 5, 5, 1), dtype=np.float32)

    def f(params, x):
      one = (1, 1)
      dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
      y = lax.conv_general_dilated(
          x, params, one, 'SAME', one, one, dimension_numbers)
      y = lax.reduce_window(
          y, 0.0, lax.add, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME')
      return y
    grad_loss = grad(lambda params, x: jnp.mean(f(params, x) ** 2))

    # Test forward prop.
    per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
    per_example = jnp.reshape(per_example, (10, 5, 5, 5))
    per_example_direct = f(W, X)
    self.assertAllClose(per_example, per_example_direct)

    # Test gradients.
    per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
    per_example_direct = []
    for i in range(10):
      g = grad_loss(W, jnp.reshape(X[i], (1, 5, 5, 1)))
      per_example_direct += [
          jnp.reshape(g, (1,) + g.shape)]
    per_example_direct = jnp.concatenate(per_example_direct, axis=0)
    self.assertAllClose(per_example, per_example_direct,
                        rtol=3e-2, atol=1e-3)

  def testCumProd(self):
   x = jnp.arange(9).reshape(3, 3) + 1
   y = vmap(lambda x: jnp.cumprod(x, axis=-1))(x)
   self.assertAllClose(np.cumprod(x, axis=1, dtype=jnp.int_), y)

  def testSelect(self):
    pred = np.array([True, False])
    on_true = np.array([0, 1])
    on_false = np.array([2, 3])
    ans = vmap(lax.select)(pred, on_true, on_false)
    expected = np.array([0, 3])
    self.assertAllClose(ans, expected)

    pred = np.array([False, True])
    on_true = np.array([0, 1])
    on_false = np.array([2, 3])
    ans = vmap(lax.select, (0, None, None))(pred, on_true, on_false)
    expected = np.array([[2, 3],
                          [0, 1]])
    self.assertAllClose(ans, expected)

    pred = True
    on_true = np.array([0, 1], np.float32)
    on_false = np.array(3, np.float32)
    ans = vmap(lax.select, (None, 0, None))(pred, on_true, on_false)
    expected = np.array([0, 1], np.float32)
    self.assertAllClose(ans, expected)

    pred = np.array([False, True])
    on_true = np.array([0, 1], np.float32)
    on_false = np.array(3, np.float32)
    ans = vmap(lax.select, (0, 0, None))(pred, on_true, on_false)
    expected = np.array([3, 1], np.float32)
    self.assertAllClose(ans, expected)

    pred = np.array([False, True])
    on_true = np.array([2], np.float32)
    on_false = np.array([[3, 4]], np.float32)
    ans = vmap(lax.select, (0, None, 1), 1)(pred, on_true, on_false)
    expected = np.array([[3, 2]], np.float32)
    self.assertAllClose(ans, expected)

  def testLaxLinalgCholesky(self):
    a = np.random.RandomState(0).randn(10, 5, 5).astype(np.float32)
    a = np.matmul(a, np.conj(np.swapaxes(a, -1, -2)))

    ans = vmap(lax.linalg.cholesky)(a)
    expected = np.linalg.cholesky(a)
    self.assertAllClose(ans, expected, check_dtypes=False, rtol=1e-4)

    b = np.random.RandomState(0).randn(10, 5, 5).astype(np.float32)
    b = np.matmul(b, np.conj(np.swapaxes(b, -1, -2)))
    b_trans = np.swapaxes(b, 0, 1)  # shape is (5, 10, 5)

    ans = vmap(lax.linalg.cholesky, in_axes=1, out_axes=0)(b_trans)
    expected = np.linalg.cholesky(b)
    self.assertAllClose(ans, expected, check_dtypes=False, rtol=1e-4)

  def testLaxLinalgTriangularSolve(self):
    a = np.random.RandomState(0).randn(4, 10, 4).astype(np.float32)
    a += np.eye(4, dtype=jnp.float32)[:, None, :]
    b = np.random.RandomState(0).randn(5, 4, 10).astype(np.float32)

    ans = vmap(lax.linalg.triangular_solve, in_axes=(1, 2))(a, b)
    expected = np.stack(
      [lax.linalg.triangular_solve(a[:, i], b[..., i]) for i in range(10)])
    self.assertAllClose(ans, expected)

    ans = vmap(lax.linalg.triangular_solve, in_axes=(None, 2))(a[:, 0], b)
    expected = np.stack(
      [lax.linalg.triangular_solve(a[:, 0], b[..., i]) for i in range(10)])
    self.assertAllClose(ans, expected)

    ans = vmap(lax.linalg.triangular_solve, in_axes=(1, None))(a, b[..., 0])
    expected = np.stack(
      [lax.linalg.triangular_solve(a[:, i], b[..., 0]) for i in range(10)])
    self.assertAllClose(ans, expected)

  @parameterized.named_parameters(
      {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
          slice_sizes),
       "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
       "slice_sizes": slice_sizes, "rng_factory": rng_factory}
      for dtype in [np.float32, np.int32]
      for axis, shape, idxs, dnums, slice_sizes in [
          (0, (3, 5), np.array([[0], [2]]), lax.GatherDimensionNumbers(
            offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1,)),
          (1, (10, 3), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
            (2,)),
          (1, (10, 3, 5), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1, 3)),
          (2, (10, 5, 3), np.array([[0, 2], [1, 0]]),
           lax.GatherDimensionNumbers(
             offset_dims=(1,), collapsed_slice_dims=(0,),
             start_index_map=(0, 1)),
            (1, 3)),
      ]
      for rng_factory in [jtu.rand_default])
  def testGatherBatchedOperand(self, axis, shape, dtype, idxs, dnums,
                               slice_sizes, rng_factory):
    rng = rng_factory(self.rng())
    fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
    operand = rng(shape, dtype)
    ans = vmap(fun, (axis, None))(operand, idxs)
    expected = np.stack([fun(operand[(slice(None),) * axis + (i,)], idxs)
                          for i in range(operand.shape[axis])])
    self.assertAllClose(ans, expected, check_dtypes=False)

  @parameterized.named_parameters(
      {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
          slice_sizes),
       "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
       "slice_sizes": slice_sizes, "rng_factory": rng_factory}
      for dtype in [np.float32, np.float64]
      for axis, shape, idxs, dnums, slice_sizes in [
          (0, (3, 5), np.array([[0], [2]]), lax.GatherDimensionNumbers(
            offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1,)),
          (1, (10, 3), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
            (2,)),
          (1, (10, 3, 5), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1, 3)),
          (2, (10, 5, 3), np.array([[0, 2], [1, 0]]),
           lax.GatherDimensionNumbers(
             offset_dims=(1,), collapsed_slice_dims=(0,),
             start_index_map=(0, 1)),
            (1, 3)),      ]
      for rng_factory in [jtu.rand_default])
  def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums,
                                   slice_sizes, rng_factory):
    rng = rng_factory(self.rng())
    fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
    gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx))))
    operand = rng(shape, dtype)
    ans = vmap(gfun, (axis, None))(operand, idxs)
    expected = np.stack([gfun(operand[(slice(None),) * axis + (i,)], idxs)
                          for i in range(operand.shape[axis])])
    self.assertAllClose(ans, expected, check_dtypes=False)

  @parameterized.named_parameters(
      {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
          slice_sizes),
       "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
       "slice_sizes": slice_sizes, "rng_factory": rng_factory}
      for dtype in [np.float32, np.int32]
      for axis, shape, idxs, dnums, slice_sizes in [
          (0, (5,), np.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers(
              offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)),
          (1, (10,), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
           lax.GatherDimensionNumbers(
               offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)),
          (1, (10, 5), np.array([[0, 2, 1], [0, 3, 3]]).T[..., None],
           lax.GatherDimensionNumbers(
               offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)),
          (0, (10, 5), np.array([[[0, 1], [2, 0]],
                                  [[1, 0], [2, 3]]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)),
      ]
      for rng_factory in [jtu.rand_default])
  def testGatherBatchedIndices(self, axis, shape, dtype, idxs, dnums,
                               slice_sizes, rng_factory):
    rng = rng_factory(self.rng())
    fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
    operand = rng(shape, dtype)
    ans = vmap(fun, (None, axis))(operand, idxs)
    expected = np.stack([fun(operand, idxs[(slice(None),) * axis + (i,)])
                          for i in range(idxs.shape[axis])])
    self.assertAllClose(ans, expected, check_dtypes=False)

  @parameterized.named_parameters(
      {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
          slice_sizes),
       "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
       "slice_sizes": slice_sizes, "rng_factory": rng_factory}
      for dtype in [np.float32, np.float64]
      for axis, shape, idxs, dnums, slice_sizes in [
          (0, (5,), np.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers(
              offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)),
          (1, (10,), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
           lax.GatherDimensionNumbers(
               offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)),
          (1, (10, 5), np.array([[0, 2, 1], [0, 3, 3]]).T[..., None],
           lax.GatherDimensionNumbers(
               offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)),
          (0, (10, 5), np.array([[[0, 1], [2, 0]],
                                  [[1, 0], [2, 3]]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)),
      ]
      for rng_factory in [jtu.rand_default])
  def testGatherGradBatchedIndices(self, axis, shape, dtype, idxs, dnums,
                                   slice_sizes, rng_factory):
    rng = rng_factory(self.rng())
    fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
    gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx))))
    operand = rng(shape, dtype)
    ans = vmap(gfun, (None, axis))(operand, idxs)
    expected = np.stack([gfun(operand, idxs[(slice(None),) * axis + (i,)])
                          for i in range(idxs.shape[axis])])
    self.assertAllClose(ans, expected, check_dtypes=False)

  @parameterized.named_parameters(
      {"testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs,
          dnums, slice_sizes),
       "op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype":
       dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes,
       "rng_factory": rng_factory}
      for dtype in [np.float32, np.int32]
      for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
          (0, 0, (2, 5), np.array([[[0], [2]], [[1], [3]]]),
           lax.GatherDimensionNumbers(
             offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
           (1,)),
          (1, 1, (10, 2), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
           lax.GatherDimensionNumbers(
             offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
           (2,)),
          (0, 1, (2, 10, 5,), np.array([[[0, 2, 1], [0, 3, 3]]]).T,
           lax.GatherDimensionNumbers(
             offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
           (1, 3)),
          (2, 0, (10, 5, 2), np.array([[[0, 2], [1, 0]],
                                        [[1, 0], [2, 0]]]),
          lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
           (1, 3)),
      ]
      for rng_factory in [jtu.rand_default])
  def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums,
                            slice_sizes, rng_factory):
    rng = rng_factory(self.rng())
    fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
    operand = rng(shape, dtype)
    assert operand.shape[op_axis] == idxs.shape[idxs_axis]
    ans = vmap(fun, (op_axis, idxs_axis))(operand, idxs)
    expected = np.stack([fun(operand[(slice(None),) * op_axis + (i,)],
                              idxs[(slice(None),) * idxs_axis + (i,)])
                          for i in range(idxs.shape[idxs_axis])])
    self.assertAllClose(ans, expected, check_dtypes=False)

  @parameterized.named_parameters(
      {"testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs,
          dnums, slice_sizes),
       "op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype":
       dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes,
       "rng_factory": rng_factory}
      for dtype in [np.float32]
      for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
          (0, 0, (2, 5), np.array([[[0], [2]], [[1], [3]]]),
           lax.GatherDimensionNumbers(
             offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
           (1,)),
          (1, 1, (10, 2), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
           lax.GatherDimensionNumbers(
             offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
           (2,)),
          (0, 1, (2, 10, 5,), np.array([[[0, 2, 1], [0, 3, 3]]]).T,
           lax.GatherDimensionNumbers(
             offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
           (1, 3)),
          (2, 0, (10, 5, 2), np.array([[[0, 2], [1, 0]],
                                        [[1, 0], [2, 0]]]),
          lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
           (1, 3)),
      ]
      for rng_factory in [jtu.rand_default])
  def testGatherGradBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums,
                                slice_sizes, rng_factory):
    rng = rng_factory(self.rng())
    fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
    gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx))))
    operand = rng(shape, dtype)
    assert operand.shape[op_axis] == idxs.shape[idxs_axis]
    ans = vmap(gfun, (op_axis, idxs_axis))(operand, idxs)
    expected = np.stack([gfun(operand[(slice(None),) * op_axis + (i,)],
                              idxs[(slice(None),) * idxs_axis + (i,)])
                          for i in range(idxs.shape[idxs_axis])])
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testNumpyIndexing1(self):
    a = jnp.arange(2 * 3 * 4).reshape((2, 3, 4))
    ind = np.array([[0, 1],
                    [2, 0]])
    def f(a, ind):
      return a[:, ind]
    expected = np.stack([f(a, ind[i, :]) for i in range(ind.shape[0])])
    ans = vmap(f, (None, 0))(a, ind)
    assert np.all(ans == expected)

  def testNumpyIndexing2(self):
    a = jnp.arange(2 * 3 * 4).reshape((2, 3, 4))
    def f(a):
      inds = jnp.array([0, 2])
      return a[:, inds]
    ans = vmap(f)(a)
    expected = np.stack([f(a[:, i, :]) for i in range(a.shape[1])], axis=1)
    assert np.all(ans == expected)

  def testTranspose(self):
    x = np.arange(4 * 3 * 3).reshape((4, 3, 3))
    ans = vmap(lambda x: x + x.T)(x)
    expected = x + np.swapaxes(x, -1, -2)
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testTransposePermutation(self):
    x = np.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5))
    ans = vmap(lambda x: jnp.transpose(x, (1, 0, 2)))(x)
    expected = np.transpose(x, (0, 2, 1, 3))
    self.assertAllClose(ans, expected, check_dtypes=False)

    x = np.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5))
    ans = vmap(lambda x: jnp.transpose(x, (1, 2, 0)))(x)
    expected = np.transpose(x, (0, 2, 3, 1))
    self.assertAllClose(ans, expected, check_dtypes=False)

    x = np.arange(6 * 3 * 4 * 5).reshape((3, 4, 6, 5))
    ans = vmap(lambda x: jnp.transpose(x, (1, 2, 0)), in_axes=2)(x)
    expected = np.transpose(x, (2, 1, 3, 0))
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testIssue354(self):
    psd_mat = np.random.randn(20, 10)
    psd_mat = psd_mat.T.dot(psd_mat)
    vec = np.random.randn(10)

    def f(scale):
      scaled_mat = scale * psd_mat
      chol = jnp.linalg.cholesky(scaled_mat)
      return -0.5 * jnp.sum((jnp.einsum('ij,j->i', chol, vec))**2)
    vmapped_f = vmap(f)
    vmapped_f_grad = grad(lambda x: jnp.sum(vmapped_f(x)))

    scales = np.array([[0.1], [0.2], [0.3], [0.4], [0.5]])
    ans = vmapped_f_grad(scales)  # don't crash!
    expected = np.stack([grad(f)(scale) for scale in scales])
    self.assertAllClose(ans, expected, check_dtypes=False,
                        rtol=jtu.default_gradient_tolerance)

  def testIssue387(self):
    # https://github.com/google/jax/issues/387
    R = np.random.RandomState(0).rand(100, 2)

    def dist_sq(R):
      dR = R[:, jnp.newaxis, :] - R[jnp.newaxis, :, :]
      zero = jnp.zeros_like(dR)
      dR = dR - jnp.where(jnp.abs(dR) < 0.5, zero, 0.5 * jnp.sign(dR))
      return jnp.sum(dR ** 2, axis=2)

    @jit
    def f(R):
      _ = dist_sq(R)
      return jnp.sum(R ** 2)

    _ = hessian(f)(R)  # don't crash on UnshapedArray

  def testIssue489(self):
    def f(key):
      def body_fn(uk):
        key = uk[1]
        u = random.uniform(key, (), dtype=jnp.float64)
        key, _ = random.split(key)
        return u, key

      u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn,
                            (jnp.float64(1.), key))
      return u

    print(vmap(f)(random.split(random.PRNGKey(0), 2)))  # no crash

  def testEmptyTuples(self):
    # Ensure there is no crash when a vectorized input contains empty tuples.
    result = vmap(lambda x, _: x + 1)(np.array([0, 1]), ())
    self.assertAllClose(result, np.array([1, 2]), check_dtypes=False)
    # Ensure there is no crash when a vectorized output contains empty tuples.
    result, empty_tuple = vmap(lambda x: (x + 1, ()))(np.array([0, 1]))
    self.assertAllClose(result, np.array([1, 2]), check_dtypes=False)
    self.assertEqual((), empty_tuple)

  def testIndexAddBatchedIndexesOnly(self):
    f = lambda x, idx, y: jax.ops.index_add(x, jax.ops.index[idx], y)
    result = vmap(f, (None, 0, None))(np.zeros((10,)), np.arange(10,), 1.)
    self.assertAllClose(result, np.eye(10), check_dtypes=False)

  def testIssue1170(self):
    def f(index1, index2):
      return jnp.arange(36).reshape(6, 6)[index1, index2]
    g = jax.jit(jax.pmap(f))
    ans = g(index1=np.asarray([1]), index2=np.asarray([2]))
    expected = g(np.asarray([1]), np.asarray([2]))
    self.assertAllClose(ans, expected)

  def testIssue3883(self):
    def scalar_f(x):
      return lax.dynamic_slice(x, [], [])

    xs = jnp.array([1, 2, 3, 4])
    ans = vmap(scalar_f)(xs)
    expected = jnp.array([scalar_f(x) for x in xs])
    self.assertAllClose(ans, expected)

    def scalar_f2(x):
      return lax.dynamic_update_slice(x, 7, [])

    xs = jnp.array([1, 2, 3, 4])
    ans = vmap(scalar_f2)(xs)
    expected = jnp.array([scalar_f2(x) for x in xs])
    self.assertAllClose(ans, expected)

  @parameterized.named_parameters(
      {"testcase_name": "_collective={}".format(seq.__name__).replace(" ", ""),
       "collective": collective,
       "seq": seq}
      for collective, seq in [(lax.psum, jnp.sum),
                              (lax.pmean, jnp.mean),
                              (lambda x, n: lax.pmax(x, n), jnp.max),
                              (lambda x, n: lax.pmin(x, n), jnp.min)])
  @skipIf(not jax.config.omnistaging_enabled,
          "vmap collectives only supported when omnistaging is enabled")
  def testCollective(self, collective, seq):
    x = jnp.arange(64).reshape((4, 4, 4))
    self.assertAllClose(
      vmap(lambda x: x - collective(x, 'i'), axis_name='i')(x),
      x - seq(x, axis=0))

    self.assertAllClose(
      vmap(vmap(lambda x: x - collective(x, ('j', 'i')), axis_name='i'), axis_name='j')(x),
      x - seq(x, axis=(0, 1)))

    self.assertAllClose(
      vmap(vmap(lambda x: x - collective(x, ('i', 'j')), axis_name='i'), axis_name='j')(x),
      x - seq(x, axis=(1, 0)))

  @skipIf(not jax.config.omnistaging_enabled,
          "vmap collectives only supported when omnistaging is enabled")
  def testPPermute(self):
    nelem = 10
    ntests = 10
    x = np.arange(nelem)
    rng = np.random.RandomState(1)
    for i in range(ntests):
      perm = np.arange(nelem)
      rng.shuffle(perm)
      perm_pairs = np.stack([np.arange(nelem), perm], axis=-1)
      rng.shuffle(perm_pairs)
      self.assertAllClose(
        vmap(lambda x: x - lax.ppermute(x, 'i', perm_pairs), axis_name='i')(x),
        x - x[perm])

  @parameterized.named_parameters(
      {"testcase_name": f"_split={split_axis}_concat={concat_axis}_vmap={vmap_axis}",
       "split_axis": split_axis, "concat_axis": concat_axis, "vmap_axis": vmap_axis}
      for split_axis, concat_axis, vmap_axis in it.product(range(3), range(3), range(4)))
  @skipIf(not jax.config.omnistaging_enabled,
          "vmap collectives only supported when omnistaging is enabled")
  def testAllToAll(self, vmap_axis, split_axis, concat_axis):
    d = vmap_axis

    def shape_fun(x, out_d):
      shape = list(x.shape)
      vmap_dim_id = shape.pop(d)
      split_dim_id = shape.pop(split_axis)
      shape.insert(concat_axis, vmap_dim_id)
      shape.insert(out_d, split_dim_id)
      return tuple(shape)

    shape = (2, 3, 4, 5)
    x = np.arange(np.prod(shape)).reshape(shape)
    rule = batching.collective_rules[lax.all_to_all_p]
    (y,), (out_d,) = rule((x,), (d,), None, None, split_axis, concat_axis)
    exp_shape = shape_fun(x, out_d)
    self.assertEqual(y.shape, exp_shape)

  @parameterized.named_parameters(
      {"testcase_name": f"_split={split_axis}_concat={concat_axis}_vmap={vmap_axis}",
       "split_axis": split_axis, "concat_axis": concat_axis, "vmap_axis": vmap_axis}
      for split_axis, concat_axis, vmap_axis in it.product(range(2), range(2), range(3)))
  @skipIf(not jax.config.omnistaging_enabled,
          "vmap collectives only supported when omnistaging is enabled")
  def testAllToAllSplitAxis(self, vmap_axis, split_axis, concat_axis):
    shape = (4, 4, 4)
    x = np.arange(np.prod(shape)).reshape(shape)

    @partial(vmap, in_axes=vmap_axis, axis_name='i')
    @partial(vmap, in_axes=vmap_axis, axis_name='j')
    def f(x):
      return lax.all_to_all(x, ('i', 'j'), split_axis, concat_axis)

    unroll_shape = (2, 2, *shape[1:])
    unroll_shape = list(shape)
    unroll_shape[vmap_axis:vmap_axis+1] = (2, 2)
    x_unroll = x.reshape(unroll_shape)
    y_unrolled = f(x_unroll)
    y = y_unrolled.reshape(shape)

    if vmap_axis <= split_axis:
      split_axis += 1
    ref = jnp.moveaxis(x, (vmap_axis, split_axis),
                          (concat_axis + 1, 0))
    self.assertAllClose(y, ref)

  def testNegativeAxes(self):
    x = np.arange(3*4*5).reshape(3, 4, 5)
    self.assertAllClose(jax.vmap(jnp.sum, in_axes=-3)(x),
                        jnp.sum(x, axis=(1, 2)))
    self.assertAllClose(jax.vmap(jnp.sum, in_axes=-2)(x),
                        jnp.sum(x, axis=(0, 2)))
    self.assertAllClose(jax.vmap(jnp.sum, in_axes=-1)(x),
                        jnp.sum(x, axis=(0, 1)))

    with self.assertRaisesRegex(ValueError, "vmap got arg 0 of rank 3 but axis to be mapped -4"):
      jax.vmap(jnp.sum, in_axes=-4)(x)

    id = lambda y: y
    self.assertAllClose(x, jax.vmap(id, in_axes=0, out_axes=-3)(x))
    self.assertAllClose(x.transpose(1, 0, 2),
                        jax.vmap(id, in_axes=0, out_axes=-2)(x))
    self.assertAllClose(x.transpose(1, 2, 0),
                        jax.vmap(id, in_axes=0, out_axes=-1)(x))

    with self.assertRaisesRegex(ValueError, "axis -4 is out of bounds.*"):
      jax.vmap(id, in_axes=0, out_axes=-4)(x)

    self.assertAllClose(
      np.full((5,), 7),
      jax.vmap(lambda *xs: xs, in_axes=(0, None), out_axes=(0, -1))(
        np.arange(5), 7)[1])

    with self.assertRaisesRegex(ValueError, "axis -2 is out of bounds.*"):
      jax.vmap(lambda *xs: xs, in_axes=(0, None), out_axes=(0, -2))(
        np.arange(5), 7)

  @skipIf(not jax.config.omnistaging_enabled,
          "vmap collectives only supported when omnistaging is enabled")
  def testAxisIndex(self):
    x = np.arange(10)
    self.assertAllClose(
      vmap(lambda x: x - lax.axis_index('i'), axis_name='i')(x),
      x - np.arange(x.shape[0]))
Пример #6
0
class IndexedUpdateTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "{}_inshape={}_indexer={}_update={}_op={}".format(
                    name, jtu.format_shape_dtype_string(shape, dtype), indexer,
                    jtu.format_shape_dtype_string(update_shape, update_dtype),
                    op.name),
                "shape":
                shape,
                "dtype":
                dtype,
                "rng_factory":
                rng_factory,
                "indexer":
                indexer,
                "update_shape":
                update_shape,
                "update_dtype":
                update_dtype,
                "op":
                op
            } for name, index_specs in STATIC_INDEXING_TESTS
            for shape, indexer in index_specs for op in UpdateOps
            for dtype in (
                all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
            for update_shape in _broadcastable_shapes(
                _update_shape(shape, indexer)) for update_dtype in (
                    [dtype] if op == UpdateOps.ADD else all_dtypes)
            for rng_factory in [jtu.rand_default]))
    def testStaticIndexing(self, shape, dtype, update_shape, update_dtype,
                           rng_factory, indexer, op):
        rng = rng_factory()
        args_maker = lambda: [
            rng(shape, dtype),
            rng(update_shape, update_dtype)
        ]
        onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y)
        jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
        self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True)
        self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "{}_inshape={}_indexer={}_update={}_op={}".format(
                    name, jtu.format_shape_dtype_string(shape, dtype), indexer,
                    jtu.format_shape_dtype_string(update_shape, update_dtype),
                    op.name),
                "shape":
                shape,
                "dtype":
                dtype,
                "rng_factory":
                rng_factory,
                "indexer":
                indexer,
                "update_shape":
                update_shape,
                "update_dtype":
                update_dtype,
                "op":
                op
            } for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS
            for shape, indexer in index_specs for op in UpdateOps
            for dtype in (
                all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
            for update_shape in _broadcastable_shapes(
                _update_shape(shape, indexer)) for update_dtype in (
                    [dtype] if op == UpdateOps.ADD else all_dtypes)
            for rng_factory in [jtu.rand_default]))
    def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
                             rng_factory, indexer, op):
        rng = rng_factory()
        args_maker = lambda: [
            rng(shape, dtype),
            rng(update_shape, update_dtype)
        ]
        onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y)
        jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
        self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True)
        self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "{}_inshape={}_indexer={}_update={}_op={}".format(
                    name, jtu.format_shape_dtype_string(shape, dtype), indexer,
                    jtu.format_shape_dtype_string(update_shape, update_dtype),
                    op.name),
                "shape":
                shape,
                "dtype":
                dtype,
                "rng_factory":
                rng_factory,
                "indexer":
                indexer,
                "update_shape":
                update_shape,
                "update_dtype":
                update_dtype,
                "op":
                op
            } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS
            for shape, indexer in index_specs for op in UpdateOps
            for dtype in (
                all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
            for update_shape in _broadcastable_shapes(
                _update_shape(shape, indexer)) for update_dtype in (
                    [dtype] if op == UpdateOps.ADD else all_dtypes)
            for rng_factory in [jtu.rand_default]))
    def testMixedAdvancedIndexing(self, shape, dtype, update_shape,
                                  update_dtype, rng_factory, indexer, op):
        rng = rng_factory()
        args_maker = lambda: [
            rng(shape, dtype),
            rng(update_shape, update_dtype)
        ]
        onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y)
        jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
        self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True)
        self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "{}_inshape={}_indexer={}_update={}_op={}".format(
                    name, jtu.format_shape_dtype_string(shape, dtype), indexer,
                    jtu.format_shape_dtype_string(update_shape, update_dtype),
                    op.name),
                "shape":
                shape,
                "dtype":
                dtype,
                "rng_factory":
                rng_factory,
                "indexer":
                indexer,
                "update_shape":
                update_shape,
                "update_dtype":
                update_dtype,
                "op":
                op
            } for name, index_specs in STATIC_INDEXING_TESTS
            for shape, indexer in index_specs for op in UpdateOps
            for dtype in float_dtypes
            for update_shape in _broadcastable_shapes(
                _update_shape(shape, indexer)) for update_dtype in (
                    [dtype] if op == UpdateOps.ADD else float_dtypes)
            for rng_factory in [jtu.rand_default]))
    @jtu.skip_on_devices("tpu")  # TODO(mattjj,phawkins): tpu issues
    def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
                                rng_factory, indexer, op):
        rng = rng_factory()
        jax_op = ops.index_update if op == UpdateOps.UPDATE else ops.index_add
        jax_fn = lambda x, y: jax_op(x, indexer, y)
        x = rng(shape, dtype)
        y = rng(update_shape, update_dtype)
        check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.)

    def testSegmentSumBehavior(self):
        # testAdvancedIndexing compares against NumPy, and as a result doesn't check
        # repeated indices. This test is just a simple manual check, based on
        # https://www.tensorflow.org/api_docs/python/tf/math/segment_sum
        data = onp.array([5, 1, 7, 2, 3, 4, 1, 3])
        segment_ids = onp.array([0, 0, 0, 1, 2, 2, 3, 3])

        ans = ops.index_add(onp.zeros(onp.max(segment_ids) + 1), segment_ids,
                            data)
        expected = onp.array([13, 2, 7, 4])
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testSegmentSum(self):
        data = onp.array([5, 1, 7, 2, 3, 4, 1, 3])
        segment_ids = onp.array([0, 0, 0, 1, 2, 2, 3, 3])

        # test with explicit num_segments
        ans = ops.segment_sum(data, segment_ids, num_segments=4)
        expected = onp.array([13, 2, 7, 4])
        self.assertAllClose(ans, expected, check_dtypes=False)

        # test without explicit num_segments
        ans = ops.segment_sum(data, segment_ids)
        expected = onp.array([13, 2, 7, 4])
        self.assertAllClose(ans, expected, check_dtypes=False)
Пример #7
0
class NumpyLinalgTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)]
                            for dtype in float_types()
                            for rng in [jtu.rand_default()]))
    def testCholesky(self, shape, dtype, rng):
        def args_maker():
            a = rng(shape, dtype)
            return [onp.matmul(a, np.conj(T(a)))]

        self._CheckAgainstNumpy(onp.linalg.cholesky,
                                np.linalg.cholesky,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(np.linalg.cholesky,
                              args_maker,
                              check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_n={}".format(jtu.format_shape_dtype_string((n, n), dtype)),
            "n":
            n,
            "dtype":
            dtype,
            "rng":
            rng
        } for n in [0, 4, 5, 50] for dtype in float_types() | complex_types()
                            for rng in [jtu.rand_default()]))
    # TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
    @jtu.skip_on_devices("gpu", "tpu")
    def testDet(self, n, dtype, rng):
        if not hasattr(lapack, "jax_getrf"):
            self.skipTest("No LU implementation available")
        args_maker = lambda: [rng((n, n), dtype)]

        self._CheckAgainstNumpy(onp.linalg.det,
                                np.linalg.det,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(np.linalg.det, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_n={}".format(jtu.format_shape_dtype_string((n, n), dtype)),
            "n":
            n,
            "dtype":
            dtype,
            "rng":
            rng
        } for n in [0, 4, 10, 200]
                            for dtype in float_types() | complex_types()
                            for rng in [jtu.rand_default()]))
    @jtu.skip_on_devices("gpu", "tpu")
    def testSlogdet(self, n, dtype, rng):
        if not hasattr(lapack, "jax_getrf"):
            self.skipTest("No LU implementation available")
        args_maker = lambda: [rng((n, n), dtype)]

        self._CheckAgainstNumpy(onp.linalg.slogdet,
                                np.linalg.slogdet,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(np.linalg.slogdet, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_n={}_lower={}".format(
                jtu.format_shape_dtype_string((n, n), dtype), lower),
            "n":
            n,
            "dtype":
            dtype,
            "lower":
            lower,
            "rng":
            rng
        } for n in [0, 4, 5, 50] for dtype in float_types() | complex_types()
                            for lower in [False, True]
                            for rng in [jtu.rand_default()]))
    # TODO(phawkins): enable when there is an eigendecomposition implementation
    # for GPU/TPU.
    @jtu.skip_on_devices("gpu", "tpu")
    def testEigh(self, n, dtype, lower, rng):
        if not hasattr(lapack, "jax_syevd"):
            self.skipTest(
                "No symmetric eigendecomposition implementation available")
        args_maker = lambda: [rng((n, n), dtype)]

        uplo = "L" if lower else "U"

        # Norm, adjusted for dimension and type.
        def norm(x):
            norm = onp.linalg.norm(x, axis=(-2, -1))
            return norm / ((n + 1) * onp.finfo(dtype).eps)

        a, = args_maker()
        a = (a + onp.conj(a.T)) / 2
        w, v = np.linalg.eigh(onp.tril(a) if lower else onp.triu(a), UPLO=uplo)

        self.assertTrue(norm(onp.eye(n) - onp.matmul(onp.conj(T(v)), v)) < 5)
        self.assertTrue(norm(onp.matmul(a, v) - w * v) < 30)

        self._CompileAndCheck(partial(np.linalg.eigh, UPLO=uplo),
                              args_maker,
                              check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_n={}_full_matrices={}_compute_uv={}".format(
                jtu.format_shape_dtype_string((
                    m, n), dtype), full_matrices, compute_uv),
            "m":
            m,
            "n":
            n,
            "dtype":
            dtype,
            "full_matrices":
            full_matrices,
            "compute_uv":
            compute_uv,
            "rng":
            rng
        } for m in [2, 7, 29, 53] for n in [2, 7, 29, 53]
                            for dtype in float_types() | complex_types()
                            for full_matrices in [False, True]
                            for compute_uv in [False, True]
                            for rng in [jtu.rand_default()]))
    @jtu.skip_on_devices("gpu", "tpu")
    def testSVD(self, m, n, dtype, full_matrices, compute_uv, rng):
        if not hasattr(lapack, "jax_gesdd"):
            self.skipTest(
                "No singular value decomposition implementation available")

        args_maker = lambda: [rng((m, n), dtype)]

        # Norm, adjusted for dimension and type.
        def norm(x):
            norm = onp.linalg.norm(x, axis=(-2, -1))
            return norm / (max(m, n) * onp.finfo(dtype).eps)

        a, = args_maker()
        out = np.linalg.svd(a,
                            full_matrices=full_matrices,
                            compute_uv=compute_uv)

        if compute_uv:
            # Check the reconstructed matrices
            if full_matrices:
                k = min(m, n)
                if m < n:
                    self.assertTrue(
                        onp.all(
                            norm(a - onp.matmul(out[1] *
                                                out[0], out[2][:k, :])) < 50))
                else:
                    self.assertTrue(
                        onp.all(
                            norm(a - onp.matmul(out[1] *
                                                out[0][:, :k], out[2])) < 50))
            else:
                self.assertTrue(
                    onp.all(
                        norm(a - onp.matmul(out[1] * out[0], out[2])) < 50))

            # Check the unitary properties of the singular vector matrices.
            self.assertTrue(
                onp.all(
                    norm(
                        onp.eye(out[0].shape[1]) -
                        onp.matmul(onp.conj(T(out[0])), out[0])) < 10))
            if m >= n:
                self.assertTrue(
                    onp.all(
                        norm(
                            onp.eye(out[2].shape[1]) -
                            onp.matmul(onp.conj(T(out[2])), out[2])) < 10))
            else:
                self.assertTrue(
                    onp.all(
                        norm(
                            onp.eye(out[2].shape[0]) -
                            onp.matmul(out[2], onp.conj(T(out[2])))) < 20))

        else:
            self.assertTrue(
                onp.allclose(onp.linalg.svd(a, compute_uv=False),
                             onp.asarray(out),
                             atol=1e-4,
                             rtol=1e-4))

        self._CompileAndCheck(partial(np.linalg.svd,
                                      full_matrices=full_matrices,
                                      compute_uv=compute_uv),
                              args_maker,
                              check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_fullmatrices={}".format(
                jtu.format_shape_dtype_string(shape, dtype), full_matrices),
            "shape":
            shape,
            "dtype":
            dtype,
            "full_matrices":
            full_matrices,
            "rng":
            rng
        } for shape in [(1, 1), (3, 4), (2, 10, 5), (2, 200, 100)]
                            for dtype in float_types()
                            for full_matrices in [False, True]
                            for rng in [jtu.rand_default()]))
    @jtu.skip_on_devices("cpu")
    def testQr(self, shape, dtype, full_matrices, rng):
        m, n = shape[-2:]

        if full_matrices:
            mode, k = "complete", m
        else:
            mode, k = "reduced", min(m, n)

        a = rng(shape, dtype)
        lq, lr = np.linalg.qr(a, mode=mode)

        # onp.linalg.qr doesn't support broadcasting. But it seems like an
        # inevitable extension so we support it in our version.
        nq = onp.zeros(shape[:-2] + (m, k), dtype)
        nr = onp.zeros(shape[:-2] + (k, n), dtype)
        for index in onp.ndindex(*shape[:-2]):
            nq[index], nr[index] = onp.linalg.qr(a[index], mode=mode)

        max_rank = max(m, n)

        # Norm, adjusted for dimension and type.
        def norm(x):
            n = onp.linalg.norm(x, axis=(-2, -1))
            return n / (max_rank * onp.finfo(dtype).eps)

        def compare_orthogonal(q1, q2):
            # Q is unique up to sign, so normalize the sign first.
            sum_of_ratios = onp.sum(onp.divide(q1, q2), axis=-2, keepdims=True)
            phases = onp.divide(sum_of_ratios, onp.abs(sum_of_ratios))
            q1 *= phases
            self.assertTrue(onp.all(norm(q1 - q2) < 30))

        # Check a ~= qr
        self.assertTrue(onp.all(norm(a - onp.matmul(lq, lr)) < 30))

        # Compare the first 'k' vectors of Q; the remainder form an arbitrary
        # orthonormal basis for the null space.
        compare_orthogonal(nq[..., :k], lq[..., :k])

        # Check that q is close to unitary.
        self.assertTrue(onp.all(norm(onp.eye(k) - onp.matmul(T(lq), lq)) < 5))

        if not full_matrices and m >= n:
            jtu.check_jvp(np.linalg.qr, partial(jvp, np.linalg.qr), (a, ))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs={}_rhs={}".format(
                jtu.format_shape_dtype_string(lhs_shape, dtype),
                jtu.format_shape_dtype_string(rhs_shape, dtype)),
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for lhs_shape, rhs_shape in [
            ((1, 1), (1, 1)),
            ((4, 4), (4, )),
            ((8, 8), (8, 4)),
        ] for dtype in float_types() | complex_types()
                            for rng in [jtu.rand_default()]))
    # TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
    @jtu.skip_on_devices("gpu", "tpu")
    def testSolve(self, lhs_shape, rhs_shape, dtype, rng):
        if not hasattr(lapack, "jax_getrf"):
            self.skipTest("No LU implementation available")
        args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]

        self._CheckAgainstNumpy(onp.linalg.solve,
                                np.linalg.solve,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(np.linalg.solve, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (5, 5, 5)]
                            for dtype in float_types()
                            for rng in [jtu.rand_default()]))
    def testInv(self, shape, dtype, rng):
        def args_maker():
            invertible = False
            while not invertible:
                a = rng(shape, dtype)
                try:
                    onp.linalg.inv(a)
                    invertible = True
                except onp.linalg.LinAlgError:
                    pass
            return [a]

        self._CheckAgainstNumpy(onp.linalg.inv,
                                np.linalg.inv,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(np.linalg.inv, args_maker, check_dtypes=True)
Пример #8
0
class SparseObjectTest(jtu.JaxTestCase):
    @parameterized.named_parameters({
        "testcase_name": "_{}".format(Obj.__name__),
        "Obj": Obj
    } for Obj in [sparse_ops.CSR, sparse_ops.CSC, sparse_ops.COO])
    def test_attrs(self, Obj, shape=(5, 8), dtype=np.float16):
        rng = rand_sparse(self.rng(), post=Obj.fromdense)
        M = rng(shape, dtype)

        assert isinstance(M, Obj)
        assert M.shape == shape
        assert M.dtype == dtype
        assert M.nnz == (M.todense() != 0).sum()
        assert M.data.dtype == dtype

        if isinstance(M, sparse_ops.CSR):
            assert len(M.data) == len(M.indices)
            assert len(M.indptr) == M.shape[0] + 1
        elif isinstance(M, sparse_ops.CSC):
            assert len(M.data) == len(M.indices)
            assert len(M.indptr) == M.shape[1] + 1
        elif isinstance(M, sparse_ops.COO):
            assert len(M.data) == len(M.row) == len(M.col)
        else:
            raise ValueError("Obj={Obj} not expected.")

    @parameterized.named_parameters(
        itertools.chain.from_iterable(
            jtu.cases_from_list({
                "testcase_name":
                "_{}_Obj={}".format(
                    jtu.format_shape_dtype_string(shape, dtype), Obj.__name__),
                "shape":
                shape,
                "dtype":
                dtype,
                "Obj":
                Obj
            } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
                                for dtype in jtu.dtypes.floating +
                                jtu.dtypes.complex)
            for Obj in [sparse_ops.CSR, sparse_ops.CSC, sparse_ops.COO]))
    def test_dense_round_trip(self, shape, dtype, Obj):
        rng = rand_sparse(self.rng())
        M = rng(shape, dtype)
        Msparse = Obj.fromdense(M)
        self.assertArraysEqual(M, Msparse.todense())

    @parameterized.named_parameters(
        itertools.chain.from_iterable(
            jtu.cases_from_list({
                "testcase_name":
                "_{}_Obj={}".format(
                    jtu.format_shape_dtype_string(shape, dtype), Obj.__name__),
                "shape":
                shape,
                "dtype":
                dtype,
                "Obj":
                Obj
            } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
                                for dtype in jtu.dtypes.floating +
                                jtu.dtypes.complex)
            for Obj in [sparse_ops.CSR, sparse_ops.CSC, sparse_ops.COO]))
    def test_transpose(self, shape, dtype, Obj):
        rng = rand_sparse(self.rng())
        M = rng(shape, dtype)
        Msparse = Obj.fromdense(M)
        self.assertArraysEqual(M.T, Msparse.T.todense())

    @unittest.skipIf(jtu.device_under_test() == "tpu",
                     "TPU has insufficient precision")
    @parameterized.named_parameters(
        itertools.chain.from_iterable(
            jtu.cases_from_list(
                {
                    "testcase_name":
                    "_{}_Obj={}_bshape={}".format(
                        jtu.format_shape_dtype_string(shape, dtype),
                        Obj.__name__, bshape),
                    "shape":
                    shape,
                    "dtype":
                    dtype,
                    "Obj":
                    Obj,
                    "bshape":
                    bshape
                } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
                for bshape in [shape[-1:] + s for s in [(), (3, ), (4, )]]
                for dtype in jtu.dtypes.floating + jtu.dtypes.complex)
            for Obj in [sparse_ops.CSR, sparse_ops.CSC, sparse_ops.COO]))
    def test_matmul(self, shape, dtype, Obj, bshape):
        rng = rand_sparse(self.rng(), post=jnp.array)
        rng_b = jtu.rand_default(self.rng())
        M = rng(shape, dtype)
        Msp = Obj.fromdense(M)
        x = rng_b(bshape, dtype)
        x = jnp.asarray(x)

        self.assertAllClose(M @ x, Msp @ x, rtol=MATMUL_TOL)
Пример #9
0
class cuSparseTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype
        } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
                            for dtype in jtu.dtypes.floating +
                            jtu.dtypes.complex))
    def test_csr_todense(self, shape, dtype):
        rng = rand_sparse(self.rng(), post=sparse.csr_matrix)
        M = rng(shape, dtype)

        args = (M.data, M.indices, M.indptr)
        todense = lambda *args: sparse_ops.csr_todense(*args, shape=M.shape)

        self.assertArraysEqual(M.toarray(), todense(*args))
        self.assertArraysEqual(M.toarray(), jit(todense)(*args))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype
        } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
                            for dtype in jtu.dtypes.floating +
                            jtu.dtypes.complex))
    def test_csr_fromdense(self, shape, dtype):
        rng = rand_sparse(self.rng())
        M = rng(shape, dtype)
        M_csr = sparse.csr_matrix(M)

        nnz = M_csr.nnz
        index_dtype = jnp.int32
        fromdense = lambda M: sparse_ops.csr_fromdense(
            M, nnz=nnz, index_dtype=jnp.int32)

        data, indices, indptr = fromdense(M)
        self.assertArraysEqual(data, M_csr.data.astype(dtype))
        self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype))
        self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype))

        data, indices, indptr = jit(fromdense)(M)
        self.assertArraysEqual(data, M_csr.data.astype(dtype))
        self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype))
        self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype),
                              transpose),
            "shape":
            shape,
            "dtype":
            dtype,
            "transpose":
            transpose
        } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
                            for dtype in jtu.dtypes.floating +
                            jtu.dtypes.complex for transpose in [True, False]))
    def test_csr_matvec(self, shape, dtype, transpose):
        op = lambda M: M.T if transpose else M

        v_rng = jtu.rand_default(self.rng())
        rng = rand_sparse(self.rng(), post=sparse.csr_matrix)
        M = rng(shape, dtype)
        v = v_rng(op(M).shape[1], dtype)

        args = (M.data, M.indices, M.indptr, v)
        matvec = lambda *args: sparse_ops.csr_matvec(
            *args, shape=M.shape, transpose=transpose)

        self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
        self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype),
                              transpose),
            "shape":
            shape,
            "dtype":
            dtype,
            "transpose":
            transpose
        } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
                            for dtype in jtu.dtypes.floating +
                            jtu.dtypes.complex for transpose in [True, False]))
    def test_csr_matmat(self, shape, dtype, transpose):
        op = lambda M: M.T if transpose else M

        B_rng = jtu.rand_default(self.rng())
        rng = rand_sparse(self.rng(), post=sparse.csr_matrix)
        M = rng(shape, dtype)
        B = B_rng((op(M).shape[1], 4), dtype)

        args = (M.data, M.indices, M.indptr, B)
        matmat = lambda *args: sparse_ops.csr_matmat(
            *args, shape=shape, transpose=transpose)

        self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
        self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype
        } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
                            for dtype in jtu.dtypes.floating +
                            jtu.dtypes.complex))
    def test_coo_todense(self, shape, dtype):
        rng = rand_sparse(self.rng(), post=sparse.coo_matrix)
        M = rng(shape, dtype)

        args = (M.data, M.row, M.col)
        todense = lambda *args: sparse_ops.coo_todense(*args, shape=M.shape)

        self.assertArraysEqual(M.toarray(), todense(*args))
        self.assertArraysEqual(M.toarray(), jit(todense)(*args))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype
        } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
                            for dtype in jtu.dtypes.floating +
                            jtu.dtypes.complex))
    def test_coo_fromdense(self, shape, dtype):
        rng = rand_sparse(self.rng())
        M = rng(shape, dtype)
        M_coo = sparse.coo_matrix(M)

        nnz = M_coo.nnz
        index_dtype = jnp.int32
        fromdense = lambda M: sparse_ops.coo_fromdense(
            M, nnz=nnz, index_dtype=jnp.int32)

        data, row, col = fromdense(M)
        self.assertArraysEqual(data, M_coo.data.astype(dtype))
        self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
        self.assertArraysEqual(col, M_coo.col.astype(index_dtype))

        data, indices, indptr = jit(fromdense)(M)
        self.assertArraysEqual(data, M_coo.data.astype(dtype))
        self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
        self.assertArraysEqual(col, M_coo.col.astype(index_dtype))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype),
                              transpose),
            "shape":
            shape,
            "dtype":
            dtype,
            "transpose":
            transpose
        } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
                            for dtype in jtu.dtypes.floating +
                            jtu.dtypes.complex for transpose in [True, False]))
    def test_coo_matvec(self, shape, dtype, transpose):
        op = lambda M: M.T if transpose else M

        v_rng = jtu.rand_default(self.rng())
        rng = rand_sparse(self.rng(), post=sparse.coo_matrix)
        M = rng(shape, dtype)
        v = v_rng(op(M).shape[1], dtype)

        args = (M.data, M.row, M.col, v)
        matvec = lambda *args: sparse_ops.coo_matvec(
            *args, shape=M.shape, transpose=transpose)

        self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
        self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype),
                              transpose),
            "shape":
            shape,
            "dtype":
            dtype,
            "transpose":
            transpose
        } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
                            for dtype in jtu.dtypes.floating +
                            jtu.dtypes.complex for transpose in [True, False]))
    def test_coo_matmat(self, shape, dtype, transpose):
        op = lambda M: M.T if transpose else M

        B_rng = jtu.rand_default(self.rng())
        rng = rand_sparse(self.rng(), post=sparse.coo_matrix)
        M = rng(shape, dtype)
        B = B_rng((op(M).shape[1], 4), dtype)

        args = (M.data, M.row, M.col, B)
        matmat = lambda *args: sparse_ops.coo_matmat(
            *args, shape=shape, transpose=transpose)

        self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
        self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)

    @unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
    def test_gpu_translation_rule(self):
        version = xla_bridge.get_backend().platform_version
        cuda_version = None if version == "<unknown>" else int(
            version.split()[-1])
        if cuda_version is None or cuda_version < 11000:
            self.assertFalse(cusparse and cusparse.is_supported)
            self.assertNotIn(sparse_ops.csr_todense_p,
                             xla.backend_specific_translations["gpu"])
        else:
            self.assertTrue(cusparse and cusparse.is_supported)
            self.assertIn(sparse_ops.csr_todense_p,
                          xla.backend_specific_translations["gpu"])

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_{}".format(jtu.format_shape_dtype_string(shape, dtype),
                            mat_type),
            "shape":
            shape,
            "dtype":
            dtype,
            "mat_type":
            mat_type
        } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
                            for dtype in jtu.dtypes.floating +
                            jtu.dtypes.complex for mat_type in ['csr', 'coo']))
    def test_extra_nnz(self, shape, dtype, mat_type):
        rng = rand_sparse(self.rng())
        M = rng(shape, dtype)
        nnz = (M != 0).sum() + 5
        fromdense = getattr(sparse_ops, f"{mat_type}_fromdense")
        todense = getattr(sparse_ops, f"{mat_type}_todense")
        args = fromdense(M, nnz=nnz, index_dtype=jnp.int32)
        M_out = todense(*args, shape=M.shape)
        self.assertArraysEqual(M, M_out)
Пример #10
0
class LaxBackedScipyStatsTests(jtu.JaxTestCase):
    """Tests for LAX-backed scipy.stats implementations"""
    @genNamedParametersNArgs(3)
    def testPoissonLogPmf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.poisson.logpmf
        lax_fun = lsp_stats.poisson.logpmf

        def args_maker():
            k, mu, loc = map(rng, shapes, dtypes)
            k = np.floor(k)
            # clipping to ensure that rate parameter is strictly positive
            mu = np.clip(np.abs(mu), a_min=0.1, a_max=None)
            loc = np.floor(loc)
            return [k, mu, loc]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14})

    @genNamedParametersNArgs(3)
    def testPoissonPmf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.poisson.pmf
        lax_fun = lsp_stats.poisson.pmf

        def args_maker():
            k, mu, loc = map(rng, shapes, dtypes)
            k = np.floor(k)
            # clipping to ensure that rate parameter is strictly positive
            mu = np.clip(np.abs(mu), a_min=0.1, a_max=None)
            loc = np.floor(loc)
            return [k, mu, loc]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testPoissonCdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.poisson.cdf
        lax_fun = lsp_stats.poisson.cdf

        def args_maker():
            k, mu, loc = map(rng, shapes, dtypes)
            # clipping to ensure that rate parameter is strictly positive
            mu = np.clip(np.abs(mu), a_min=0.1, a_max=None)
            return [k, mu, loc]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testBernoulliLogPmf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.bernoulli.logpmf
        lax_fun = lsp_stats.bernoulli.logpmf

        def args_maker():
            x, logit, loc = map(rng, shapes, dtypes)
            x = np.floor(x)
            p = expit(logit)
            loc = np.floor(loc)
            return [x, p, loc]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testGeomLogPmf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.geom.logpmf
        lax_fun = lsp_stats.geom.logpmf

        def args_maker():
            x, logit, loc = map(rng, shapes, dtypes)
            x = np.floor(x)
            p = expit(logit)
            loc = np.floor(loc)
            return [x, p, loc]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(5)
    def testBetaLogPdf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.beta.logpdf
        lax_fun = lsp_stats.beta.logpdf

        def args_maker():
            x, a, b, loc, scale = map(rng, shapes, dtypes)
            return [x, a, b, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun,
                              args_maker,
                              rtol={
                                  np.float32: 2e-3,
                                  np.float64: 1e-4
                              })

    @genNamedParametersNArgs(3)
    def testCauchyLogPdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.cauchy.logpdf
        lax_fun = lsp_stats.cauchy.logpdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(2)
    def testDirichletLogPdf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.cauchy.logpdf
        lax_fun = lsp_stats.cauchy.logpdf
        dim = 4
        shapes = (shapes[0] + (dim, ), shapes[1] + (dim, ))

        def args_maker():
            x, alpha = map(rng, shapes, dtypes)
            x = x / np.sum(x, axis=-1, keepdims=True)
            return [x, alpha]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testExponLogPdf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.expon.logpdf
        lax_fun = lsp_stats.expon.logpdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(4)
    def testGammaLogPdf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.gamma.logpdf
        lax_fun = lsp_stats.gamma.logpdf

        def args_maker():
            x, a, loc, scale = map(rng, shapes, dtypes)
            return [x, a, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=5e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testLaplaceLogPdf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.laplace.logpdf
        lax_fun = lsp_stats.laplace.logpdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = np.clip(scale, a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testLaplaceCdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.laplace.cdf
        lax_fun = lsp_stats.laplace.cdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # ensure that scale is not too low
            scale = np.clip(scale, a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol={
                                    np.float32: 1e-5,
                                    np.float64: 1e-6
                                })
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(1)
    def testLogisticCdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.logistic.cdf
        lax_fun = lsp_stats.logistic.cdf

        def args_maker():
            return list(map(rng, shapes, dtypes))

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-6)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(1)
    def testLogisticLogpdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.logistic.logpdf
        lax_fun = lsp_stats.logistic.logpdf

        def args_maker():
            return list(map(rng, shapes, dtypes))

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(1)
    def testLogisticPpf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.logistic.ppf
        lax_fun = lsp_stats.logistic.ppf

        def args_maker():
            return list(map(rng, shapes, dtypes))

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(1)
    def testLogisticSf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.logistic.sf
        lax_fun = lsp_stats.logistic.sf

        def args_maker():
            return list(map(rng, shapes, dtypes))

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-6)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testNormLogPdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.norm.logpdf
        lax_fun = lsp_stats.norm.logpdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testNormLogCdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.norm.logcdf
        lax_fun = lsp_stats.norm.logcdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testNormCdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.norm.cdf
        lax_fun = lsp_stats.norm.cdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-6)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testNormPpf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.norm.ppf
        lax_fun = lsp_stats.norm.ppf

        def args_maker():
            q, loc, scale = map(rng, shapes, dtypes)
            # ensure probability is between 0 and 1:
            q = np.clip(np.abs(q / 3), a_min=None, a_max=1)
            # clipping to ensure that scale is not too low
            scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
            return [q, loc, scale]

        self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)

    @genNamedParametersNArgs(4)
    def testParetoLogPdf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.pareto.logpdf
        lax_fun = lsp_stats.pareto.logpdf

        def args_maker():
            x, b, loc, scale = map(rng, shapes, dtypes)
            return [x, b, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(4)
    def testTLogPdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.t.logpdf
        lax_fun = lsp_stats.t.logpdf

        def args_maker():
            x, df, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
            return [x, df, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun,
                              args_maker,
                              rtol={np.float64: 1e-14},
                              atol={np.float64: 1e-14})

    @genNamedParametersNArgs(3)
    def testUniformLogPdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.uniform.logpdf
        lax_fun = lsp_stats.uniform.logpdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            return [x, loc, np.abs(scale)]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(4)
    def testChi2LogPdf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.chi2.logpdf
        lax_fun = lsp_stats.chi2.logpdf

        def args_maker():
            x, df, loc, scale = map(rng, shapes, dtypes)
            return [x, df, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=5e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    def testIssue972(self):
        self.assertAllClose(np.ones((4, ), np.float32),
                            lsp_stats.norm.cdf(
                                np.full((4, ), np.inf, np.float32)),
                            check_dtypes=False)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_x={}_mean={}_cov={}".format(
                jtu.format_shape_dtype_string(x_shape, x_dtype),
                jtu.format_shape_dtype_string(mean_shape, mean_dtype)
                if mean_shape is not None else None,
                jtu.format_shape_dtype_string(cov_shape, cov_dtype)
                if cov_shape is not None else None),
            "x_shape":
            x_shape,
            "x_dtype":
            x_dtype,
            "mean_shape":
            mean_shape,
            "mean_dtype":
            mean_dtype,
            "cov_shape":
            cov_shape,
            "cov_dtype":
            cov_dtype
        } for x_shape, mean_shape, cov_shape in [
            # # These test cases cover default values for mean/cov, but we don't
            # # support those yet (and they seem not very valuable).
            # [(), None, None],
            # [(), (), None],
            # [(2,), None, None],
            # [(2,), (), None],
            # [(2,), (2,), None],
            # [(3, 2), (3, 2,), None],
            # [(5, 3, 2), (5, 3, 2,), None],
            [(), (), ()],
            [(3, ), (), ()],
            [(3, ), (3, ), ()],
            [(3, ), (3, ), (3, 3)],
            [(3, 4), (4, ), (4, 4)],

            # # These test cases are where scipy flattens things, which has
            # # different batch semantics than some might expect
            # [(5, 3, 2), (5, 3, 2,), ()],
            # [(5, 3, 2), (5, 3, 2,), (5, 3, 2, 2)],
            # [(5, 3, 2), (3, 2,), (5, 3, 2, 2)],
            # [(5, 3, 2), (3, 2,), (2, 2)],
        ] for x_dtype, mean_dtype, cov_dtype in
                            itertools.combinations_with_replacement(
                                jtu.dtypes.floating, 3)
                            if (mean_shape is not None
                                or mean_dtype == np.float32) and
                            (cov_shape is not None or cov_dtype == np.float32))
    )
    def testMultivariateNormalLogpdf(self, x_shape, x_dtype, mean_shape,
                                     mean_dtype, cov_shape, cov_dtype):
        rng = jtu.rand_default(self.rng())

        def args_maker():
            args = [rng(x_shape, x_dtype)]
            if mean_shape is not None:
                args.append(5 * rng(mean_shape, mean_dtype))
            if cov_shape is not None:
                if cov_shape == ():
                    args.append(0.1 + rng(cov_shape, cov_dtype)**2)
                else:
                    factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1])
                    factor = rng(factor_shape, cov_dtype)
                    args.append(np.matmul(factor, np.swapaxes(factor, -1, -2)))
            return args

        self._CheckAgainstNumpy(osp_stats.multivariate_normal.logpdf,
                                lsp_stats.multivariate_normal.logpdf,
                                args_maker,
                                tol=1e-3)
        self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf,
                              args_maker,
                              rtol=1e-4,
                              atol=1e-4)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_ndim={}_nbatch={}_dtype={}".format(ndim, nbatch, dtype.__name__),
            "ndim":
            ndim,
            "nbatch":
            nbatch,
            "dtype":
            dtype
        } for ndim in [2, 3] for nbatch in [1, 3, 5]
                            for dtype in jtu.dtypes.floating))
    def testMultivariateNormalLogpdfBatch(self, ndim, nbatch, dtype):
        # Regression test for #5570
        rng = jtu.rand_default(self.rng())
        x = rng((nbatch, ndim), dtype)
        mean = 5 * rng((nbatch, ndim), dtype)
        factor = rng((nbatch, ndim, 2 * ndim), dtype)
        cov = factor @ factor.transpose(0, 2, 1)

        result1 = lsp_stats.multivariate_normal.logpdf(x, mean, cov)
        result2 = api.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov)
        self.assertArraysEqual(result1, result2)
Пример #11
0
class LaxBackedScipyTests(jtu.JaxTestCase):
    """Tests for LAX-backed Scipy implementation."""
    def _GetArgsMaker(self, rng, shapes, dtypes):
        return lambda: [
            rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)
        ]

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format(
                jtu.format_shape_dtype_string(shapes, dtype), axis, keepdims,
                return_sign, use_b),
            # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU.
            "shapes":
            shapes,
            "dtype":
            dtype,
            "axis":
            axis,
            "keepdims":
            keepdims,
            "return_sign":
            return_sign,
            "use_b":
            use_b
        } for shape_group in compatible_shapes for dtype in float_dtypes +
                            complex_dtypes + int_dtypes
                            for use_b in [False, True]
                            for shapes in itertools.product(
                                *((shape_group,
                                   shape_group) if use_b else (shape_group, )))
                            for axis in range(
                                -max(len(shape) for shape in shapes),
                                max(len(shape) for shape in shapes))
                            for keepdims in [False, True]
                            for return_sign in [False, True]))
    @jtu.ignore_warning(category=RuntimeWarning,
                        message="invalid value encountered in .*")
    def testLogSumExp(self, shapes, dtype, axis, keepdims, return_sign, use_b):
        if jtu.device_under_test() != "cpu":
            rng = jtu.rand_some_inf_and_nan(self.rng())
        else:
            rng = jtu.rand_default(self.rng())
        # TODO(mattjj): test autodiff
        if use_b:

            def scipy_fun(array_to_reduce, scale_array):
                return osp_special.logsumexp(array_to_reduce,
                                             axis,
                                             keepdims=keepdims,
                                             return_sign=return_sign,
                                             b=scale_array)

            def lax_fun(array_to_reduce, scale_array):
                return lsp_special.logsumexp(array_to_reduce,
                                             axis,
                                             keepdims=keepdims,
                                             return_sign=return_sign,
                                             b=scale_array)

            args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)]
        else:

            def scipy_fun(array_to_reduce):
                return osp_special.logsumexp(array_to_reduce,
                                             axis,
                                             keepdims=keepdims,
                                             return_sign=return_sign)

            def lax_fun(array_to_reduce):
                return lsp_special.logsumexp(array_to_reduce,
                                             axis,
                                             keepdims=keepdims,
                                             return_sign=return_sign)

            args_maker = lambda: [rng(shapes[0], dtype)]
        tol = {np.float32: 1E-6, np.float64: 1E-14}
        self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
        self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)

    def testLogSumExpZeros(self):
        # Regression test for https://github.com/google/jax/issues/5370
        scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b)
        lax_fun = lambda a, b: lsp_special.logsumexp(a, b=b)
        args_maker = lambda: [np.array([-1000, -2]), np.array([1, 0])]
        self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
        self._CompileAndCheck(lax_fun, args_maker)

    @parameterized.named_parameters(
        itertools.chain.from_iterable(
            jtu.cases_from_list(
                {
                    "testcase_name":
                    jtu.format_test_name_suffix(rec.test_name, shapes, dtypes),
                    "rng_factory":
                    rec.rng_factory,
                    "shapes":
                    shapes,
                    "dtypes":
                    dtypes,
                    "test_autodiff":
                    rec.test_autodiff,
                    "nondiff_argnums":
                    rec.nondiff_argnums,
                    "scipy_op":
                    getattr(osp_special, rec.name),
                    "lax_op":
                    getattr(lsp_special, rec.name)
                } for shapes in itertools.combinations_with_replacement(
                    all_shapes, rec.nargs)
                for dtypes in (itertools.combinations_with_replacement(
                    rec.dtypes, rec.nargs) if isinstance(rec.dtypes, list) else
                               itertools.product(*rec.dtypes)))
            for rec in JAX_SPECIAL_FUNCTION_RECORDS))
    def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes,
                            dtypes, test_autodiff, nondiff_argnums):
        if (jtu.device_under_test() == "cpu"
                and (lax_op is lsp_special.gammainc
                     or lax_op is lsp_special.gammaincc)):
            # TODO(b/173608403): re-enable test when LLVM bug is fixed.
            raise unittest.SkipTest("Skipping test due to LLVM lowering bug")
        rng = rng_factory(self.rng())
        args_maker = self._GetArgsMaker(rng, shapes, dtypes)
        args = args_maker()
        self.assertAllClose(scipy_op(*args),
                            lax_op(*args),
                            atol=1e-3,
                            rtol=1e-3,
                            check_dtypes=False)
        self._CompileAndCheck(lax_op, args_maker, rtol=1e-4)

        if test_autodiff:

            def partial_lax_op(*vals):
                list_args = list(vals)
                for i in nondiff_argnums:
                    list_args.insert(i, args[i])
                return lax_op(*list_args)

            assert list(nondiff_argnums) == sorted(set(nondiff_argnums))
            diff_args = [
                x for i, x in enumerate(args) if i not in nondiff_argnums
            ]
            jtu.check_grads(partial_lax_op,
                            diff_args,
                            order=1,
                            atol=jtu.if_device_under_test("tpu", .1, 1e-3),
                            rtol=.1,
                            eps=1e-3)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_d={}".format(
                jtu.format_shape_dtype_string(shape, dtype), d),
            "shape":
            shape,
            "dtype":
            dtype,
            "d":
            d
        } for shape in all_shapes for dtype in float_dtypes
                            for d in [1, 2, 5]))
    def testMultigammaln(self, shape, dtype, d):
        def scipy_fun(a):
            return osp_special.multigammaln(a, d)

        def lax_fun(a):
            return lsp_special.multigammaln(a, d)

        rng = jtu.rand_positive(self.rng())
        args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                tol={
                                    np.float32: 1e-3,
                                    np.float64: 1e-14
                                })
        self._CompileAndCheck(lax_fun, args_maker)

    def testIssue980(self):
        x = np.full((4, ), -1e20, dtype=np.float32)
        self.assertAllClose(np.zeros((4, ), dtype=np.float32),
                            lsp_special.expit(x))

    def testIssue3758(self):
        x = np.array([1e5, 1e19, 1e10], dtype=np.float32)
        q = np.array([1., 40., 30.], dtype=np.float32)
        self.assertAllClose(np.array([1., 0., 0.], dtype=np.float32),
                            lsp_special.zeta(x, q))

    def testXlogyShouldReturnZero(self):
        self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False)

    def testGradOfXlogyAtZero(self):
        partial_xlogy = functools.partial(lsp_special.xlogy, 0.)
        self.assertAllClose(api.grad(partial_xlogy)(0.),
                            0.,
                            check_dtypes=False)

    def testXlog1pyShouldReturnZero(self):
        self.assertAllClose(lsp_special.xlog1py(0., -1.),
                            0.,
                            check_dtypes=False)

    def testGradOfXlog1pyAtZero(self):
        partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.)
        self.assertAllClose(api.grad(partial_xlog1py)(-1.),
                            0.,
                            check_dtypes=False)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_maxdegree={}_inputsize={}".format(
                    l_max, num_z),
                "l_max": l_max,
                "num_z": num_z
            } for l_max, num_z in zip([1, 2, 3], [6, 7, 8])))
    def testLpmn(self, l_max, num_z):
        # Points on which the associated Legendre functions areevaluated.
        z = np.linspace(-0.2, 0.9, num_z)
        actual_p_vals, actual_p_derivatives = lsp_special.lpmn(m=l_max,
                                                               n=l_max,
                                                               z=z)

        # The expected results are obtained from scipy.
        expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z))
        expected_p_derivatives = np.zeros((l_max + 1, l_max + 1, num_z))

        for i in range(num_z):
            val, derivative = osp_special.lpmn(l_max, l_max, z[i])
            expected_p_vals[:, :, i] = val
            expected_p_derivatives[:, :, i] = derivative

        with self.subTest('Test values.'):
            self.assertAllClose(actual_p_vals,
                                expected_p_vals,
                                rtol=1e-6,
                                atol=3.2e-6)

        with self.subTest('Test derivatives.'):
            self.assertAllClose(actual_p_derivatives,
                                expected_p_derivatives,
                                rtol=1e-6,
                                atol=8.4e-4)

        with self.subTest('Test JIT compatibility'):
            args_maker = lambda: [z]
            lsp_special_fn = lambda z: lsp_special.lpmn(l_max, l_max, z)
            self._CompileAndCheck(lsp_special_fn, args_maker)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_maxdegree={}_inputsize={}".format(
                    l_max, num_z),
                "l_max": l_max,
                "num_z": num_z
            } for l_max, num_z in zip([3, 4, 6, 32], [2, 3, 4, 64])))
    def testNormalizedLpmnValues(self, l_max, num_z):
        # Points on which the associated Legendre functions areevaluated.
        z = np.linspace(-0.2, 0.9, num_z)
        is_normalized = True
        actual_p_vals = lsp_special.lpmn_values(l_max, l_max, z, is_normalized)

        # The expected results are obtained from scipy.
        expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z))
        for i in range(num_z):
            expected_p_vals[:, :, i] = osp_special.lpmn(l_max, l_max, z[i])[0]

        def apply_normalization(a):
            """Applies normalization to the associated Legendre functions."""
            num_m, num_l, _ = a.shape
            a_normalized = np.zeros_like(a)
            for m in range(num_m):
                for l in range(num_l):
                    c0 = (2.0 * l + 1.0) * osp_special.factorial(l - m)
                    c1 = (4.0 * np.pi) * osp_special.factorial(l + m)
                    c2 = np.sqrt(c0 / c1)
                    a_normalized[m, l] = c2 * a[m, l]
            return a_normalized

        # The results from scipy are not normalized and the comparison requires
        # normalizing the results.
        expected_p_vals_normalized = apply_normalization(expected_p_vals)

        with self.subTest('Test accuracy.'):
            self.assertAllClose(actual_p_vals,
                                expected_p_vals_normalized,
                                rtol=1e-6,
                                atol=3.2e-6)

        with self.subTest('Test JIT compatibility'):
            args_maker = lambda: [z]
            lsp_special_fn = lambda z: lsp_special.lpmn_values(
                l_max, l_max, z, is_normalized)
            self._CompileAndCheck(lsp_special_fn, args_maker)

    def testSphHarmAccuracy(self):
        m = jnp.arange(-3, 3)[:, None]
        n = jnp.arange(3, 6)
        n_max = 5
        theta = 0.0
        phi = jnp.pi

        expected = lsp_special.sph_harm(m, n, theta, phi, n_max)

        actual = osp_special.sph_harm(m, n, theta, phi)

        self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)

    def testSphHarmOrderZeroDegreeZero(self):
        """Tests the spherical harmonics of order zero and degree zero."""
        theta = jnp.array([0.3])
        phi = jnp.array([2.3])
        n_max = 0

        expected = jnp.array([1.0 / jnp.sqrt(4.0 * np.pi)])
        actual = jnp.real(
            lsp_special.sph_harm(jnp.array([0]), jnp.array([0]), theta, phi,
                                 n_max))

        self.assertAllClose(actual, expected, rtol=1.1e-7, atol=3e-8)

    def testSphHarmOrderZeroDegreeOne(self):
        """Tests the spherical harmonics of order one and degree zero."""
        theta = jnp.array([2.0])
        phi = jnp.array([3.1])
        n_max = 1

        expected = jnp.sqrt(3.0 / (4.0 * np.pi)) * jnp.cos(phi)
        actual = jnp.real(
            lsp_special.sph_harm(jnp.array([0]), jnp.array([1]), theta, phi,
                                 n_max))

        self.assertAllClose(actual, expected, rtol=7e-8, atol=1.5e-8)

    def testSphHarmOrderOneDegreeOne(self):
        """Tests the spherical harmonics of order one and degree one."""
        theta = jnp.array([2.0])
        phi = jnp.array([2.5])
        n_max = 1

        expected = (-1.0 / 2.0 * jnp.sqrt(3.0 / (2.0 * np.pi)) * jnp.sin(phi) *
                    jnp.exp(1j * theta))
        actual = lsp_special.sph_harm(jnp.array([1]), jnp.array([1]), theta,
                                      phi, n_max)

        self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            '_maxdegree={}_inputsize={}_dtype={}'.format(l_max, num_z, dtype),
            'l_max':
            l_max,
            'num_z':
            num_z,
            'dtype':
            dtype
        } for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8])
                            for dtype in jtu.dtypes.all_integer))
    def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype):
        """Tests against JIT compatibility and Numpy."""
        n_max = l_max
        shape = (num_z, )
        rng = jtu.rand_int(self.rng(), -l_max, l_max + 1)

        lsp_special_fn = partial(lsp_special.sph_harm, n_max=n_max)

        def args_maker():
            m = rng(shape, dtype)
            n = abs(m)
            theta = jnp.linspace(-4.0, 5.0, num_z)
            phi = jnp.linspace(-2.0, 1.0, num_z)
            return m, n, theta, phi

        with self.subTest('Test JIT compatibility'):
            self._CompileAndCheck(lsp_special_fn, args_maker)

        with self.subTest('Test against numpy.'):
            self._CheckAgainstNumpy(osp_special.sph_harm, lsp_special_fn,
                                    args_maker)

    def testSphHarmCornerCaseWithWrongNmax(self):
        """Tests the corner case where `n_max` is not the maximum value of `n`."""
        m = jnp.array([2])
        n = jnp.array([10])
        n_clipped = jnp.array([6])
        n_max = 6
        theta = jnp.array([0.9])
        phi = jnp.array([0.2])

        expected = lsp_special.sph_harm(m, n, theta, phi, n_max)

        actual = lsp_special.sph_harm(m, n_clipped, theta, phi, n_max)

        self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name':
                '_shape={}'
                '_n_zero_sv={}_degeneracy={}_geometric_spectrum={}'
                '_max_sv={}_method={}_side={}'
                '_nonzero_condition_number={}_seed={}'.format(
                    jtu.format_shape_dtype_string(
                        shape,
                        jnp.dtype(dtype).name).replace(" ", ""), n_zero_sv,
                    degeneracy, geometric_spectrum, max_sv, method, side,
                    nonzero_condition_number, seed),
                'n_zero_sv':
                n_zero_sv,
                'degeneracy':
                degeneracy,
                'geometric_spectrum':
                geometric_spectrum,
                'max_sv':
                max_sv,
                'shape':
                shape,
                'method':
                method,
                'side':
                side,
                'nonzero_condition_number':
                nonzero_condition_number,
                'dtype':
                dtype,
                'seed':
                seed
            } for n_zero_sv in n_zero_svs for degeneracy in degeneracies
            for geometric_spectrum in geometric_spectra for max_sv in max_svs
            for shape in polar_shapes for method in methods for side in sides
            for nonzero_condition_number in nonzero_condition_numbers
            for dtype in jtu.dtypes.floating for seed in seeds))
    def testPolar(self, n_zero_sv, degeneracy, geometric_spectrum, max_sv,
                  shape, method, side, nonzero_condition_number, dtype, seed):
        """ Tests jax.scipy.linalg.polar."""
        if jtu.device_under_test() != "cpu":
            if jnp.dtype(dtype).name in ("bfloat16", "float16"):
                raise unittest.SkipTest("Skip half precision off CPU.")
            if method == "svd":
                raise unittest.SkipTest("Can't use SVD mode on TPU/GPU.")

        np.random.seed(seed)
        matrix, _ = _initialize_polar_test(shape, n_zero_sv, degeneracy,
                                           geometric_spectrum, max_sv,
                                           nonzero_condition_number, dtype)
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            self.assertRaises(NotImplementedError,
                              jsp.linalg.polar,
                              matrix,
                              method=method,
                              side=side)
            return

        unitary, posdef = jsp.linalg.polar(matrix, method=method, side=side)
        if shape[0] >= shape[1]:
            should_be_eye = np.matmul(unitary.conj().T, unitary)
        else:
            should_be_eye = np.matmul(unitary, unitary.conj().T)
        tol = 10 * jnp.finfo(matrix.dtype).eps
        eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype)
        with self.subTest('Test unitarity.'):
            self.assertAllClose(eye_mat, should_be_eye, atol=tol * min(shape))

        with self.subTest('Test Hermiticity.'):
            self.assertAllClose(posdef,
                                posdef.conj().T,
                                atol=tol * jnp.linalg.norm(posdef))

        ev, _ = np.linalg.eigh(posdef)
        ev = ev[np.abs(ev) > tol * np.linalg.norm(posdef)]
        negative_ev = jnp.sum(ev < 0.)
        with self.subTest('Test positive definiteness.'):
            assert negative_ev == 0.

        if side == "right":
            recon = jnp.matmul(unitary,
                               posdef,
                               precision=lax.Precision.HIGHEST)
        elif side == "left":
            recon = jnp.matmul(posdef,
                               unitary,
                               precision=lax.Precision.HIGHEST)
        with self.subTest('Test reconstruction.'):
            self.assertAllClose(matrix,
                                recon,
                                atol=tol * jnp.linalg.norm(matrix))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            '_linear_size_={}_seed={}_dtype={}'.format(linear_size, seed,
                                                       jnp.dtype(dtype).name),
            'linear_size':
            linear_size,
            'seed':
            seed,
            'dtype':
            dtype
        } for linear_size in linear_sizes for seed in seeds
                            for dtype in jtu.dtypes.floating))
    def test_spectral_dac_eigh(self, linear_size, seed, dtype):
        if jtu.device_under_test != "cpu":
            raise unittest.SkipTest("Skip eigh off CPU for now.")
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            if jtu.device_under_test() != "cpu":
                raise unittest.SkipTest("Skip half precision off CPU.")

        np.random.seed(seed)
        H = np.random.randn(linear_size, linear_size)
        H = jnp.array(0.5 * (H + H.conj().T)).astype(dtype)
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            self.assertRaises(NotImplementedError, jax._src.scipy.eigh.eigh, H)
            return
        evs, V = jax._src.scipy.eigh.eigh(H)
        ev_exp, eV_exp = jnp.linalg.eigh(H)
        HV = jnp.dot(H, V, precision=lax.Precision.HIGHEST)
        vV = evs * V
        eps = jnp.finfo(H.dtype).eps
        atol = jnp.linalg.norm(H) * eps
        self.assertAllClose(ev_exp, jnp.sort(evs), atol=20 * atol)
        self.assertAllClose(HV, vV, atol=30 * atol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            '_linear_size_={}_seed={}_dtype={}'.format(linear_size, seed,
                                                       jnp.dtype(dtype).name),
            'linear_size':
            linear_size,
            'seed':
            seed,
            'dtype':
            dtype
        } for linear_size in linear_sizes for seed in seeds
                            for dtype in jtu.dtypes.floating))
    def test_spectral_dac_svd(self, linear_size, seed, dtype):
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            if jtu.device_under_test() != "cpu":
                raise unittest.SkipTest("Skip half precision off CPU.")

        np.random.seed(seed)
        A = np.random.randn(linear_size, linear_size).astype(dtype)
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            self.assertRaises(NotImplementedError, jax._src.scipy.eigh.svd, A)
            return
        S_expected = np.linalg.svd(A, compute_uv=False)
        U, S, V = jax._src.scipy.eigh.svd(A)
        recon = jnp.dot((U * S), V, precision=lax.Precision.HIGHEST)
        eps = jnp.finfo(dtype).eps
        eps = eps * jnp.linalg.norm(A) * 10
        self.assertAllClose(np.sort(S), np.sort(S_expected), atol=eps)
        self.assertAllClose(A, recon, atol=eps)

        # U is unitary.
        u_unitary_delta = jnp.dot(U.conj().T,
                                  U,
                                  precision=lax.Precision.HIGHEST)
        u_eye = jnp.eye(u_unitary_delta.shape[0], dtype=dtype)
        self.assertAllClose(u_unitary_delta, u_eye, atol=eps)

        # V is unitary.
        v_unitary_delta = jnp.dot(V.conj().T,
                                  V,
                                  precision=lax.Precision.HIGHEST)
        v_eye = jnp.eye(v_unitary_delta.shape[0], dtype=dtype)
        self.assertAllClose(v_unitary_delta, v_eye, atol=eps)
Пример #12
0
class DLPackTest(jtu.JaxTestCase):
    def setUp(self):
        super(DLPackTest, self).setUp()
        if jtu.device_under_test() == "tpu":
            self.skipTest("DLPack not supported on TPU")

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_take_ownership={}".format(
                jtu.format_shape_dtype_string(shape, dtype), take_ownership),
            "shape":
            shape,
            "dtype":
            dtype,
            "take_ownership":
            take_ownership
        } for shape in all_shapes for dtype in dlpack_dtypes
                            for take_ownership in [False, True]))
    def testJaxRoundTrip(self, shape, dtype, take_ownership):
        rng = jtu.rand_default(self.rng())
        np = rng(shape, dtype)
        x = jnp.array(np)
        dlpack = jax.dlpack.to_dlpack(x, take_ownership=take_ownership)
        self.assertEqual(take_ownership, x.device_buffer.is_deleted())
        y = jax.dlpack.from_dlpack(dlpack)
        self.assertAllClose(np.astype(x.dtype), y)

        self.assertRaisesRegex(RuntimeError,
                               "DLPack tensor may be consumed at most once",
                               lambda: jax.dlpack.from_dlpack(dlpack))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype
        } for shape in all_shapes for dtype in dlpack_dtypes))
    @unittest.skipIf(not tf, "Test requires TensorFlow")
    def testTensorFlowToJax(self, shape, dtype):
        if not config.x64_enabled and dtype in [
                jnp.int64, jnp.uint64, jnp.float64
        ]:
            raise self.skipTest("x64 types are disabled by jax_enable_x64")
        if (jtu.device_under_test() == "gpu"
                and not tf.config.list_physical_devices("GPU")):
            raise self.skipTest("TensorFlow not configured with GPU support")

        rng = jtu.rand_default(self.rng())
        np = rng(shape, dtype)
        with tf.device("/GPU:0" if jtu.device_under_test() ==
                       "gpu" else "/CPU:0"):
            x = tf.constant(np)
        dlpack = tf.experimental.dlpack.to_dlpack(x)
        y = jax.dlpack.from_dlpack(dlpack)
        self.assertAllClose(np, y)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype
        } for shape in all_shapes for dtype in dlpack_dtypes))
    @unittest.skipIf(not tf, "Test requires TensorFlow")
    def testJaxToTensorFlow(self, shape, dtype):
        if not config.x64_enabled and dtype in [
                jnp.int64, jnp.uint64, jnp.float64
        ]:
            self.skipTest("x64 types are disabled by jax_enable_x64")
        if (jtu.device_under_test() == "gpu"
                and not tf.config.list_physical_devices("GPU")):
            raise self.skipTest("TensorFlow not configured with GPU support")
        rng = jtu.rand_default(self.rng())
        np = rng(shape, dtype)
        x = jnp.array(np)
        # TODO(b/171320191): this line works around a missing context initialization
        # bug in TensorFlow.
        _ = tf.add(1, 1)
        dlpack = jax.dlpack.to_dlpack(x)
        y = tf.experimental.dlpack.from_dlpack(dlpack)
        self.assertAllClose(np, y.numpy())

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype
        } for shape in all_shapes for dtype in torch_dtypes))
    @unittest.skipIf(not torch, "Test requires PyTorch")
    def testTorchToJax(self, shape, dtype):
        if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]:
            self.skipTest("x64 types are disabled by jax_enable_x64")
        rng = jtu.rand_default(self.rng())
        np = rng(shape, dtype)
        x = torch.from_numpy(np)
        x = x.cuda() if jtu.device_under_test() == "gpu" else x
        dlpack = torch.utils.dlpack.to_dlpack(x)
        y = jax.dlpack.from_dlpack(dlpack)
        self.assertAllClose(np, y)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype
        } for shape in all_shapes for dtype in torch_dtypes))
    @unittest.skipIf(not torch, "Test requires PyTorch")
    def testJaxToTorch(self, shape, dtype):
        if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]:
            self.skipTest("x64 types are disabled by jax_enable_x64")
        rng = jtu.rand_default(self.rng())
        np = rng(shape, dtype)
        x = jnp.array(np)
        dlpack = jax.dlpack.to_dlpack(x)
        y = torch.utils.dlpack.from_dlpack(dlpack)
        self.assertAllClose(np, y.cpu().numpy())
Пример #13
0
class LaxBackedNumpyTests(jtu.JaxTestCase):
    """Tests for LAX-backed Numpy implementation."""
    def _GetArgsMaker(self, rng, shapes, dtypes):
        return lambda: [
            rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)
        ]

    @parameterized.named_parameters(
        itertools.chain.from_iterable(
            jtu.cases_from_list(
                {
                    "testcase_name":
                    jtu.format_test_name_suffix(rec.test_name, shapes, dtypes),
                    "rng":
                    rec.rng,
                    "shapes":
                    shapes,
                    "dtypes":
                    dtypes,
                    "onp_op":
                    getattr(onp, rec.name),
                    "lnp_op":
                    getattr(lnp, rec.name)
                } for shapes in filter(
                    _shapes_are_broadcast_compatible,
                    CombosWithReplacement(rec.shapes, rec.nargs))
                for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
            for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS,
                                       JAX_COMPOUND_OP_RECORDS)))
    def testOp(self, onp_op, lnp_op, rng, shapes, dtypes):
        args_maker = self._GetArgsMaker(rng, shapes, dtypes)
        self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
        self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        itertools.chain.from_iterable(
            jtu.cases_from_list(
                {
                    "testcase_name":
                    jtu.format_test_name_suffix(rec.test_name, shapes, dtypes),
                    "rng":
                    rec.rng,
                    "shapes":
                    shapes,
                    "dtypes":
                    dtypes,
                    "onp_op":
                    getattr(onp, rec.name),
                    "lnp_op":
                    getattr(lnp, rec.name)
                } for shapes in filter(
                    _shapes_are_broadcast_compatible,
                    CombosWithReplacement(rec.shapes, rec.nargs))
                for dtypes in filter(
                    _dtypes_are_compatible_for_bitwise_ops,
                    CombosWithReplacement(rec.dtypes, rec.nargs)))
            for rec in JAX_BITWISE_OP_RECORDS))
    def testBitwiseOp(self, onp_op, lnp_op, rng, shapes, dtypes):
        if not FLAGS.jax_enable_x64 and any(
                onp.iinfo(dtype).bits == 64 for dtype in dtypes):
            self.skipTest("x64 types are disabled by jax_enable_x64")
        args_maker = self._GetArgsMaker(rng, shapes, dtypes)
        self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
        self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "{}_inshape={}_axis={}_dtype={}_keepdims={}".format(
                rec.test_name.capitalize(),
                jtu.format_shape_dtype_string(shape, dtype), axis, "None"
                if out_dtype is None else onp.dtype(out_dtype).name, keepdims),
            "rng":
            rec.rng,
            "shape":
            shape,
            "dtype":
            dtype,
            "out_dtype":
            out_dtype,
            "onp_op":
            getattr(onp, rec.name),
            "lnp_op":
            getattr(lnp, rec.name),
            "axis":
            axis,
            "keepdims":
            keepdims
        } for rec in JAX_REDUCER_RECORDS for shape in rec.shapes
                            for dtype in rec.dtypes
                            for out_dtype in [None] + rec.dtypes
                            for axis in range(-len(shape), len(shape))
                            for keepdims in [False, True]))
    def testReducer(self, onp_op, lnp_op, rng, shape, dtype, out_dtype, axis,
                    keepdims):
        onp_fun = lambda x: onp_op(x, axis, dtype=out_dtype, keepdims=keepdims)
        lnp_fun = lambda x: lnp_op(x, axis, dtype=out_dtype, keepdims=keepdims)
        args_maker = lambda: [rng(shape, dtype)]
        self._CheckAgainstNumpy(onp_fun,
                                lnp_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "{}_inshape={}_axis={}_keepdims={}".format(
                rec.test_name.capitalize(),
                jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
            "rng":
            rec.rng,
            "shape":
            shape,
            "dtype":
            dtype,
            "onp_op":
            getattr(onp, rec.name),
            "lnp_op":
            getattr(lnp, rec.name),
            "axis":
            axis,
            "keepdims":
            keepdims
        } for rec in JAX_REDUCER_NO_DTYPE_RECORDS for shape in rec.shapes
                            for dtype in rec.dtypes
                            for axis in range(-len(shape), len(shape))
                            for keepdims in [False, True]))
    def testReducerNoDtype(self, onp_op, lnp_op, rng, shape, dtype, axis,
                           keepdims):
        onp_fun = lambda x: onp_op(x, axis, keepdims=keepdims)
        lnp_fun = lambda x: lnp_op(x, axis, keepdims=keepdims)
        args_maker = lambda: [rng(shape, dtype)]
        self._CheckAgainstNumpy(onp_fun,
                                lnp_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "{}_inshape={}_axis={}".format(
                rec.test_name.capitalize(),
                jtu.format_shape_dtype_string(shape, dtype), axis),
            "rng":
            rec.rng,
            "shape":
            shape,
            "dtype":
            dtype,
            "onp_op":
            getattr(onp, rec.name),
            "lnp_op":
            getattr(lnp, rec.name),
            "axis":
            axis
        } for rec in JAX_ARGMINMAX_RECORDS for shape in rec.shapes
                            for dtype in rec.dtypes
                            for axis in range(-len(shape), len(shape))))
    def testArgMinMax(self, onp_op, lnp_op, rng, shape, dtype, axis):
        def onp_fun(array_to_reduce):
            return onp_op(array_to_reduce, axis)

        def lnp_fun(array_to_reduce):
            return lnp_op(array_to_reduce, axis)

        args_maker = lambda: [rng(shape, dtype)]
        self._CheckAgainstNumpy(onp_fun,
                                lnp_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_{}_{}".format(
                name, jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
                jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
            "lhs_shape":
            lhs_shape,
            "lhs_dtype":
            lhs_dtype,
            "rhs_shape":
            rhs_shape,
            "rhs_dtype":
            rhs_dtype,
            "rng":
            rng
        } for rng in [jtu.rand_default()] for name, lhs_shape, rhs_shape in [(
            "matrix-scalar", (3, 3),
            ()), ("scalar-matrix", (), (3, 3)), ("matrix-vector", (4, 5), (
                5, )), ("vector-matrix", (6, ), (
                    6,
                    4)), ("matrix-matrix", (3, 4),
                          (4,
                           5)), ("tensor-vector", (4, 3, 2),
                                 (2, )), (
                                     "vector-tensor", (2, ), (3, 2, 4)
                                 ), ("tensor-matrix", (4, 3, 2),
                                     (2,
                                      5)), (
                                          "matrix-tensor", (5,
                                                            2), (3, 2, 4)
                                      ), ("tensor-tensor", (2, 3, 4), (5, 4,
                                                                       1))]
                            for lhs_dtype, rhs_dtype in CombosWithReplacement(
                                float_dtypes, 2)))
    def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng):
        args_maker = lambda: [
            rng(lhs_shape, lhs_dtype),
            rng(rhs_shape, rhs_dtype)
        ]
        self._CheckAgainstNumpy(onp.dot,
                                lnp.dot,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lnp.dot, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_{}_{}".format(
                name, jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
                jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
            "lhs_shape":
            lhs_shape,
            "lhs_dtype":
            lhs_dtype,
            "rhs_shape":
            rhs_shape,
            "rhs_dtype":
            rhs_dtype,
            "rng":
            rng
        } for rng in [jtu.rand_default()] for name, lhs_shape, rhs_shape in [
            ("vector-vector", (3, ), (3, )), ("matrix-vector", (3,
                                                                3), (3, )),
            ("vector-matrix", (3, ),
             (3,
              3)), ("matrix-matrix", (3, 3),
                    (3,
                     3)), ("vector-tensor", (3, ),
                           (5, 3,
                            2)), ("tensor-vector", (5, 3, 2),
                                  (2, )), ("matrix-tensor", (5, 2), (3, 2, 4)),
            ("tensor-matrix", (5, 2, 3),
             (3,
              2)), ("tensor-tensor", (5, 3, 4),
                    (5, 4,
                     1)), ("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))
        ] for lhs_dtype, rhs_dtype in CombosWithReplacement(float_dtypes, 2)))
    def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng):
        args_maker = lambda: [
            rng(lhs_shape, lhs_dtype),
            rng(rhs_shape, rhs_dtype)
        ]
        self._CheckAgainstNumpy(onp.matmul,
                                lnp.matmul,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lnp.matmul, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_amin={}_amax={}".format(
                jtu.format_shape_dtype_string(shape, dtype), a_min, a_max),
            "shape":
            shape,
            "dtype":
            dtype,
            "a_min":
            a_min,
            "a_max":
            a_max,
            "rng":
            jtu.rand_default()
        } for shape in all_shapes for dtype in float_dtypes
                            for a_min, a_max in [(-1, None), (None, 1), (-1,
                                                                         1)]))
    def testClipStaticBounds(self, shape, dtype, a_min, a_max, rng):
        onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max)
        lnp_fun = lambda x: lnp.clip(x, a_min=a_min, a_max=a_max)
        args_maker = lambda: [rng(shape, dtype)]
        self._CheckAgainstNumpy(onp_fun,
                                lnp_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_decimals={}".format(
                jtu.format_shape_dtype_string(shape, dtype), decimals),
            "shape":
            shape,
            "dtype":
            dtype,
            "decimals":
            decimals,
            "rng":
            jtu.rand_default()
        } for shape in all_shapes for dtype in float_dtypes
                            for decimals in [0, 1, -2]))
    def testRoundStaticDecimals(self, shape, dtype, decimals, rng):
        onp_fun = lambda x: onp.round(x, decimals=decimals)
        lnp_fun = lambda x: lnp.round(x, decimals=decimals)
        args_maker = lambda: [rng(shape, dtype)]
        self._CheckAgainstNumpy(onp_fun,
                                lnp_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_axis={}_baseshape=[{}]_dtypes=[{}]".format(
                    axis, ",".join(str(d) for d in base_shape), ",".join(
                        onp.dtype(dtype).name for dtype in dtypes)),
                "axis":
                axis,
                "base_shape":
                base_shape,
                "dtypes":
                dtypes,
                "rng":
                jtu.rand_default()
            } for num_arrs in [3]
            for dtypes in CombosWithReplacement(default_dtypes, num_arrs)
            for base_shape in [(4, ), (3, 4), (2, 3, 4)]
            for axis in range(-len(base_shape) + 1, len(base_shape))))
    def testConcatenate(self, axis, base_shape, dtypes, rng):
        wrapped_axis = axis % len(base_shape)
        shapes = [
            base_shape[:wrapped_axis] + (size, ) +
            base_shape[wrapped_axis + 1:]
            for size, _ in zip(itertools.cycle([3, 1, 4]), dtypes)
        ]
        onp_fun = lambda *args: onp.concatenate(args, axis=axis)
        lnp_fun = lambda *args: lnp.concatenate(args, axis=axis)

        def args_maker():
            return [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]

        self._CheckAgainstNumpy(onp_fun,
                                lnp_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape=[{}]_axis={}_repeats={}".format(
                jtu.format_shape_dtype_string(shape, dtype), axis, repeats),
            "axis":
            axis,
            "shape":
            shape,
            "dtype":
            dtype,
            "repeats":
            repeats,
            "rng":
            jtu.rand_default()
        } for repeats in [0, 1, 2] for dtype in default_dtypes
                            for shape in all_shapes for axis in [None] +
                            list(range(-len(shape), len(shape)))))
    def testRepeat(self, axis, shape, dtype, repeats, rng):
        onp_fun = lambda arg: onp.repeat(arg, repeats=repeats, axis=axis)
        lnp_fun = lambda arg: lnp.repeat(arg, repeats=repeats, axis=axis)

        args_maker = lambda: [rng(shape, dtype)]

        self._CheckAgainstNumpy(onp_fun,
                                lnp_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_dtype={}_m={}_n={}_k={}".format(onp.dtype(dtype).name, m, n, k),
            "m":
            m,
            "n":
            n,
            "k":
            k,
            "dtype":
            dtype,
            "rng":
            jtu.rand_default()
        } for dtype in default_dtypes for n in [0, 4]
                            for m in [None, 0, 1, 3, 4]
                            for k in list(range(-4, 4))))
    def testTri(self, m, n, k, dtype, rng):
        onp_fun = lambda: onp.tri(n, M=m, k=k, dtype=dtype)
        lnp_fun = lambda: lnp.tri(n, M=m, k=k, dtype=dtype)
        args_maker = lambda: []
        self._CheckAgainstNumpy(onp_fun,
                                lnp_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_op={}_shape={}_k={}".format(
                    op, jtu.format_shape_dtype_string(shape, dtype), k),
                "dtype":
                dtype,
                "shape":
                shape,
                "op":
                op,
                "k":
                k,
                "rng":
                jtu.rand_default()
            } for dtype in default_dtypes
            for shape in [shape for shape in all_shapes if len(shape) >= 1]
            for op in ["tril", "triu"] for k in list(range(-3, 3))))
    def testTriLU(self, dtype, shape, op, k, rng):
        onp_fun = lambda arg: getattr(onp, op)(arg, k=k)
        lnp_fun = lambda arg: getattr(lnp, op)(arg, k=k)
        args_maker = lambda: [rng(shape, dtype)]
        self._CheckAgainstNumpy(onp_fun,
                                lnp_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_shape={}_k={}".format(
                    jtu.format_shape_dtype_string(shape, dtype), k),
                "dtype":
                dtype,
                "shape":
                shape,
                "k":
                k,
                "rng":
                jtu.rand_default()
            } for dtype in default_dtypes for shape in
            [shape for shape in all_shapes if len(shape) in (1, 2)]
            for k in list(range(-4, 4))))
    def testDiag(self, shape, dtype, k, rng):
        onp_fun = lambda arg: onp.diag(arg, k)
        lnp_fun = lambda arg: lnp.diag(arg, k)
        args_maker = lambda: [rng(shape, dtype)]
        self._CheckAgainstNumpy(onp_fun,
                                lnp_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_shape={}_offset={}_axis1={}_axis2={}".format(
                    jtu.format_shape_dtype_string(shape, dtype), offset, axis1,
                    axis2),
                "dtype":
                dtype,
                "shape":
                shape,
                "offset":
                offset,
                "axis1":
                axis1,
                "axis2":
                axis2,
                "rng":
                jtu.rand_default()
            } for dtype in default_dtypes
            for shape in [shape for shape in all_shapes if len(shape) >= 2]
            for (axis1, axis2) in itertools.combinations(range(len(shape)), 2)
            for offset in list(range(-4, 4))))
    def testDiagonal(self, shape, dtype, offset, axis1, axis2, rng):
        onp_fun = lambda arg: onp.diagonal(arg, offset, axis1, axis2)
        lnp_fun = lambda arg: lnp.diagonal(arg, offset, axis1, axis2)
        args_maker = lambda: [rng(shape, dtype)]
        self._CheckAgainstNumpy(onp_fun,
                                lnp_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}".format(
                jtu.format_test_name_suffix("", [shape] *
                                            len(dtypes), dtypes)),
            "shape":
            shape,
            "dtypes":
            dtypes,
            "rng":
            rng
        } for dtypes in [
            [onp.float32],
            [onp.float32, onp.float32],
            [onp.float32, onp.int32, onp.float32],
            [onp.float32, onp.int64, onp.float32],
            [onp.float32, onp.int32, onp.float64],
        ] for shape in [(), (2, ), (3, 4), (1, 100)]
                            for rng in [jtu.rand_default()]))
    def testStack(self, shape, dtypes, rng):
        args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
        self._CheckAgainstNumpy(lnp.stack,
                                onp.stack,
                                args_maker,
                                check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_outdtype={}".format(
                jtu.format_shape_dtype_string(shape, fill_value_dtype),
                onp.dtype(out_dtype).name),
            "shape":
            shape,
            "fill_value_dtype":
            fill_value_dtype,
            "out_dtype":
            out_dtype,
            "rng":
            jtu.rand_default()
        } for shape in array_shapes for fill_value_dtype in default_dtypes
                            for out_dtype in default_dtypes))
    def testFull(self, shape, fill_value_dtype, out_dtype, rng):
        onp_fun = lambda fill_value: onp.full(
            shape, fill_value, dtype=out_dtype)
        lnp_fun = lambda fill_value: lnp.full(
            shape, fill_value, dtype=out_dtype)
        args_maker = lambda: [rng((), fill_value_dtype)]
        self._CheckAgainstNumpy(onp_fun,
                                lnp_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_axis={}_{}sections".format(
                jtu.format_shape_dtype_string(shape, dtype), axis,
                num_sections),
            "shape":
            shape,
            "num_sections":
            num_sections,
            "axis":
            axis,
            "dtype":
            dtype,
            "rng":
            jtu.rand_default()
        } for shape, axis, num_sections in [((3, ), 0, 3), ((
            12, ), 0, 3), ((12, 4), 0, 4), ((12, 4), 1,
                                            2), ((2, 3, 4), -1,
                                                 2), ((2, 3, 4), -2, 3)]
                            for dtype in default_dtypes))
    def testSplitStaticInt(self, shape, num_sections, axis, dtype, rng):
        onp_fun = lambda x: onp.split(x, num_sections, axis=axis)
        lnp_fun = lambda x: lnp.split(x, num_sections, axis=axis)
        args_maker = lambda: [rng(shape, dtype)]
        self._CheckAgainstNumpy(onp_fun,
                                lnp_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_outshape={}".format(
                jtu.format_shape_dtype_string(arg_shape, dtype),
                jtu.format_shape_dtype_string(out_shape, dtype)),
            "arg_shape":
            arg_shape,
            "out_shape":
            out_shape,
            "dtype":
            dtype,
            "rng":
            jtu.rand_default()
        } for dtype in default_dtypes for arg_shape, out_shape in [(
            jtu.NUMPY_SCALAR_SHAPE,
            (1, 1, 1)), ((), (1, 1, 1)), ((7, 0), (0, 42, 101)), ((
                3, 4), 12), ((3, 4),
                             (12, )), ((3, 4),
                                       -1), ((2, 1, 4),
                                             (-1, )), ((2, 2, 4), (2, 8))]))
    def testReshape(self, arg_shape, out_shape, dtype, rng):
        onp_fun = lambda x: onp.reshape(x, out_shape)
        lnp_fun = lambda x: lnp.reshape(x, out_shape)
        args_maker = lambda: [rng(arg_shape, dtype)]
        self._CheckAgainstNumpy(onp_fun,
                                lnp_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_expanddim={}".format(
                jtu.format_shape_dtype_string(arg_shape, dtype), dim),
            "arg_shape":
            arg_shape,
            "dtype":
            dtype,
            "dim":
            dim,
            "rng":
            jtu.rand_default()
        } for arg_shape in [(), (3, ), (3, 4)] for dtype in default_dtypes
                            for dim in range(-len(arg_shape) +
                                             1, len(arg_shape))))
    def testExpandDimsStaticDim(self, arg_shape, dtype, dim, rng):
        onp_fun = lambda x: onp.expand_dims(x, dim)
        lnp_fun = lambda x: lnp.expand_dims(x, dim)
        args_maker = lambda: [rng(arg_shape, dtype)]
        self._CheckAgainstNumpy(onp_fun,
                                lnp_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_axes=({},{})".format(
                jtu.format_shape_dtype_string(arg_shape, dtype), ax1, ax2),
            "arg_shape":
            arg_shape,
            "dtype":
            dtype,
            "ax1":
            ax1,
            "ax2":
            ax2,
            "rng":
            jtu.rand_default()
        } for arg_shape, ax1, ax2 in [((3, 4), 0,
                                       1), ((3, 4), 1,
                                            0), ((3, 4, 5), 1,
                                                 2), ((3, 4, 5), -1,
                                                      -2), ((3, 4, 5), 0, 1)]
                            for dtype in default_dtypes))
    def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2, rng):
        onp_fun = lambda x: onp.swapaxes(x, ax1, ax2)
        lnp_fun = lambda x: lnp.swapaxes(x, ax1, ax2)
        args_maker = lambda: [rng(arg_shape, dtype)]
        self._CheckAgainstNumpy(onp_fun,
                                lnp_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_axis={}".format(
                jtu.format_shape_dtype_string(arg_shape, dtype), ax),
            "arg_shape":
            arg_shape,
            "dtype":
            dtype,
            "ax":
            ax,
            "rng":
            jtu.rand_default()
        } for arg_shape, ax in [((3, 1), None), ((3, 1),
                                                 1), ((1, 3, 1),
                                                      (0,
                                                       2)), ((1, 4, 1), (0, ))]
                            for dtype in default_dtypes))
    def testSqueeze(self, arg_shape, dtype, ax, rng):
        onp_fun = lambda x: onp.squeeze(x, ax)
        lnp_fun = lambda x: lnp.squeeze(x, ax)
        args_maker = lambda: [rng(arg_shape, dtype)]
        self._CheckAgainstNumpy(onp_fun,
                                lnp_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_arg{}".format(i),
            "arg": arg
        } for i, arg in enumerate([
            [1, 2, 3],
            [1., 2., 3.],
            [[1, 2], [3, 4], [5, 6]],
            [[1, 2.], [3, 4], [5, 6]],
            [[3, onp.array(2), 1], onp.arange(3.)],
        ])))
    def testArray(self, arg):
        args_maker = lambda: [arg]
        self._CheckAgainstNumpy(onp.array,
                                lnp.array,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lnp.array, args_maker, check_dtypes=True)

    def testArrayAsarrayMethod(self):
        class arraylike(object):
            def __asarray__(self, dtype=None):
                return 3.

        a = arraylike()
        ans = lnp.array(a)
        assert ans == 3.

    def testAllClose(self):
        rng = onp.random.RandomState(0)
        x = rng.randn(2, 2)
        y = rng.randn(2)

        def same(list1, list2):
            allclose = functools.partial(lnp.allclose, atol=1e-3, rtol=1e-3)
            elements_close = list(map(allclose, list1, list2))
            return lnp.all(lnp.array(elements_close))

        csame = api.jit(same)

        a1 = same((x, y), (x, y))
        a2 = csame((x, y), (x, y))
        a3 = csame((x, y), (x, 2 * y))

        self.assertTrue(a1)
        self.assertTrue(a2)
        self.assertFalse(a3)

    @jtu.skip_on_devices("tpu")  # TODO(mattjj): investigate this failure
    def DISABLED_testOnesBroadcastingConstantHandler(self):
        # TODO(mattjj): update this test for jax3

        def fun(x):
            ones = lnp.ones((3, 4))
            assert isinstance(ones, onp.ndarray) and ones.strides == (0, 0)

            # To check that the constant handler generates a Broadcast for stride-zero
            # arrays, we monkey-patch the client instance.
            # TODO(mattjj): once we have better HLO dumping and inspecting facilities,
            # we can check the HLO more directly.
            c = x._node.c
            Broadcast = c.Broadcast  # pylint: disable=invalid-name
            was_called = []
            c.Broadcast = lambda *args: was_called.append(True) or Broadcast(
                *args)
            out = x + ones  # the ndarray constant handler should call Broadcast here
            assert was_called, "Broadcast was not called."

            return out

        fun = api.jit(fun)
        out_val = fun(lnp.ones(4))
        self.assertAllClose(out_val, onp.full((3, 4), 2.), check_dtypes=False)

    def testZeroStridesConstantHandler(self):
        raw_const = onp.random.RandomState(0).randn(1, 2, 1, 1, 5, 1)
        const = onp.broadcast_to(raw_const, (3, 2, 3, 4, 5, 6))

        def fun(x):
            return x * const

        fun = api.jit(fun)
        out_val = fun(3.)
        self.assertAllClose(out_val, 3. * const, check_dtypes=False)

    def testIsInstanceNdarrayDuringTracing(self):
        arr = onp.ones(3)

        @api.jit
        def f(x):
            self.assertIsInstance(x, lnp.ndarray)
            return lnp.sum(x)

        f(arr)

    def testNonArrayErrorMessage(self):
        x = [1., 2.]
        y = onp.array([3., 4.])

        def g(x, y):
            return lnp.add(x, y)

        def f(x, y):
            return lnp.dot(x, y)

        self.assertRaises(TypeError, lambda: g(x, y))
        self.assertRaises(TypeError, lambda: f(x, y))
        self.assertRaises(TypeError, lambda: api.jit(g)(x, y))
        self.assertRaises(TypeError, lambda: api.jit(f)(x, y))

    def testAbstractionErrorMessage(self):
        @api.jit
        def f(x, n):
            for _ in range(n):
                x = x * x
            return x

        self.assertRaises(TypeError, lambda: f(3., 3))

        @api.jit
        def g(x):
            if x > 0.:
                return x * 2
            else:
                return x + 2

        self.assertRaises(TypeError, lambda: g(3.))

    def DISABLED_testTracingPrimitiveWithNoTranslationErrorMessage(self):
        # TODO(mattjj): update this for jax3
        foo = lnp._not_implemented(lambda x: x)

        # No error if there's no tracing.
        foo(onp.arange(3))

        cfoo = api.jit(foo)
        self.assertRaises(NotImplementedError, lambda: cfoo(onp.arange(3)))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_axis={}".format(jtu.format_shape_dtype_string(shape, dtype),
                                 axis),
            "rng":
            rng,
            "shape":
            shape,
            "dtype":
            dtype,
            "axis":
            axis
        } for shape in [(3, ), (2, 3)] for dtype in default_dtypes
                            for axis in range(len(shape))
                            for rng in [jtu.rand_default()]))
    def testFlip(self, shape, dtype, axis, rng):
        args_maker = self._GetArgsMaker(rng, [shape], [dtype])
        lnp_op = lambda x: lnp.flip(x, axis)
        onp_op = lambda x: onp.flip(x, axis)
        self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
        self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_k={}_axes={}".format(
                jtu.format_shape_dtype_string(shape, dtype), k, axes),
            "rng":
            rng,
            "shape":
            shape,
            "dtype":
            dtype,
            "k":
            k,
            "axes":
            axes
        } for shape, axes in [
            [(2, 3), (0, 1)],
            [(2, 3), (1, 0)],
            [(4, 3, 2), (0, 2)],
            [(4, 3, 2), (2, 1)],
        ] for k in range(-3, 4) for dtype in default_dtypes
                            for rng in [jtu.rand_default()]))
    def testRot90(self, shape, dtype, k, axes, rng):
        args_maker = self._GetArgsMaker(rng, [shape], [dtype])
        lnp_op = lambda x: lnp.rot90(x, k, axes)
        onp_op = lambda x: onp.rot90(x, k, axes)
        self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
        self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)

    # TODO(mattjj): test infix operator overrides

    def DISABLED_testRavel(self):
        # TODO(mattjj): support this method-based syntax?
        rng = onp.random.RandomState(0)
        args_maker = lambda: [rng.randn(3, 4).astype("float32")]
        self._CompileAndCheck(lambda x: x.ravel(),
                              args_maker,
                              check_dtypes=True)
Пример #14
0
class LaxBackedScipyTests(jtu.JaxTestCase):
  """Tests for LAX-backed Scipy implementation."""

  def _GetArgsMaker(self, rng, shapes, dtypes):
    return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format(
          jtu.format_shape_dtype_string(shapes, dtype),
          axis, keepdims, return_sign, use_b),
       # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU.
       "shapes": shapes, "dtype": dtype,
       "axis": axis, "keepdims": keepdims,
       "return_sign": return_sign, "use_b": use_b}
      for shape_group in compatible_shapes for dtype in float_dtypes
      for use_b in [False, True]
      for shapes in itertools.product(*(
        (shape_group, shape_group) if use_b else (shape_group,)))
      for axis in range(-max(len(shape) for shape in shapes),
                         max(len(shape) for shape in shapes))
      for keepdims in [False, True]
      for return_sign in [False, True]))
  @jtu.ignore_warning(category=RuntimeWarning,
                      message="invalid value encountered in .*")
  def testLogSumExp(self, shapes, dtype, axis,
                    keepdims, return_sign, use_b):
    if jtu.device_under_test() != "cpu":
      rng = jtu.rand_some_inf_and_nan(self.rng())
    else:
      rng = jtu.rand_default(self.rng())
    # TODO(mattjj): test autodiff
    if use_b:
      def scipy_fun(array_to_reduce, scale_array):
        return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
                                     return_sign=return_sign, b=scale_array)

      def lax_fun(array_to_reduce, scale_array):
        return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
                                     return_sign=return_sign, b=scale_array)

      args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)]
    else:
      def scipy_fun(array_to_reduce):
        return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
                                     return_sign=return_sign)

      def lax_fun(array_to_reduce):
        return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
                                     return_sign=return_sign)

      args_maker = lambda: [rng(shapes[0], dtype)]
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
    self._CompileAndCheck(lax_fun, args_maker)

  @parameterized.named_parameters(itertools.chain.from_iterable(
    jtu.cases_from_list(
        {"testcase_name": jtu.format_test_name_suffix(
            rec.test_name, shapes, dtypes),
         "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes,
         "test_autodiff": rec.test_autodiff,
         "nondiff_argnums": rec.nondiff_argnums,
         "scipy_op": getattr(osp_special, rec.name),
         "lax_op": getattr(lsp_special, rec.name)}
        for shapes in itertools.combinations_with_replacement(all_shapes, rec.nargs)
        for dtypes in (itertools.combinations_with_replacement(rec.dtypes, rec.nargs)
          if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes)))
      for rec in JAX_SPECIAL_FUNCTION_RECORDS))
  def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes,
                          test_autodiff, nondiff_argnums):
    if (jtu.device_under_test() == "cpu" and
        (lax_op is lsp_special.gammainc or lax_op is lsp_special.gammaincc)):
      # TODO(b/173608403): re-enable test when LLVM bug is fixed.
      raise unittest.SkipTest("Skipping test due to LLVM lowering bug")
    rng = rng_factory(self.rng())
    args_maker = self._GetArgsMaker(rng, shapes, dtypes)
    args = args_maker()
    self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3,
                        check_dtypes=False)
    self._CompileAndCheck(lax_op, args_maker, rtol=1e-4)

    if test_autodiff:
      def partial_lax_op(*vals):
        list_args = list(vals)
        for i in nondiff_argnums:
          list_args.insert(i, args[i])
        return lax_op(*list_args)

      assert list(nondiff_argnums) == sorted(set(nondiff_argnums))
      diff_args = [x for i, x in enumerate(args) if i not in nondiff_argnums]
      jtu.check_grads(partial_lax_op, diff_args, order=1,
                      atol=jtu.if_device_under_test("tpu", .1, 1e-3),
                      rtol=.1, eps=1e-3)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_d={}".format(
          jtu.format_shape_dtype_string(shape, dtype), d),
       "rng_factory": jtu.rand_positive, "shape": shape, "dtype": dtype,
       "d": d}
      for shape in all_shapes
      for dtype in float_dtypes
      for d in [1, 2, 5]))
  def testMultigammaln(self, rng_factory, shape, dtype, d):
    def scipy_fun(a):
      return osp_special.multigammaln(a, d)

    def lax_fun(a):
      return lsp_special.multigammaln(a, d)

    rng = rng_factory(self.rng())
    args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
                            tol={np.float32: 1e-3, np.float64: 1e-14})
    self._CompileAndCheck(lax_fun, args_maker)

  def testIssue980(self):
    x = np.full((4,), -1e20, dtype=np.float32)
    self.assertAllClose(np.zeros((4,), dtype=np.float32),
                        lsp_special.expit(x))

  def testIssue3758(self):
    x = np.array([1e5, 1e19, 1e10], dtype=np.float32)
    q = np.array([1., 40., 30.], dtype=np.float32)
    self.assertAllClose(np.array([1., 0., 0.], dtype=np.float32), lsp_special.zeta(x, q))

  def testXlogyShouldReturnZero(self):
    self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False)

  def testGradOfXlogyAtZero(self):
    partial_xlogy = functools.partial(lsp_special.xlogy, 0.)
    self.assertAllClose(api.grad(partial_xlogy)(0.), 0., check_dtypes=False)

  def testXlog1pyShouldReturnZero(self):
    self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False)

  def testGradOfXlog1pyAtZero(self):
    partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.)
    self.assertAllClose(api.grad(partial_xlog1py)(-1.), 0., check_dtypes=False)
Пример #15
0
class LaxBackedScipyTests(jtu.JaxTestCase):
  """Tests for LAX-backed Scipy implementation."""

  def _GetArgsMaker(self, rng, shapes, dtypes):
    return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_axis={}_keepdims={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
       # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU.
       "rng_factory": jtu.rand_some_inf_and_nan
                      if jtu.device_under_test() != "cpu"
                      else jtu.rand_default,
       "shape": shape, "dtype": dtype,
       "axis": axis, "keepdims": keepdims}
      for shape in all_shapes for dtype in float_dtypes
      for axis in range(-len(shape), len(shape))
      for keepdims in [False, True]))
  @jtu.skip_on_flag("jax_xla_backend", "xrt")
  def testLogSumExp(self, rng_factory, shape, dtype, axis, keepdims):
    rng = rng_factory()
    # TODO(mattjj): test autodiff
    def scipy_fun(array_to_reduce):
      return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims)

    def lax_fun(array_to_reduce):
      return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims)

    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(itertools.chain.from_iterable(
    jtu.cases_from_list(
        {"testcase_name": jtu.format_test_name_suffix(
            rec.test_name, shapes, dtypes),
         "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes,
         "test_autodiff": rec.test_autodiff,
         "scipy_op": getattr(osp_special, rec.name),
         "lax_op": getattr(lsp_special, rec.name)}
        for shapes in CombosWithReplacement(all_shapes, rec.nargs)
        for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
      for rec in JAX_SPECIAL_FUNCTION_RECORDS))
  def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes,
                          test_autodiff):
    rng = rng_factory()
    args_maker = self._GetArgsMaker(rng, shapes, dtypes)
    args = args_maker()
    self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3,
                        check_dtypes=False)
    self._CompileAndCheck(lax_op, args_maker, check_dtypes=True, rtol=1e-5)

    if test_autodiff:
      jtu.check_grads(lax_op, args, order=1,
                      atol=jtu.if_device_under_test("tpu", 2e-3, 1e-3), rtol=3e-3, eps=1e-3)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_d={}".format(
          jtu.format_shape_dtype_string(shape, dtype), d),
       "rng_factory": jtu.rand_positive, "shape": shape, "dtype": dtype,
       "d": d}
      for shape in all_shapes
      for dtype in float_dtypes
      for d in [1, 2, 5]))
  def testMultigammaln(self, rng_factory, shape, dtype, d):
    def scipy_fun(a):
      return osp_special.multigammaln(a, d)

    def lax_fun(a):
      return lsp_special.multigammaln(a, d)

    rng = rng_factory()
    args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
                            tol={onp.float32: 1e-3, onp.float64: 1e-14})
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  def testIssue980(self):
    x = onp.full((4,), -1e20, dtype=onp.float32)
    self.assertAllClose(onp.zeros((4,), dtype=onp.float32),
                        lsp_special.expit(x), check_dtypes=True)

  def testXlogyShouldReturnZero(self):
    self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False)

  def testGradOfXlogyAtZero(self):
    partial_xlogy = functools.partial(lsp_special.xlogy, 0.)
    self.assertAllClose(api.grad(partial_xlogy)(0.), 0., check_dtypes=False)

  def testXlog1pyShouldReturnZero(self):
    self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False)

  def testGradOfXlog1pyAtZero(self):
    partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.)
    self.assertAllClose(api.grad(partial_xlog1py)(-1.), 0., check_dtypes=False)
Пример #16
0
class LaxVmapTest(jtu.JaxTestCase):

  def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng,
                     rtol=None, atol=None):
    batched_shapes = map(partial(add_bdim, bdim_size), bdims, shapes)
    args = [rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes)]
    args_slice = args_slicer(args, bdims)
    ans = api.vmap(op, bdims)(*args)
    expected = np.stack([op(*args_slice(i)) for i in range(bdim_size)])
    self.assertAllClose(ans, expected, rtol=rtol, atol=atol)

  @parameterized.named_parameters(itertools.chain.from_iterable(
      jtu.cases_from_list(
        {"testcase_name": "{}_bdims={}".format(
            jtu.format_test_name_suffix(rec.op, shapes,
                                        itertools.repeat(dtype)), bdims),
         "op_name": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes,
         "dtype": dtype, "bdims": bdims, "tol": rec.tol}
        for shape_group in compatible_shapes
        for shapes in itertools.combinations_with_replacement(shape_group, rec.nargs)
        for bdims in all_bdims(*shapes)
        for dtype in rec.dtypes)
      for rec in LAX_OPS))
  def testOp(self, op_name, rng_factory, shapes, dtype, bdims, tol):
    rng = rng_factory(self.rng())
    op = getattr(lax, op_name)
    self._CheckBatching(op, 10, bdims, shapes, [dtype] * len(shapes), rng,
                        atol=tol, rtol=tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
       "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}"
       "_lhs_bdim={}_rhs_bdim={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums),
               feature_group_count, batch_group_count, lhs_bdim, rhs_bdim),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "strides": strides, "padding": padding, "lhs_dil": lhs_dil,
       "rhs_dil": rhs_dil, "rng_factory": rng_factory, "dimension_numbers": dim_nums,
       "perms": perms, "lhs_bdim": lhs_bdim, "rhs_bdim": rhs_bdim,
       "feature_group_count": feature_group_count,
       "batch_group_count": batch_group_count,
       }
      for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)])
      for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in [
          ((b * batch_group_count, i * feature_group_count, 6, 7),  # lhs_shape
           (j * batch_group_count * feature_group_count, i, 1, 2),  # rhs_shape
           [(1, 1), (1, 2), (2, 1)],  # strides
           [((0, 0), (0, 0)), ((1, 0), (0, 1)), ((0, -1), (0, 0))],  # pads
           [(1, 1), (2, 1)],  # lhs_dils
           [(1, 1), (2, 2)])  # rhs_dils
          for b, i, j in itertools.product([1, 2], repeat=3)]
      for strides in all_strides
      for rhs_dil in rhs_dils
      for lhs_dil in lhs_dils
      for dtype in [np.float32]
      for padding in all_pads
      for dim_nums, perms in [
          (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
          (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
          (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))]
      for lhs_bdim in itertools.chain([cast(Optional[int], None)],
                                      range(len(lhs_shape) + 1))
      for rhs_bdim in itertools.chain([cast(Optional[int], None)],
                                      range(len(rhs_shape) + 1))
      if (lhs_bdim, rhs_bdim) != (None, None)
      for rng_factory in [jtu.rand_default]
  ))
  def testConvGeneralDilatedBatching(
      self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil,
      dimension_numbers, perms, feature_group_count, batch_group_count,
      lhs_bdim, rhs_bdim, rng_factory):
    rng = rng_factory(self.rng())
    tol = 1e-1 if dtypes.finfo(dtype).bits <= 32 else 1e-3

    # permute shapes to match dim_spec, scale by feature_group_count
    lhs_perm, rhs_perm = perms
    lhs_shape = list(np.take(lhs_shape, lhs_perm))
    rhs_shape = list(np.take(rhs_shape, rhs_perm))

    conv = partial(lax.conv_general_dilated, window_strides=strides,
                   padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
                   dimension_numbers=dimension_numbers,
                   feature_group_count=feature_group_count,
                   batch_group_count=batch_group_count,
                   precision=lax.Precision.HIGHEST)
    self._CheckBatching(conv, 5, (lhs_bdim, rhs_bdim), (lhs_shape, rhs_shape),
                        (dtype, dtype), rng, rtol=tol, atol=tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format(
          shape, from_dtype, to_dtype, bdims),
       "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype,
       "bdims": bdims, "rng_factory": rng_factory}
      for from_dtype, to_dtype in itertools.product(
          [np.float32, np.int32, "float32", "int32"], repeat=2)
      for shape in [(2, 3)]
      for bdims in all_bdims(shape)
      for rng_factory in [jtu.rand_default]))
  def testConvertElementType(self, shape, from_dtype, to_dtype, bdims, rng_factory):
    rng = rng_factory(self.rng())
    op = lambda x: lax.convert_element_type(x, to_dtype)
    self._CheckBatching(op, 10, bdims, (shape,), (from_dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format(
          shape, from_dtype, to_dtype, bdims),
       "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype,
       "bdims": bdims, "rng_factory": rng_factory}
      for from_dtype, to_dtype in itertools.product(
          [np.float32, np.int32, "float32", "int32"], repeat=2)
      for shape in [(2, 3)]
      for bdims in all_bdims(shape)
      for rng_factory in [jtu.rand_default]))
  def testBitcastElementType(self, shape, from_dtype, to_dtype, bdims, rng_factory):
    rng = rng_factory(self.rng())
    op = lambda x: lax.bitcast_convert_type(x, to_dtype)
    self._CheckBatching(op, 10, bdims, (shape,), (from_dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_min_shape={}_operand_shape={}_max_shape={}_bdims={}"
       .format(jtu.format_shape_dtype_string(min_shape, dtype),
               jtu.format_shape_dtype_string(operand_shape, dtype),
               jtu.format_shape_dtype_string(max_shape, dtype),
               bdims),
       "min_shape": min_shape, "operand_shape": operand_shape,
       "max_shape": max_shape, "dtype": dtype, "bdims": bdims, "rng_factory": rng_factory}
      for min_shape, operand_shape, max_shape in [
          [(), (2, 3), ()],
          [(2, 3), (2, 3), ()],
          [(), (2, 3), (2, 3)],
          [(2, 3), (2, 3), (2, 3)],
      ]
      for dtype in default_dtypes
      for bdims in all_bdims(min_shape, operand_shape, max_shape)
      for rng_factory in [jtu.rand_default]))
  def testClamp(self, min_shape, operand_shape, max_shape, dtype, bdims, rng_factory):
    rng = rng_factory(self.rng())
    raise SkipTest("batching rule for clamp not implemented")  # TODO(mattj)
    shapes = [min_shape, operand_shape, max_shape]
    self._CheckBatching(lax.clamp, 10, bdims, shapes, [dtype] * 3, rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_lhs_shape={}_rhs_shape={}_bdims={}".format(
          jtu.format_shape_dtype_string(lhs_shape, dtype),
          jtu.format_shape_dtype_string(rhs_shape, dtype),
          bdims),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "bdims": bdims, "rng_factory": rng_factory}
      for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]
      for bdims in all_bdims(lhs_shape, rhs_shape)
      for dtype in default_dtypes
      for rng_factory in [jtu.rand_default]))
  def testDot(self, lhs_shape, rhs_shape, dtype, bdims, rng_factory):
    rng = rng_factory(self.rng())
    op = partial(lax.dot, precision=lax.Precision.HIGHEST)
    self._CheckBatching(op, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype),
                        rng, rtol={np.float16: 5e-2, np.float64: 5e-14})

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}_bdims={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               lhs_contracting, rhs_contracting, bdims),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting,
       "bdims": bdims, "rng_factory": rng_factory}
      for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [
          [(5,), (5,), [0], [0]],
          [(5, 7), (5,), [0], [0]],
          [(7, 5), (5,), [1], [0]],
          [(3, 5), (2, 5), [1], [1]],
          [(5, 3), (5, 2), [0], [0]],
          [(5, 3, 2), (5, 2, 4), [0], [0]],
          [(5, 3, 2), (5, 2, 4), [0,2], [0,1]],
          [(5, 3, 2), (3, 5, 2, 4), [0,2], [1,2]],
          [(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]],
          [(3, 2), (2, 4), [1], [0]],
      ]
      for bdims in all_bdims(lhs_shape, rhs_shape)
      for dtype in default_dtypes
      for rng_factory in [jtu.rand_small]))
  def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype,
                                 lhs_contracting, rhs_contracting, bdims, rng_factory):
    rng = rng_factory(self.rng())
    dimension_numbers = ((lhs_contracting, rhs_contracting), ([], []))
    dot = partial(lax.dot_general, dimension_numbers=dimension_numbers)
    self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype),
                        rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_bdims={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               dimension_numbers, bdims),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "dimension_numbers": dimension_numbers, "bdims": bdims, "rng_factory": rng_factory}
      for lhs_shape, rhs_shape, dimension_numbers in [
          ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
          ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))),
          ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
      ]
      for bdims in all_bdims(lhs_shape, rhs_shape)
      for dtype in default_dtypes
      for rng_factory in [jtu.rand_small]))
  def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype,
                                     dimension_numbers, bdims, rng_factory):
    rng = rng_factory(self.rng())
    dot = partial(lax.dot_general, dimension_numbers=dimension_numbers)
    self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype),
                        rng)

    # Checks that batching didn't introduce any transposes or broadcasts.
    jaxpr = api.make_jaxpr(dot)(np.zeros(lhs_shape, dtype),
                                np.zeros(rhs_shape, dtype))
    for eqn in jtu.iter_eqns(jaxpr.jaxpr):
      self.assertFalse(eqn.primitive in ["transpose", "broadcast"])

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}_bdims={}".format(
          shape, np.dtype(dtype).name, broadcast_sizes, bdims),
       "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes,
       "bdims": bdims, "rng_factory": rng_factory}
      for shape in [(), (2, 3)]
      for dtype in default_dtypes
      for broadcast_sizes in [(), (2,), (1, 2)]
      for bdims in all_bdims(shape)
      for rng_factory in [jtu.rand_default]))
  def testBroadcast(self, shape, dtype, broadcast_sizes, bdims, rng_factory):
    rng = rng_factory(self.rng())
    op = lambda x: lax.broadcast(x, broadcast_sizes)
    self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_outshape={}_bcdims={}_bdims={}".format(
          jtu.format_shape_dtype_string(inshape, dtype),
          outshape, broadcast_dimensions, bdims),
       "inshape": inshape, "dtype": dtype, "outshape": outshape,
       "dimensions": broadcast_dimensions, "bdims": bdims,
       "rng_factory": rng_factory}
      for inshape, outshape, broadcast_dimensions in [
          ([2], [2, 2], [0]),
          ([2], [2, 2], [1]),
          ([2], [2, 3], [0]),
          ([], [2, 3], []),
      ]
      for dtype in default_dtypes
      for bdims in all_bdims(inshape)
      for rng_factory in [jtu.rand_default]))
  def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims, rng_factory):
    rng = rng_factory(self.rng())
    raise SkipTest("this test has failures in some cases")  # TODO(mattjj)
    op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
    self._CheckBatching(op, 5, bdims, (inshape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_dimensions={}_bdims={}".format(
          jtu.format_shape_dtype_string(arg_shape, np.float32),
          dimensions, bdims),
       "arg_shape": arg_shape, "dimensions": dimensions, "bdims": bdims,
       "rng_factory": rng_factory}
      for arg_shape, dimensions in [
          [(1,), (0,)],
          [(1,), (-1,)],
          [(2, 1, 4), (1,)],
          [(2, 1, 4), (-2,)],
          [(2, 1, 3, 1), (1,)],
          [(2, 1, 3, 1), (1, 3)],
          [(2, 1, 3, 1), (3,)],
          [(2, 1, 3, 1), (1, -1)],
      ]
      for bdims in all_bdims(arg_shape)
      for rng_factory in [jtu.rand_default]))
  def testSqueeze(self, arg_shape, dimensions, bdims, rng_factory):
    dtype = np.float32
    rng = rng_factory(self.rng())
    op = lambda x: lax.squeeze(x, dimensions)
    self._CheckBatching(op, 10, bdims, (arg_shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_outshape={}_dims={}_bdims={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          jtu.format_shape_dtype_string(out_shape, dtype),
          dimensions, bdims),
       "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
       "dimensions": dimensions, "bdims": bdims, "rng_factory": rng_factory}
      for dtype in default_dtypes
      for arg_shape, dimensions, out_shape in [
          [(3, 4), None, (12,)],
          [(2, 1, 4), None, (8,)],
          [(2, 2, 4), None, (2, 8)],
          [(2, 2, 4), (0, 1, 2), (2, 8)],
          [(2, 2, 4), (1, 0, 2), (8, 2)],
          [(2, 2, 4), (2, 1, 0), (4, 2, 2)]
      ]
      for bdims in all_bdims(arg_shape)
      for rng_factory in [jtu.rand_default]))
  def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims, rng_factory):
    rng = rng_factory(self.rng())
    op = lambda x: lax.reshape(x, out_shape, dimensions=dimensions)
    self._CheckBatching(op, 10, bdims, (arg_shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_pads={}_bdims={}"
       .format(jtu.format_shape_dtype_string(shape, dtype), pads, bdims),
       "shape": shape, "dtype": dtype, "pads": pads,
       "rng_factory": jtu.rand_small, "bdims": bdims}
      for shape in [(2, 3)]
      for bdims in all_bdims(shape)
      for dtype in default_dtypes
      for pads in [[(1, 2, 1), (0, 1, 0)]]))
  def testPad(self, shape, dtype, pads, bdims, rng_factory):
    rng = rng_factory(self.rng())
    fun = lambda operand: lax.pad(operand, np.array(0, dtype), pads)
    self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_predshape={}_argshapes={}_bdims={}".format(
          jtu.format_shape_dtype_string(pred_shape, np.bool_),
          jtu.format_shape_dtype_string(arg_shape, arg_dtype),
          bdims),
       "pred_shape": pred_shape, "arg_shape": arg_shape, "arg_dtype": arg_dtype,
       "bdims": bdims, "rng_factory": rng_factory}
      for arg_shape in [(), (3,), (2, 3)]
      for pred_shape in ([(), arg_shape] if arg_shape else [()])
      for bdims in all_bdims(pred_shape, arg_shape, arg_shape)
      for arg_dtype in default_dtypes
      for rng_factory in [jtu.rand_default]))
  def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims, rng_factory):
    rng = rng_factory(self.rng())
    op = lambda c, x, y: lax.select(c < 0, x, y)
    self._CheckBatching(op, 5, bdims, (pred_shape, arg_shape, arg_shape,),
                        (np.bool_, arg_dtype, arg_dtype), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_start_indices={}_limit_indices={}_strides={}_bdims={}".format(
          jtu.format_shape_dtype_string(shape, dtype),
          start_indices, limit_indices, strides, bdims),
       "shape": shape, "dtype": dtype, "starts": start_indices,
       "limits": limit_indices, "strides": strides, "bdims": bdims, "rng_factory": rng_factory}
      for shape, start_indices, limit_indices, strides in [
        [(3,), (1,), (2,), None],
        [(7,), (4,), (7,), None],
        [(5,), (1,), (5,), (2,)],
        [(8,), (1,), (6,), (2,)],
        [(5, 3), (1, 1), (3, 2), None],
        [(5, 3), (1, 1), (3, 1), None],
        [(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
        [(5, 3), (1, 1), (2, 1), (1, 1)],
        [(5, 3), (1, 1), (5, 3), (2, 1)],
      ]
      for bdims in all_bdims(shape)
      for dtype in default_dtypes
      for rng_factory in [jtu.rand_default]))
  def testSlice(self, shape, dtype, starts, limits, strides, bdims, rng_factory):
    rng = rng_factory(self.rng())
    op = lambda x: lax.slice(x, starts, limits, strides)
    self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_perm={}_bdims={}".format(
          jtu.format_shape_dtype_string(shape, dtype), perm, bdims),
       "shape": shape, "dtype": dtype, "perm": perm, "bdims": bdims, "rng_factory": rng_factory}
      for shape, perm in [
        [(3, 4), (1, 0)],
        [(3, 4), (0, 1)],
        [(3, 4, 5), (2, 1, 0)],
        [(3, 4, 5), (1, 0, 2)],
      ]
      for bdims in all_bdims(shape)
      for dtype in default_dtypes
      for rng_factory in [jtu.rand_default]))
  def testTranspose(self, shape, dtype, perm, bdims, rng_factory):
    rng = rng_factory(self.rng())
    op = lambda x: lax.transpose(x, perm)
    self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_inshape={}_reducedims={}_initval={}_bdims={}"
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims,
               init_val, bdims),
       "op": op, "init_val": init_val, "shape": shape, "dtype": dtype,
       "dims": dims, "bdims": bdims, "rng_factory": rng_factory}
      for init_val, op, dtypes in [
          (0, lax.add, default_dtypes),
          (1, lax.mul, default_dtypes),
          (0, lax.max, all_dtypes), # non-monoidal
          (-np.inf, lax.max, float_dtypes),
          (dtypes.iinfo(np.int32).min, lax.max, [np.int32]),
          (dtypes.iinfo(np.int64).min, lax.max, [np.int64]),
          (dtypes.iinfo(np.uint32).min, lax.max, [np.uint32]),
          (dtypes.iinfo(np.uint64).min, lax.max, [np.uint64]),
          (np.inf, lax.min, float_dtypes),
          (dtypes.iinfo(np.int32).max, lax.min, [np.int32]),
          (dtypes.iinfo(np.int64).max, lax.min, [np.int64]),
          (dtypes.iinfo(np.uint32).max, lax.min, [np.uint32]),
          (dtypes.iinfo(np.uint64).max, lax.min, [np.uint64]),
      ]
      for dtype in dtypes
      for shape, dims in [
          [(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)],
          [(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]
      ]
      for bdims in all_bdims(shape)
      for rng_factory in [jtu.rand_small]))
  def testReduce(self, op, init_val, shape, dtype, dims, bdims, rng_factory):
    rng = rng_factory(self.rng())
    init_val = np.asarray(init_val, dtype=dtype)
    fun = lambda operand: lax.reduce(operand, init_val, op, dims)
    self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_inshape={}_reducedims={}_bdims={}"
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dim,
               bdims),
       "op": op, "shape": shape, "dtype": dtype,
       "dim": dim, "bdims": bdims}
      for op in [lax.argmin, lax.argmax]
      for dtype in default_dtypes
      for shape in [(3, 4, 5)]
      for dim in range(len(shape))
      for bdims in all_bdims(shape)))
  def testArgminmax(self, op, shape, dtype, dim, bdims):
    rng = jtu.rand_default(self.rng())
    fun = lambda operand: op(operand, dim, np.int32)
    self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}"
                         "_basedilation={}_windowdilation={}")
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype),
               dims, strides, padding, base_dilation, window_dilation),
       "op": op, "init_val": init_val, "dtype": dtype, "shape": shape,
       "dims": dims, "strides": strides, "padding": padding,
       "base_dilation": base_dilation, "window_dilation": window_dilation}
      for init_val, op, dtypes in [
          (0, lax.add, [np.float32]),
          (-np.inf, lax.max, [np.float32]),
          (np.inf, lax.min, [np.float32]),
      ]
      for shape, dims, strides, padding, base_dilation, window_dilation in (
        itertools.chain(
          itertools.product(
            [(4, 6)],
            [(2, 1), (1, 2)],
            [(1, 1), (2, 1), (1, 2)],
            ["VALID", "SAME", [(0, 3), (1, 2)]],
            [(1, 1), (2, 3)],
            [(1, 1), (1, 2)]),
          itertools.product(
            [(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
            [(1, 2, 2, 1), (1, 1, 1, 1)],
            ["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]],
            [(1, 1, 1, 1), (2, 1, 3, 2)],
            [(1, 1, 1, 1), (1, 2, 2, 1)])))
      for dtype in dtypes))
  def testReduceWindow(self, op, init_val, dtype, shape, dims, strides, padding,
                       base_dilation, window_dilation):
    rng = jtu.rand_small(self.rng())
    init_val = np.asarray(init_val, dtype=dtype)

    def fun(operand):
      return lax.reduce_window(operand, init_val, op, dims, strides, padding,
                               base_dilation, window_dilation)

    for bdims in all_bdims(shape):
      self._CheckBatching(fun, 3, bdims, (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_shape={}_axis={}_bdims={}_reverse={}"
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis,
               bdims, reverse),
       "op": op, "shape": shape, "dtype": dtype, "bdims": bdims,
       "axis": axis, "reverse": reverse}
      for op, types in [
          (lax.cumsum, [np.float32, np.float64]),
          (lax.cumprod, [np.float32, np.float64]),
      ]
      for dtype in types
      for shape in [[10], [3, 4, 5]]
      for axis in range(len(shape))
      for bdims in all_bdims(shape)
      for reverse in [False, True]))
  def testCumulativeReduce(self, op, shape, dtype, axis, bdims, reverse):
    rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
                   else jtu.rand_small)
    rng = rng_factory(self.rng())
    self._CheckBatching(partial(op, axis=axis, reverse=reverse), 7, bdims,
                        (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_dtype={}_padding={}".format(np.dtype(dtype).name,
                                                      padding),
       "dtype": dtype, "padding": padding, "rng_factory": rng_factory}
      for dtype in float_dtypes
      for padding in ["VALID", "SAME"]
      for rng_factory in [jtu.rand_small]))
  @jtu.skip_on_flag("jax_skip_slow_tests", True)
  @jtu.ignore_warning(message="Using reduced precision for gradient.*")
  def testSelectAndGatherAdd(self, dtype, padding, rng_factory):
    if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16:
      raise SkipTest("bfloat16 _select_and_gather_add doesn't work on tpu")
    rng = rng_factory(self.rng())
    all_configs = itertools.chain(
        itertools.product(
            [(4, 6)],
            [(2, 1), (1, 2)],
            [(1, 1), (2, 1), (1, 2)]),
        itertools.product(
            [(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
            [(1, 2, 2, 1), (1, 1, 1, 1)]))

    def fun(operand, tangents):
      pads = lax.padtype_to_pads(operand.shape, dims, strides, padding)
      ones = (1,) * len(operand.shape)
      return lax._select_and_gather_add(operand, tangents, lax.ge_p, dims,
                                        strides, pads, ones, ones)

    for shape, dims, strides in all_configs:
      for bdims in all_bdims(shape, shape):
        self._CheckBatching(fun, 3, bdims, (shape, shape), (dtype, dtype), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": f"_dtype={jtu.format_shape_dtype_string(shape, dtype)}"
      f"_padding={padding}_dims={dims}_strides={strides}",
       "dtype": dtype, "padding": padding, "shape": shape,
       "dims": dims, "strides": strides}
      for dtype in float_dtypes
      for padding in ["VALID", "SAME"]
      for shape in [(3, 2, 4, 6)]
      for dims in [(1, 1, 2, 1)]
      for strides in [(1, 2, 2, 1), (1, 1, 1, 1)]))
  def testSelectAndScatterAdd(self, dtype, padding, shape, dims, strides):
    rng = jtu.rand_small(self.rng())

    pads = lax.padtype_to_pads(shape, dims, strides, padding)

    def fun(operand, cotangents):
      return lax._select_and_scatter_add(operand, cotangents, lax.ge_p, dims,
                                         strides, pads)
    ones = (1,) * len(shape)
    cotangent_shape = api.eval_shape(
      lambda x: lax._select_and_gather_add(x, x, lax.ge_p, dims, strides,
                                           pads, ones, ones),
      np.ones(shape, dtype)).shape

    for bdims in all_bdims(cotangent_shape, shape):
      self._CheckBatching(fun, 3, bdims, (cotangent_shape, shape),
                          (dtype, dtype), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_bdims={}_fft_ndims={}"
       .format(shape, bdims, fft_ndims),
       "shape": shape, "bdims": bdims, "fft_ndims": fft_ndims, "rng_factory": rng_factory}
      for shape in [(5,), (3, 4, 5), (2, 3, 4, 5)]
      for bdims in all_bdims(shape)
      for fft_ndims in range(0, min(3, len(shape)) + 1)
      for rng_factory in [jtu.rand_default]))
  @jtu.skip_on_devices("tpu")  # TODO(b/137993701): unimplemented cases.
  def testFft(self, fft_ndims, shape, bdims, rng_factory):
    rng = rng_factory(self.rng())
    ndims = len(shape)
    axes = range(ndims - fft_ndims, ndims)
    fft_lengths = [shape[axis] for axis in axes]
    op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths)
    self._CheckBatching(op, 5, bdims, [shape], [np.complex64], rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}_bdims={}"
       .format(jtu.format_shape_dtype_string(shape, dtype), idxs, dnums,
               slice_sizes, bdims),
       "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
       "slice_sizes": slice_sizes, "bdims": bdims}
      for dtype in all_dtypes
      for shape, idxs, dnums, slice_sizes in [
          ((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers(
            offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1,)),
          ((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
            (2,)),
          ((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1, 3)),
          ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
            (1, 3)),
      ]
      for bdims in all_bdims(shape, idxs.shape)))
  def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims):
    fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
    self._CheckBatching(fun, 5, bdims, [shape, idxs.shape], [dtype, idxs.dtype],
                        jtu.rand_default(self.rng()))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}_bdims={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums, bdims),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums, "bdims": bdims}
      for dtype in float_dtypes
      for arg_shape, idxs, update_shape, dnums in [
          ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
          ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,))),
          ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
      ]
      for bdims in all_bdims(arg_shape, idxs.shape, update_shape)))
  def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims):
    fun = partial(lax.scatter_add, dimension_numbers=dnums)
    self._CheckBatching(fun, 5, bdims, [arg_shape, idxs.shape, update_shape],
                        [dtype, idxs.dtype, dtype], jtu.rand_default(self.rng()),
                        rtol={np.float16: 5e-3})

  def testShapeUsesBuiltinInt(self):
    x = lax.iota(np.int32, 3) + 1
    self.assertIsInstance(x.shape[0], int)  # not np.int64

  def testBroadcastShapesReturnsPythonInts(self):
    shape1, shape2 = (1, 2, 3), (2, 3)
    out_shape = lax.broadcast_shapes(shape1, shape2)
    self.assertTrue(all(type(s) is int for s in out_shape))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_k={}_bdims={}".format(
          jtu.format_shape_dtype_string(shape, dtype), k, bdims),
       "shape": shape, "dtype": dtype, "k": k, "bdims": bdims, "rng_factory": rng_factory}
      for shape in [(4,), (3, 5, 3)]
      for k in [1, 3]
      for bdims in all_bdims(shape)
      # TODO(b/155170120): test with repeats once the XLA:CPU stable top_k bug is fixed:
      # The top_k indices for integer arrays with identical entries won't match between
      # vmap'd version and manual reference, so only test unique integer arrays for int_dtypes.
      # Note also that we chose 3 * 5 * 3 * 5 such that it fits in the range of
      # values a bfloat16 can represent exactly to avoid ties.
      for dtype, rng_factory in itertools.chain(
        unsafe_zip(default_dtypes, itertools.repeat(jtu.rand_unique_int)))))
  def testTopK(self, shape, dtype, k, bdims, rng_factory):
    rng = rng_factory(self.rng())
    # _CheckBatching doesn't work with tuple outputs, so test outputs separately.
    op1 = lambda x: lax.top_k(x, k=k)[0]
    self._CheckBatching(op1, 5, bdims, (shape,), (dtype,), rng)
    op2 = lambda x: lax.top_k(x, k=k)[1]
    self._CheckBatching(op2, 5, bdims, (shape,), (dtype,), rng)


  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_dimension={}_arity={}_bdims={}_isstable={}"
       .format(jtu.format_shape_dtype_string(shape, np.float32), dimension,
               arity, bdims, is_stable),
       "shape": shape, "dimension": dimension, "arity": arity, "bdims": bdims,
       "is_stable": is_stable}
      for shape in [(2, 3)]
      for dimension in [0, 1]
      for arity in range(3)
      for bdims in all_bdims(*((shape,) * arity))
      for is_stable in [False, True]))
  def testSort(self, shape, dimension, arity, bdims, is_stable):
    rng = jtu.rand_default(self.rng())
    if arity == 1:
      fun = partial(lax.sort, dimension=dimension)
      self._CheckBatching(fun, 5, bdims, (shape,) * arity, (np.float32,) * arity,
                          rng)
    else:
      for i in range(arity):
        fun = lambda *args, i=i: lax.sort(args,
                                          dimension=dimension,
                                          is_stable=is_stable)[i]
        self._CheckBatching(fun, 5, bdims, (shape,) * arity,
                            (np.float32,) * arity, rng)
Пример #17
0
class IndexingTest(jtu.JaxTestCase):
    """Tests for Numpy indexing translation rules."""
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "{}_inshape={}_indexer={}".format(
                name, jtu.format_shape_dtype_string(shape, dtype), indexer),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng_factory":
            rng_factory,
            "indexer":
            indexer
        } for name, index_specs in STATIC_INDEXING_TESTS
                            for shape, indexer in index_specs
                            for dtype in all_dtypes
                            for rng_factory in [jtu.rand_default]))
    def testStaticIndexing(self, shape, dtype, rng_factory, indexer):
        rng = rng_factory()
        args_maker = lambda: [rng(shape, dtype)]
        fun = lambda x: x[indexer]
        self._CompileAndCheck(fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng_factory":
        rng_factory,
        "indexer":
        indexer
    } for name, index_specs in STATIC_INDEXING_GRAD_TESTS
                                    for shape, indexer in index_specs
                                    for dtype in float_dtypes
                                    for rng_factory in [jtu.rand_default])
    def testStaticIndexingGrads(self, shape, dtype, rng_factory, indexer):
        rng = rng_factory()
        tol = 1e-2 if lnp.finfo(dtype).bits == 32 else None
        arg = rng(shape, dtype)
        fun = lambda x: x[indexer]**2
        check_grads(fun, (arg, ), 2, tol, tol, tol)

    def _ReplaceSlicesWithTuples(self, idx):
        """Helper method to replace slices with tuples for dynamic indexing args."""
        if isinstance(idx, slice):
            triple = idx.start, idx.stop, idx.step
            isnone = [i for i, elt in enumerate(triple) if elt is None]
            zeros = itertools.repeat(0)
            nones = itertools.repeat(None)
            out = util.subvals(triple, zip(isnone, zeros))
            return out, lambda out: slice(*util.subvals(
                out, zip(isnone, nones)))
        elif isinstance(idx, (tuple, list)) and idx:
            t = type(idx)
            elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx))
            return elts, lambda elts: t(
                (pack(i) for pack, i in zip(packs, elts)))
        else:
            return idx, lambda x: x

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng_factory":
        rng_factory,
        "indexer":
        indexer
    } for name, index_specs in [
        ("OneSliceIndex", [
            IndexSpec(shape=(5, ), indexer=slice(1, 3)),
            IndexSpec(shape=(5, 4), indexer=slice(1, 3))
        ]),
        ("TwoSliceIndices", [
            IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))),
            IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2)))
        ]),
        ("NonUnitStrides", [
            IndexSpec(shape=(3, ), indexer=slice(None, None, -1)),
            IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)),
            IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2))
        ]),
        ("OnlyStartOrStopDynamic", [
            IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))),
            IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None)))
        ]),
    ] for shape, indexer in index_specs for dtype in all_dtypes
                                    for rng_factory in [jtu.rand_default])
    def testDynamicIndexingWithSlicesErrors(self, shape, dtype, rng_factory,
                                            indexer):
        rng = rng_factory()
        unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

        @api.jit
        def fun(x, unpacked_indexer):
            indexer = pack_indexer(unpacked_indexer)
            return x[indexer]

        args_maker = lambda: [rng(shape, dtype), unpacked_indexer]
        self.assertRaises(IndexError, lambda: fun(*args_maker()))

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng_factory":
        rng_factory,
        "indexer":
        indexer
    } for name, index_specs in [
        ("OneIntIndex", [
            IndexSpec(shape=(3, ), indexer=1),
            IndexSpec(shape=(3, 3), indexer=0),
            IndexSpec(shape=(3, 4, 5), indexer=2),
            IndexSpec(shape=(3, ), indexer=-1),
            IndexSpec(shape=(3, ), indexer=-2)
        ]),
        ("TwoIntIndices", [
            IndexSpec(shape=(3, 3), indexer=(2, 1)),
            IndexSpec(shape=(3, 4, 5), indexer=(1, 2)),
            IndexSpec(shape=(3, 4, 5), indexer=(-1, 2))
        ]),
        ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]),
    ] for shape, indexer in index_specs for dtype in all_dtypes
                                    for rng_factory in [jtu.rand_default])
    def testDynamicIndexingWithIntegers(self, shape, dtype, rng_factory,
                                        indexer):
        rng = rng_factory()
        unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

        def fun(x, unpacked_indexer):
            indexer = pack_indexer(unpacked_indexer)
            return x[indexer]

        args_maker = lambda: [rng(shape, dtype), unpacked_indexer]
        self._CompileAndCheck(fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng_factory":
        rng_factory,
        "indexer":
        indexer
    } for name, index_specs in [
        ("OneIntIndex", [
            IndexSpec(shape=(3, ), indexer=1),
            IndexSpec(shape=(3, 3), indexer=0),
            IndexSpec(shape=(3, 4, 5), indexer=2),
            IndexSpec(shape=(3, ), indexer=-1),
            IndexSpec(shape=(3, ), indexer=-2),
        ]),
        ("TwoIntIndices", [
            IndexSpec(shape=(3, 3), indexer=(2, 1)),
            IndexSpec(shape=(3, 4, 5), indexer=(1, 2)),
            IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)),
        ]),
        ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]),
    ] for shape, indexer in index_specs for dtype in float_dtypes
                                    for rng_factory in [jtu.rand_default])
    def testDynamicIndexingWithIntegersGrads(self, shape, dtype, rng_factory,
                                             indexer):
        rng = rng_factory()
        tol = 1e-2 if lnp.finfo(dtype).bits == 32 else None
        unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

        @api.jit
        def fun(unpacked_indexer, x):
            indexer = pack_indexer(unpacked_indexer)
            return x[indexer]

        arr = rng(shape, dtype)
        check_grads(partial(fun, unpacked_indexer), (arr, ), 2, tol, tol, tol)

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng_factory":
        rng_factory,
        "indexer":
        indexer
    } for name, index_specs in ADVANCED_INDEXING_TESTS
                                    for shape, indexer in index_specs
                                    for dtype in all_dtypes
                                    for rng_factory in [jtu.rand_default])
    def testAdvancedIntegerIndexing(self, shape, dtype, rng_factory, indexer):
        rng = rng_factory()
        args_maker = lambda: [rng(shape, dtype), indexer]
        fun = lambda x, idx: x[idx]
        self._CompileAndCheck(fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng_factory":
        rng_factory,
        "indexer":
        indexer
    } for name, index_specs in [
        ("One1DIntArrayIndex", [
            IndexSpec(shape=(3, ), indexer=onp.array([0, 1])),
            IndexSpec(shape=(3, 3), indexer=onp.array([1, 2, 1])),
            IndexSpec(shape=(3, 4, 5), indexer=onp.array([0, 2, 0, 1])),
            IndexSpec(shape=(3, ), indexer=onp.array([-1, 1])),
            IndexSpec(shape=(3, ), indexer=onp.array([-2, -1])),
        ]),
        ("One2DIntArrayIndex", [
            IndexSpec(shape=(3, ), indexer=onp.array([[0, 0]])),
            IndexSpec(shape=(3,
                             3), indexer=onp.array([[1, 2, 1], [0, 1, -1]])),
            IndexSpec(shape=(3, 4, 5),
                      indexer=onp.array([[0, 2, 0, 1], [-1, -2, 1, 0]])),
        ]),
        ("Two1DIntArrayIndicesNoBroadcasting", [
            IndexSpec(shape=(3, 3),
                      indexer=[onp.array([0, 1]),
                               onp.array([1, 2])]),
            IndexSpec(
                shape=(3,
                       4, 5),
                indexer=[onp.array([0, 2, 0, 1]),
                         onp.array([-1, 0, -1, 2])]),
        ]),
        ("Two1DIntArrayIndicesWithBroadcasting", [
            IndexSpec(shape=(3, 3),
                      indexer=[onp.array([[0, 1]]),
                               onp.array([1, 2])]),
            IndexSpec(shape=(3, 4, 5),
                      indexer=[
                          onp.array([[0, 2, 0, 1]]),
                          onp.array([-1, 0, -1, 2])
                      ]),
        ]),
        ("ListOfPythonInts", [
            IndexSpec(shape=(3, ), indexer=[0, 1, 0]),
            IndexSpec(shape=(3, 4, 5), indexer=[0, -1]),
        ]),
        ("ListOfListsOfPythonInts", [
            IndexSpec(shape=(3, 4, 5), indexer=[[0, 1]]),
            IndexSpec(shape=(3, 4, 5), indexer=[[[0], [-1]], [[2, 3, 0, 3]]]),
        ]),
        ("ListOfPythonIntsAndIntArrays", [
            IndexSpec(shape=(3, 4, 5), indexer=[0, onp.array([0, 1])]),
            IndexSpec(shape=(3, 4, 5),
                      indexer=[0, 1, onp.array([[2, 3, 0, 3]])]),
        ]),
        ("ListOfListsOfPythonIntsAndIntArrays", [
            IndexSpec(shape=(3, 4, 5), indexer=[[0, 1], onp.array([0])]),
            IndexSpec(shape=(3, 4, 5),
                      indexer=[[[0], [-1]],
                               onp.array([[2, 3, 0, 3]])]),
        ]),
    ] for shape, indexer in index_specs for dtype in float_dtypes
                                    for rng_factory in [jtu.rand_default])
    def testAdvancedIntegerIndexingGrads(self, shape, dtype, rng_factory,
                                         indexer):
        rng = rng_factory()
        tol = 1e-2 if lnp.finfo(dtype).bits == 32 else None
        arg = rng(shape, dtype)
        fun = lambda x: x[indexer]**2
        check_grads(fun, (arg, ), 2, tol, tol, tol)

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng_factory":
        rng_factory,
        "indexer":
        indexer
    } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS
                                    for shape, indexer in index_specs
                                    for dtype in all_dtypes
                                    for rng_factory in [jtu.rand_default])
    def testMixedAdvancedIntegerIndexing(self, shape, dtype, rng_factory,
                                         indexer):
        rng = rng_factory()
        indexer_with_dummies = [
            e if isinstance(e, onp.ndarray) else () for e in indexer
        ]
        substitutes = [(i, e) for i, e in enumerate(indexer)
                       if not isinstance(e, onp.ndarray)]
        args_maker = lambda: [rng(shape, dtype), indexer_with_dummies]

        def fun(x, indexer_with_dummies):
            idx = type(indexer)(util.subvals(indexer_with_dummies,
                                             substitutes))
            return x[idx]

        self._CompileAndCheck(fun, args_maker, check_dtypes=True)

    def testAdvancedIndexingManually(self):
        x = onp.random.RandomState(0).randn(3, 4, 5)
        index_array = onp.array([0, 2, -1, 0])

        op = lambda x, index_array: x[..., index_array, :]
        cop = api.jit(op)

        a1 = op(x, index_array)
        a2 = cop(x, index_array)

        self.assertAllClose(a1, a2, check_dtypes=True)

        op = lambda x, index_array: x[..., index_array, :, index_array, None]
        cop = api.jit(op)

        a1 = op(x, index_array)
        a2 = cop(x, index_array)

        self.assertAllClose(a1, a2, check_dtypes=True)

        op = lambda x, index_array: x[index_array, ..., index_array[:, None],
                                      None]
        cop = api.jit(op)

        a1 = op(x, index_array)
        a2 = cop(x, index_array)

        self.assertAllClose(a1, a2, check_dtypes=True)

    def testUnpacking(self):
        def foo(x):
            a, b, c = x
            return a + b + c

        cfoo = api.jit(foo)

        a1 = foo(onp.arange(3))
        a2 = cfoo(onp.arange(3))

        self.assertAllClose(a1, a2, check_dtypes=True)

    def testBooleanIndexingArray1D(self):
        idx = onp.array([True, True, False])
        x = api.device_put(onp.arange(3))
        ans = x[idx]
        expected = onp.arange(3)[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testBooleanIndexingList1D(self):
        idx = [True, True, False]
        x = api.device_put(onp.arange(3))
        ans = x[idx]
        expected = onp.arange(3)[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testBooleanIndexingArray2DBroadcast(self):
        idx = onp.array([True, True, False, True])
        x = onp.arange(8).reshape(4, 2)
        ans = api.device_put(x)[idx]
        expected = x[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testBooleanIndexingList2DBroadcast(self):
        idx = [True, True, False, True]
        x = onp.arange(8).reshape(4, 2)
        ans = api.device_put(x)[idx]
        expected = x[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testBooleanIndexingArray2D(self):
        idx = onp.array([[True, False], [False, True], [False, False],
                         [True, True]])
        x = onp.arange(8).reshape(4, 2)
        ans = api.device_put(x)[idx]
        expected = x[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testBooleanIndexingDynamicShapeError(self):
        x = onp.zeros(3)
        i = onp.array([True, True, False])
        self.assertRaises(IndexError, lambda: api.jit(lambda x, i: x[i])(x, i))

    def testIssue187(self):
        x = lnp.ones((5, 5))
        x[[0, 2, 4], [0, 2, 4]]  # doesn't crash

        x = onp.arange(25).reshape((5, 5))
        ans = api.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x)
        expected = x[[0, 2, 4], [0, 2, 4]]
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testJVPOfGradOfIndexing(self):
        # Should return a value, even though we didn't pass a symbolic zero as the
        # index tangent.
        x = lnp.ones((3, 4), lnp.float32)
        i = lnp.ones((3, ), lnp.int32)
        f = lambda x, i: lnp.sum(x[i])
        primals, tangents = api.jvp(api.grad(f), (x, i),
                                    (x, onp.zeros_like(i)))
        expected = onp.broadcast_to(
            onp.array([0, 3, 0], dtype=onp.float32)[:, None], (3, 4))
        self.assertAllClose(expected, primals, check_dtypes=True)
        self.assertAllClose(onp.zeros_like(x), tangents, check_dtypes=True)

    def testTrivialGatherIsntGenerated(self):
        # https://github.com/google/jax/issues/1621
        jaxpr = api.make_jaxpr(lambda x: x[:, None])(onp.arange(4))
        self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
        self.assertNotIn('gather', str(jaxpr))

    def testBooleanIndexingWithEmptyResult(self):
        # based on a TensorFlow Probability test that started failing after #1622
        x = lnp.array([-1])
        mask = lnp.array([False])
        ans = x[mask]  # doesn't crash

        expected = onp.array([-1])[onp.array([False])]
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testFloatIndexingError(self):
        x = lnp.array([1, 2, 3])
        self.assertRaises(TypeError, lambda: x[3.5])
Пример #18
0
class CustomObjectTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_compile={}_primitive={}".format(compile, primitive),
            "compile":
            compile,
            "primitive":
            primitive
        } for primitive in [True, False] for compile in [True, False]))
    def testSparseIdentity(self, compile, primitive):
        f = identity if primitive else (lambda x: x)
        f = jit(f) if compile else f
        rng = jtu.rand_default(self.rng())
        M = make_sparse_array(rng, (10, ), jnp.float32)
        M2 = f(M)

        jaxpr = make_jaxpr(f)(M).jaxpr
        core.check_jaxpr(jaxpr)

        self.assertEqual(M.dtype, M2.dtype)
        self.assertEqual(M.index_dtype, M2.index_dtype)
        self.assertAllClose(M.data, M2.data)
        self.assertAllClose(M.indices, M2.indices)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_compile={}".format(compile),
            "compile": compile
        } for compile in [True, False]))
    def testSparseSplit(self, compile):
        f = jit(split) if compile else split
        rng = jtu.rand_default(self.rng())
        M = make_sparse_array(rng, (10, ), jnp.float32)
        M2, M3 = f(M)

        jaxpr = make_jaxpr(f)(M).jaxpr
        core.check_jaxpr(jaxpr)

        for MM in M2, M3:
            self.assertEqual(M.dtype, MM.dtype)
            self.assertEqual(M.index_dtype, MM.index_dtype)
            self.assertArraysEqual(M.data, MM.data)
            self.assertArraysEqual(M.indices, MM.indices)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_compile={}_primitive={}".format(compile, primitive),
            "compile":
            compile,
            "primitive":
            primitive
        } for primitive in [True, False] for compile in [True, False]))
    def testSparseLaxLoop(self, compile, primitive):
        rng = jtu.rand_default(self.rng())
        f = identity if primitive else (lambda x: x)
        f = jit(f) if compile else f
        body_fun = lambda _, A: f(A)
        M = make_sparse_array(rng, (10, ), jnp.float32)
        lax.fori_loop(0, 10, body_fun, M)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_attr={}".format(attr),
            "attr": attr
        } for attr in ["data", "indices"]))
    def testSparseAttrAccess(self, attr):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [make_sparse_array(rng, (10, ), jnp.float32)]
        f = lambda x: getattr(x, attr)
        self._CompileAndCheck(f, args_maker)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype
        } for shape in [(3, 3), (2, 6), (6, 2)]
                            for dtype in jtu.dtypes.floating))
    def testSparseMatvec(self, shape, dtype):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [
            make_sparse_array(rng, shape, dtype),
            rng(shape[-1:], dtype)
        ]
        self._CompileAndCheck(matvec, args_maker)

    def testLowerToNothing(self):
        empty = Empty(AbstractEmpty())
        jaxpr = make_jaxpr(jit(lambda e: e))(empty).jaxpr
        core.check_jaxpr(jaxpr)

        # cannot return a unit, because CompileAndCheck assumes array output.
        testfunc = lambda e: None
        args_maker = lambda: [empty]
        self._CompileAndCheck(testfunc, args_maker)
Пример #19
0
class LaxBackedScipyStatsTests(jtu.JaxTestCase):
  """Tests for LAX-backed scipy.stats implementations"""

  @genNamedParametersNArgs(3, jtu.rand_default)
  def testPoissonLogPmf(self, rng_factory, shapes, dtypes):
    rng = rng_factory()
    scipy_fun = osp_stats.poisson.logpmf
    lax_fun = lsp_stats.poisson.logpmf

    def args_maker():
      k, mu, loc = map(rng, shapes, dtypes)
      k = onp.floor(k)
      # clipping to ensure that rate parameter is strictly positive
      mu = onp.clip(onp.abs(mu), a_min=0.1, a_max=None)
      loc = onp.floor(loc)
      return [k, mu, loc]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-3)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(3, jtu.rand_default)
  def testPoissonPmf(self, rng_factory, shapes, dtypes):
    rng = rng_factory()
    scipy_fun = osp_stats.poisson.pmf
    lax_fun = lsp_stats.poisson.pmf

    def args_maker():
      k, mu, loc = map(rng, shapes, dtypes)
      k = onp.floor(k)
      # clipping to ensure that rate parameter is strictly positive
      mu = onp.clip(onp.abs(mu), a_min=0.1, a_max=None)
      loc = onp.floor(loc)
      return [k, mu, loc]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-3)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(3, jtu.rand_default)
  def testBernoulliLogPmf(self, rng_factory, shapes, dtypes):
    rng = rng_factory()
    scipy_fun = osp_stats.bernoulli.logpmf
    lax_fun = lsp_stats.bernoulli.logpmf

    def args_maker():
      x, logit, loc = map(rng, shapes, dtypes)
      x = onp.floor(x)
      p = expit(logit)
      loc = onp.floor(loc)
      return [x, p, loc]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(5, jtu.rand_positive)
  def testBetaLogPdf(self, rng_factory, shapes, dtypes):
    rng = rng_factory()
    scipy_fun = osp_stats.beta.logpdf
    lax_fun = lsp_stats.beta.logpdf

    def args_maker():
      x, a, b, loc, scale = map(rng, shapes, dtypes)
      return [x, a, b, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-3)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True, rtol=1e-4)

  @genNamedParametersNArgs(3, jtu.rand_default)
  def testCauchyLogPdf(self, rng_factory, shapes, dtypes):
    rng = rng_factory()
    scipy_fun = osp_stats.cauchy.logpdf
    lax_fun = lsp_stats.cauchy.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(2, jtu.rand_positive)
  def testDirichletLogPdf(self, rng_factory, shapes, dtypes):
    rng = rng_factory()
    scipy_fun = osp_stats.cauchy.logpdf
    lax_fun = lsp_stats.cauchy.logpdf
    dim = 4
    shapes = (shapes[0] + (dim,), shapes[1] + (dim,))

    def args_maker():
      x, alpha = map(rng, shapes, dtypes)
      x = x / onp.sum(x, axis=-1, keepdims=True)
      return [x, alpha]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(3, jtu.rand_positive)
  def testExponLogPdf(self, rng_factory, shapes, dtypes):
    rng = rng_factory()
    scipy_fun = osp_stats.expon.logpdf
    lax_fun = lsp_stats.expon.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(4, jtu.rand_positive)
  def testGammaLogPdf(self, rng_factory, shapes, dtypes):
    rng = rng_factory()
    scipy_fun = osp_stats.gamma.logpdf
    lax_fun = lsp_stats.gamma.logpdf

    def args_maker():
      x, a, loc, scale = map(rng, shapes, dtypes)
      return [x, a, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=5e-4)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(3, jtu.rand_positive)
  def testLaplaceLogPdf(self, rng_factory, shapes, dtypes):
    rng = rng_factory()
    scipy_fun = osp_stats.laplace.logpdf
    lax_fun = lsp_stats.laplace.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = onp.clip(scale, a_min=0.1, a_max=None)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(3, jtu.rand_default)
  def testLaplaceCdf(self, rng_factory, shapes, dtypes):
    rng = rng_factory()
    scipy_fun = osp_stats.laplace.cdf
    lax_fun = lsp_stats.laplace.cdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # ensure that scale is not too low
      scale = onp.clip(scale, a_min=0.1, a_max=None)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-6)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(1, jtu.rand_default)
  def testLogisticCdf(self, rng_factory, shapes, dtypes):
    rng = rng_factory()
    scipy_fun = osp_stats.logistic.cdf
    lax_fun = lsp_stats.logistic.cdf

    def args_maker():
      return list(map(rng, shapes, dtypes))

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-6)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(1, jtu.rand_default)
  def testLogisticLogpdf(self, rng_factory, shapes, dtypes):
    rng = rng_factory()
    scipy_fun = osp_stats.logistic.logpdf
    lax_fun = lsp_stats.logistic.logpdf

    def args_maker():
      return list(map(rng, shapes, dtypes))

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-3)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(1, jtu.rand_default)
  def testLogisticPpf(self, rng_factory, shapes, dtypes):
    rng = rng_factory()
    scipy_fun = osp_stats.logistic.ppf
    lax_fun = lsp_stats.logistic.ppf

    def args_maker():
      return list(map(rng, shapes, dtypes))

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(1, jtu.rand_default)
  def testLogisticSf(self, rng_factory, shapes, dtypes):
    rng = rng_factory()
    scipy_fun = osp_stats.logistic.sf
    lax_fun = lsp_stats.logistic.sf

    def args_maker():
      return list(map(rng, shapes, dtypes))

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-6)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(3, jtu.rand_default)
  def testNormLogPdf(self, rng_factory, shapes, dtypes):
    rng = rng_factory()
    scipy_fun = osp_stats.norm.logpdf
    lax_fun = lsp_stats.norm.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-3)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)


  @genNamedParametersNArgs(3, jtu.rand_default)
  def testNormLogCdf(self, rng_factory, shapes, dtypes):
    rng = rng_factory()
    scipy_fun = osp_stats.norm.logcdf
    lax_fun = lsp_stats.norm.logcdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)


  @genNamedParametersNArgs(3, jtu.rand_default)
  def testNormCdf(self, rng_factory, shapes, dtypes):
    rng = rng_factory()
    scipy_fun = osp_stats.norm.cdf
    lax_fun = lsp_stats.norm.cdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-6)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)


  @genNamedParametersNArgs(3, jtu.rand_default)
  def testNormPpf(self, rng_factory, shapes, dtypes):
    rng = rng_factory()
    scipy_fun = osp_stats.norm.ppf
    lax_fun = lsp_stats.norm.ppf

    def args_maker():
      q, loc, scale = map(rng, shapes, dtypes)
      # ensure probability is between 0 and 1:
      q = onp.clip(onp.abs(q / 3), a_min=None, a_max=1)
      # clipping to ensure that scale is not too low
      scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
      return [q, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
                            tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True, rtol=1e-5)


  @genNamedParametersNArgs(4, jtu.rand_positive)
  def testParetoLogPdf(self, rng_factory, shapes, dtypes):
    rng = rng_factory()
    scipy_fun = osp_stats.pareto.logpdf
    lax_fun = lsp_stats.pareto.logpdf

    def args_maker():
      x, b, loc, scale = map(rng, shapes, dtypes)
      return [x, b, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-3)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)


  @genNamedParametersNArgs(4, jtu.rand_default)
  def testTLogPdf(self, rng_factory, shapes, dtypes):
    rng = rng_factory()
    scipy_fun = osp_stats.t.logpdf
    lax_fun = lsp_stats.t.logpdf

    def args_maker():
      x, df, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
      return [x, df, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-3)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)


  @genNamedParametersNArgs(3, jtu.rand_default)
  def testUniformLogPdf(self, rng_factory, shapes, dtypes):
    rng = rng_factory()
    scipy_fun = osp_stats.uniform.logpdf
    lax_fun = lsp_stats.uniform.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      return [x, loc, onp.abs(scale)]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  def testIssue972(self):
    self.assertAllClose(
      onp.ones((4,), onp.float32),
      lsp_stats.norm.cdf(onp.full((4,), onp.inf, onp.float32)),
      check_dtypes=False)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_x={}_mean={}_cov={}".format(
          jtu.format_shape_dtype_string(x_shape, x_dtype),
          jtu.format_shape_dtype_string(mean_shape, mean_dtype)
          if mean_shape is not None else None,
          jtu.format_shape_dtype_string(cov_shape, cov_dtype)
          if cov_shape is not None else None),
       "x_shape": x_shape, "x_dtype": x_dtype,
       "mean_shape": mean_shape, "mean_dtype": mean_dtype,
       "cov_shape": cov_shape, "cov_dtype": cov_dtype,
       "rng_factory": rng_factory}
      for x_shape, mean_shape, cov_shape in [
          # # These test cases cover default values for mean/cov, but we don't
          # # support those yet (and they seem not very valuable).
          # [(), None, None],
          # [(), (), None],
          # [(2,), None, None],
          # [(2,), (), None],
          # [(2,), (2,), None],
          # [(3, 2), (3, 2,), None],
          # [(5, 3, 2), (5, 3, 2,), None],

          [(), (), ()],
          [(3,), (), ()],
          [(3,), (3,), ()],
          [(3,), (3,), (3, 3)],
          [(3, 4), (4,), (4, 4)],

          # # These test cases are where scipy flattens things, which has
          # # different batch semantics than some might expect
          # [(5, 3, 2), (5, 3, 2,), ()],
          # [(5, 3, 2), (5, 3, 2,), (5, 3, 2, 2)],
          # [(5, 3, 2), (3, 2,), (5, 3, 2, 2)],
          # [(5, 3, 2), (3, 2,), (2, 2)],
      ]
      for x_dtype, mean_dtype, cov_dtype in CombosWithReplacement(float_dtypes, 3)
      if (mean_shape is not None or mean_dtype == onp.float32)
      and (cov_shape is not None or cov_dtype == onp.float32)
      for rng_factory in [jtu.rand_default]))
  def testMultivariateNormalLogpdf(self, x_shape, x_dtype, mean_shape,
                                   mean_dtype, cov_shape, cov_dtype, rng_factory):
    rng = rng_factory()
    def args_maker():
      args = [rng(x_shape, x_dtype)]
      if mean_shape is not None:
        args.append(5 * rng(mean_shape, mean_dtype))
      if cov_shape is not None:
        if cov_shape == ():
          args.append(0.1 + rng(cov_shape, cov_dtype) ** 2)
        else:
          factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1])
          factor = rng(factor_shape, cov_dtype)
          args.append(onp.matmul(factor, onp.swapaxes(factor, -1, -2)))
      return args

    self._CheckAgainstNumpy(osp_stats.multivariate_normal.logpdf,
                            lsp_stats.multivariate_normal.logpdf,
                            args_maker, check_dtypes=True, tol=1e-3)
    self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker,
                          check_dtypes=True, rtol=1e-4, atol=1e-4)
Пример #20
0
class LaxBackedScipyTests(jtu.JaxTestCase):
    """Tests for LAX-backed Scipy implementation."""
    def _GetArgsMaker(self, rng, shapes, dtypes):
        return lambda: [
            rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)
        ]

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_axis={}_keepdims={}".format(
                jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
            "rng":
            jtu.rand_default(),
            "shape":
            shape,
            "dtype":
            dtype,
            "axis":
            axis,
            "keepdims":
            keepdims
        } for shape in all_shapes for dtype in float_dtypes
                            for axis in range(-len(shape), len(shape))
                            for keepdims in [False, True]))
    @jtu.skip_on_flag("jax_xla_backend", "xrt")
    def testLogSumExp(self, rng, shape, dtype, axis, keepdims):
        # TODO(mattjj): test autodiff
        def scipy_fun(array_to_reduce):
            return osp_misc.logsumexp(array_to_reduce, axis, keepdims=keepdims)

        def lax_fun(array_to_reduce):
            return lsp_misc.logsumexp(array_to_reduce, axis, keepdims=keepdims)

        args_maker = lambda: [rng(shape, dtype)]
        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                jtu.format_test_name_suffix(rec.test_name, shapes, dtypes),
                "rng":
                rec.rng,
                "shapes":
                shapes,
                "dtypes":
                dtypes,
                "modes":
                rec.diff_modes,
                "scipy_op":
                getattr(osp_special, rec.name),
                "lax_op":
                getattr(lsp_special, rec.name)
            } for rec in JAX_SPECIAL_FUNCTION_RECORDS
            for shapes in CombosWithReplacement(all_shapes, rec.nargs)
            for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs)))
    def testScipySpecialFun(self, scipy_op, lax_op, rng, shapes, dtypes,
                            modes):
        # TODO(mattjj): unskip this test combination when real() on tpu is improved
        # TODO(mattjj): test autodiff
        if (FLAGS.jax_test_dut and FLAGS.jax_test_dut.startswith("tpu")
                and not shapes[0]):
            return absltest.unittest.skip(
                "real() on scalar not supported on tpu")

        args_maker = self._GetArgsMaker(rng, shapes, dtypes)
        args = args_maker()
        self.assertAllClose(scipy_op(*args),
                            lax_op(*args),
                            atol=1e-3,
                            rtol=1e-3,
                            check_dtypes=False)
        self._CompileAndCheck(lax_op, args_maker, check_dtypes=True)
Пример #21
0
class LaxAutodiffTest(jtu.JaxTestCase):

  @parameterized.named_parameters(itertools.chain.from_iterable(
      jtu.cases_from_list(
        {"testcase_name": jtu.format_test_name_suffix(
            rec.name, shapes, itertools.repeat(dtype)),
         "op": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype,
         "order": rec.order, "tol": rec.tol}
        for shape_group in compatible_shapes
        for shapes in itertools.combinations_with_replacement(shape_group, rec.nargs)
        for dtype in rec.dtypes)
      for rec in LAX_GRAD_OPS))
  def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol):
    rng = rng_factory(self.rng())
    if jtu.device_under_test() == "tpu" and op is lax.pow:
      raise SkipTest("pow grad imprecise on tpu")
    tol = jtu.join_tolerance(1e-1, tol) if jtu.num_float_bits(dtype) == 32 else tol
    args = tuple(rng(shape, dtype) for shape in shapes)
    check_grads(op, args, order, ["fwd", "rev"], tol, tol)

  @parameterized.named_parameters(itertools.chain.from_iterable(
      jtu.cases_from_list(
          {"testcase_name": "_{}_{}".format(rec.op.__name__, special_value),
           "op": rec.op, "special_value": special_value, "tol": rec.tol}
          for special_value in rec.values)
      for rec in LAX_GRAD_SPECIAL_VALUE_TESTS))
  def testOpGradSpecialValue(self, op, special_value, tol):
    check_grads(op, (special_value,), 2, ["fwd", "rev"], rtol=tol, atol=tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_from_dtype={}_to_dtype={}".format(
          jtu.dtype_str(from_dtype), jtu.dtype_str(to_dtype)),
       "from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory}
      for from_dtype, to_dtype in itertools.product(inexact_dtypes, repeat=2)
      for rng_factory in [jtu.rand_default]))
  def testConvertElementTypeGrad(self, from_dtype, to_dtype, rng_factory):
    rng = rng_factory(self.rng())
    tol = max(jtu.tolerance(to_dtype, jtu.default_gradient_tolerance),
              jtu.tolerance(from_dtype, jtu.default_gradient_tolerance))
    args = (rng((2, 3), from_dtype),)
    convert_element_type = lambda x: lax.convert_element_type(x, to_dtype)
    convert_element_type = jtu.ignore_warning(category=np.ComplexWarning)(
      convert_element_type)
    check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}".format(
          jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype,
       "rng_factory": rng_factory}
      for shape in [(), (2, 3)]
      for dtype in grad_float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testClampGrad(self, shape, dtype, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    low = operand - dtype(10)
    high = operand + dtype(10)
    # Avoids points near the boundary where the gradient may be inaccurate.
    check_grads(lax.clamp, (operand, low, high), 2, ["fwd", "rev"], eps=1e-2)
    check_grads(lax.clamp, (low, operand, high), 2, ["fwd", "rev"], eps=1e-2)
    check_grads(lax.clamp, (low, high, operand), 2, ["fwd", "rev"], eps=1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format(
          dim, ",".join(str(d) for d in base_shape), np.dtype(dtype).name,
          num_arrs),
       "dim": dim, "base_shape": base_shape, "dtype": dtype,
       "num_arrs": num_arrs, "rng_factory": rng_factory}
      for num_arrs in [3]
      for dtype in float_dtypes
      for base_shape in [(4,), (3, 4), (2, 3, 4)]
      for dim in range(len(base_shape))
      for rng_factory in [jtu.rand_default]))
  def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs, rng_factory):
    rng = rng_factory(self.rng())
    shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:]
              for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))]
    operands = tuple(rng(shape, dtype) for shape in shapes)
    concatenate = lambda *args: lax.concatenate(args, dim)
    check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               strides, padding),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "strides": strides, "padding": padding, "rng_factory": rng_factory,}
       for lhs_shape, rhs_shape, all_strides in itertools.chain(
           [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)])
            for b, i, j in itertools.product([2, 3], repeat=3)],
           [((4, 2, 1), (3, 2, 1), [(1,)])])
       for strides in all_strides
       for dtype in float_dtypes
       for padding in ["VALID", "SAME"]
       for rng_factory in [jtu.rand_small]))
  def testConvGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, rng_factory):
    rng = rng_factory(self.rng())
    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    conv = partial(lax.conv, window_strides=strides, padding=padding,
                   precision=lax.Precision.HIGHEST)
    check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
                         atol=1e-2, rtol=1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
       "rhs_dilation={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               strides, padding, lhs_dil, rhs_dil),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "strides": strides, "padding": padding, "lhs_dil": lhs_dil,
       "rhs_dil": rhs_dil, "rng_factory": rng_factory}
       for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in
       itertools.chain(
           [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)],
             [((0, 0), (0, 0)), ((-1, 0), (0, -1)), ((1, 0), (0, 1))],
             [(1, 1), (2, 1)], [(1, 1)])
            for b, i, j in itertools.product([2, 3], repeat=3)],
           [((4, 2, 1), (3, 2, 1), [(1,)], [((1, 1),), ((0, 0),)],
             [(1,), (2,)], [(1,), (2,)])])
       for strides in all_strides
       for rhs_dil in rhs_dils
       for lhs_dil in lhs_dils
       for dtype in float_dtypes
       for padding in all_pads
       for rng_factory in [jtu.rand_small]))
  def testConvWithGeneralPaddingGrad(self, lhs_shape, rhs_shape, dtype, strides,
                                     padding, lhs_dil, rhs_dil, rng_factory):
    rng = rng_factory(self.rng())
    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    conv = partial(lax.conv_with_general_padding, window_strides=strides,
                   padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
                   precision=lax.Precision.HIGHEST)
    check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
                         atol=1e-2, rtol=1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
       "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums),
               feature_group_count, batch_group_count),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "strides": strides, "padding": padding, "lhs_dil": lhs_dil,
       "rhs_dil": rhs_dil, "rng_factory": rng_factory, "dimension_numbers": dim_nums,
       "perms": perms, "feature_group_count": feature_group_count,
       "batch_group_count": batch_group_count}
      for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)])
      for lhs_shapes, rhs_shape, all_strides, lhs_dils, rhs_dils in [
          ([(b * batch_group_count, i * feature_group_count, 6, 7),
            (b * batch_group_count, i * feature_group_count, 0, 4)],  # lhs_shape
           (j * batch_group_count * feature_group_count, i, 1, 2),  # rhs_shape
           [(1, 1), (1, 2), (2, 1)],  # strides
           [(1, 1), (2, 1)],  # lhs_dils
           [(1, 1), (2, 2)])  # rhs_dils
          for b, i, j in itertools.product([1, 2], repeat=3)]
      for lhs_shape in lhs_shapes
      for strides in all_strides
      for rhs_dil in rhs_dils
      for lhs_dil in lhs_dils
      for dtype in grad_inexact_dtypes
      for padding in ([((0, 0), (0, 0)), ((1, 0), (0, 1))] +
        ([((0, -1), (0, 0))] if lhs_shape[2] != 0 else []))
      for dim_nums, perms in [
          (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
          (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
          (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))]
      for rng_factory in [jtu.rand_default]
  ))
  def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides,
                                 padding, lhs_dil, rhs_dil, dimension_numbers,
                                 perms, feature_group_count, batch_group_count,
                                 rng_factory):
    if dtype == np.float16:
      raise SkipTest("float16 numerical issues")  # TODO(mattjj): resolve

    rng = rng_factory(self.rng())
    tol = {dtypes.bfloat16: 1e-0, np.float16: 5e-1, np.float32: 1e-3}

    # permute shapes to match dim_spec, scale by feature_group_count
    lhs_perm, rhs_perm = perms
    lhs_shape = list(np.take(lhs_shape, lhs_perm))
    rhs_shape = list(np.take(rhs_shape, rhs_perm))

    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    conv = partial(lax.conv_general_dilated, window_strides=strides,
                   padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
                   dimension_numbers=dimension_numbers,
                   feature_group_count=feature_group_count,
                   batch_group_count=batch_group_count,
                   precision=lax.Precision.HIGHEST)
    check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
                         atol=tol, rtol=tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_lhs_shape={}_rhs_shape={}".format(
          jtu.format_shape_dtype_string(lhs_shape, dtype),
          jtu.format_shape_dtype_string(rhs_shape, dtype)),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "rng_factory": jtu.rand_default}
      for lhs_shape in [(2,), (3, 2)] for rhs_shape in [(2,), (2, 4)]
      for dtype in float_dtypes))
  def testDotGrad(self, lhs_shape, rhs_shape, dtype, rng_factory):
    rng = rng_factory(self.rng())
    tol = {np.float16: 1e-1, np.float32: 1e-4}
    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    dot = partial(lax.dot, precision=lax.Precision.HIGHEST)
    check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"],
                         atol=tol, rtol=tol)
    # check that precision config is preserved
    result, pullback = api.vjp(dot, lhs, rhs)
    gresult = lax.zeros_like_array(result)
    s = str(api.make_jaxpr(pullback)(gresult))
    assert "precision=HIGHEST" in s

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_dimension_numbers={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               dimension_numbers),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "dimension_numbers": dimension_numbers, "rng_factory": jtu.rand_small}
      for lhs_shape, rhs_shape, dimension_numbers in [
          ((3, 2), (2, 4), (([1], [0]), ([], []))),
          ((3, 5), (2, 5), (([1], [1]), ([], []))),
          ((5, 3), (5, 2), (([0], [0]), ([], []))),
          ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
          ((3, 5, 2), (2, 4, 5), (([2], [0]), ([1], [2]))),
          ((7, 3, 5, 2), (2, 2, 4, 5), (([3], [0]), ([2], [3]))),
      ]
      for dtype in float_dtypes))
  def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype,
                                          dimension_numbers, rng_factory):
    rng = rng_factory(self.rng())
    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers,
                          precision=lax.Precision.HIGHEST)
    check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"])
    # check that precision config is preserved
    result, pullback = api.vjp(dot_general, lhs, rhs)
    gresult = lax.zeros_like_array(result)
    s = str(api.make_jaxpr(pullback)(gresult))
    assert "precision=HIGHEST" in s

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format(
          shape, np.dtype(dtype).name, broadcast_sizes),
       "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes,
       "rng_factory": rng_factory}
      for shape in [(), (2, 3)]
      for dtype in float_dtypes
      for broadcast_sizes in [(), (2,), (1, 2)]
      for rng_factory in [jtu.rand_default]))
  def testBroadcastGrad(self, shape, dtype, broadcast_sizes, rng_factory):
    rng = rng_factory(self.rng())
    args = (rng(shape, dtype),)
    broadcast = lambda x: lax.broadcast(x, broadcast_sizes)
    check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
          jtu.format_shape_dtype_string(inshape, dtype),
          outshape, broadcast_dimensions),
       "inshape": inshape, "dtype": dtype, "outshape": outshape,
       "dimensions": broadcast_dimensions, "rng_factory": rng_factory}
      for inshape, outshape, broadcast_dimensions in [
          ([2], [2, 2], [0]),
          ([2], [2, 2], [1]),
          ([2], [2, 3], [0]),
          ([], [2, 3], []),
      ]
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(inshape, dtype)
    broadcast_in_dim = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
    check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_outshape={}_perm={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          jtu.format_shape_dtype_string(out_shape, dtype),
          permutation),
       "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
       "rng_factory": rng_factory, "permutation": permutation}
      for dtype in float_dtypes
      for arg_shape, out_shape, permutation in [
          [(3, 4), (12,), None],
          [(2, 1, 4), (8,), None],
          [(2, 2, 4), (2, 8), None],
          [(3, 4), (12,), (0, 1)],
          [(3, 4), (12,), (1, 0)],
          [(2, 1, 4), (8,), (0, 2, 1)],
          [(2, 1, 4), (8,), (2, 0, 1)],
          [(2, 2, 4), (2, 8), (0, 2, 1)],
          [(2, 2, 4), (2, 8), (2, 0, 1)],
      ]
      for rng_factory in [jtu.rand_default]))
  def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(arg_shape, dtype)
    reshape = lambda x: lax.reshape(x, out_shape, permutation)
    check_grads(reshape, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_pads={}"
       .format(jtu.format_shape_dtype_string(shape, dtype), pads),
       "shape": shape, "dtype": dtype, "pads": pads, "rng_factory": jtu.rand_small}
      for shape in [(2, 3)]
      for dtype in float_dtypes
      for pads in [[(1, 2, 1), (0, 1, 0)], [(-1, 0, 0), (-1, 0, 2)]]))
  def testPadGrad(self, shape, dtype, pads, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    pad = lambda operand: lax.pad(operand, np.array(0, dtype), pads)
    check_grads(pad, (operand,), 2, ["fwd", "rev"], eps=1.)

    operand = rng(shape, dtype)
    padding_value = np.array(0., dtype)
    pad = lambda operand, padding_value: lax.pad(operand, padding_value, pads)
    check_grads(pad, (operand, padding_value), 2, ["fwd", "rev"], eps=1.)

  def testReverseGrad(self):
    rev = lambda operand: lax.rev(operand, dimensions)

    dimensions = [0]
    check_grads(rev, (np.array([3., 2., 1.]),), 2)

    dimensions = [0, 1]
    check_grads(rev, (np.array([[6., 5., 4.], [3., 2., 1.]]),), 2,
                rtol={np.float32: 3e-3})

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_predshape={}_argshapes={}".format(
          jtu.format_shape_dtype_string(pred_shape, np.bool_),
          jtu.format_shape_dtype_string(arg_shape, dtype)),
       "pred_shape": pred_shape, "arg_shape": arg_shape, "dtype": dtype,
       "rng_factory": rng_factory}
      for arg_shape in [(), (3,), (2, 3)]
      for pred_shape in ([(), arg_shape] if arg_shape else [()])
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testSelectGrad(self, pred_shape, arg_shape, dtype, rng_factory):
    rng = rng_factory(self.rng())
    pred = rng(pred_shape, np.bool_)
    on_true = rng(arg_shape, dtype)
    on_false = rng(arg_shape, dtype)
    select = lambda on_true, on_false: lax.select(pred, on_true, on_false)
    check_grads(select, (on_true, on_false), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_start_indices={}_limit_indices={}_strides={}".format(
          jtu.format_shape_dtype_string(shape, dtype),
          start_indices, limit_indices, strides),
       "shape": shape, "dtype": dtype, "starts": start_indices,
       "limits": limit_indices, "strides": strides, "rng_factory": rng_factory}
      for shape, start_indices, limit_indices, strides in [
        [(3,), (1,), (2,), None],
        [(7,), (4,), (7,), None],
        [(5,), (1,), (5,), (2,)],
        [(8,), (1,), (6,), (2,)],
        [(5, 3), (1, 1), (3, 2), None],
        [(5, 3), (1, 1), (3, 1), None],
        [(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
        [(5, 3), (1, 1), (2, 1), (1, 1)],
        [(5, 3), (1, 1), (5, 3), (2, 1)],
      ]
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testSliceGrad(self, shape, dtype, starts, limits, strides, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    slice = lambda x: lax.slice(x, starts, limits, strides)
    check_grads(slice, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_start_indices={}_size_indices={}".format(
          jtu.format_shape_dtype_string(shape, dtype),
          start_indices, size_indices),
       "shape": shape, "dtype": dtype, "start_indices": start_indices,
       "size_indices": size_indices, "rng_factory": rng_factory}
      for shape, start_indices, size_indices in [
        [(3,), (1,), (1,)],
        [(5, 3), (1, 1), (3, 1)],
        [(7, 5, 3), (4, 1, 0), (2, 0, 1)],
      ]
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices,
                           rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    dynamic_slice = lambda x: lax.dynamic_slice(x, start_indices, size_indices)
    check_grads(dynamic_slice, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_start_indices={}_update_shape={}".format(
          jtu.format_shape_dtype_string(shape, dtype),
          start_indices, update_shape),
       "shape": shape, "dtype": dtype, "start_indices": start_indices,
       "update_shape": update_shape, "rng_factory": rng_factory}
      for shape, start_indices, update_shape in [
        [(3,), (1,), (1,)],
        [(5, 3), (1, 1), (3, 1)],
        [(7, 5, 3), (4, 1, 0), (2, 0, 1)],
      ]
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices,
                                 update_shape, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    update = rng(update_shape, dtype)
    start_indices = np.array(start_indices)

    dus = lambda x, y: lax.dynamic_update_slice(x, y, start_indices)
    check_grads(dus, (operand, update), 2, ["fwd", "rev"], eps=1.)

    dus = lambda x: lax.dynamic_update_slice(x, update, start_indices)
    check_grads(dus, (operand,), 2, ["fwd", "rev"], eps=1.)

    dus = lambda y: lax.dynamic_update_slice(operand, y, start_indices)
    check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_perm={}".format(
          jtu.format_shape_dtype_string(shape, dtype), perm),
       "shape": shape, "dtype": dtype, "perm": perm, "rng_factory": rng_factory}
      for shape, perm in [
        [(3, 4), (1, 0)],
        [(3, 4), (0, 1)],
        [(3, 4, 5), (2, 1, 0)],
        [(3, 4, 5), (1, 0, 2)],
      ]
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testTransposeGrad(self, shape, dtype, perm, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    transpose = lambda x: lax.transpose(x, perm)
    check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_inshape={}_reducedims={}"
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims),
       "op": op, "init_val": init_val, "shape": shape, "dtype": dtype,
       "dims": dims, "rng_factory": rng_factory}
      for init_val, op, dtypes, rng_factory in [
          (0, lax.add, float_dtypes + jtu.dtypes.complex, jtu.rand_default),
          (-np.inf, lax.max, grad_inexact_dtypes, jtu.rand_unique_int),
          (np.inf, lax.min, grad_inexact_dtypes, jtu.rand_unique_int),
          (1, lax.mul, grad_float_dtypes, partial(jtu.rand_default, scale=1)),
      ]
      for dtype in dtypes
      for shape, dims in [
          [(3, 4, 5), ()],
          [(3, 4, 5), (0,)],
          [(3, 4, 5), (1, 2)],
          [(3, 4, 5), (0, 2)],
          [(3, 4, 5), (0, 1, 2)],
          [(3, 1), (1,)],
          [(3, 0, 5), (1,)],
      ]))
  def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory):
    rng = rng_factory(self.rng())
    if jtu.device_under_test() == "tpu" and op is lax.mul:
      raise SkipTest("unimplemented case")
    tol = {dtypes.bfloat16: 2e-1, np.float16: 1e-1, np.float32: 1e-1,
           np.float64: 1e-3, np.complex64: 1e-1}
    operand = rng(shape, dtype)
    init_val = np.asarray(init_val, dtype=dtype)
    reduce = lambda operand: lax.reduce(operand, init_val, op, dims)
    eps = (1.0 if dtypes.finfo(dtype).bits == 16 and op is lax.add else
           1e-1 if dtype == dtypes.bfloat16 else
           1e-2 if dtypes.finfo(dtype).bits == 32 else None)
    if op not in (lax.max, lax.min) or all(d > 0 for d in shape):
      check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_dtype={}_padding={}"
       .format(op.__name__, np.dtype(dtype).name, padding),
       "op": op, "init_val": init_val, "dtype": dtype, "padding": padding,
       "rng_factory": rng_factory}
      for init_val, op, dtypes, rng_factory in [
          (0, lax.add, grad_float_dtypes, jtu.rand_small),
          (-np.inf, lax.max, grad_float_dtypes, jtu.rand_unique_int),
          (np.inf, lax.min, grad_float_dtypes, jtu.rand_unique_int),
      ]
      for dtype in dtypes
      for padding in ["VALID", "SAME"]))
  @jtu.skip_on_flag("jax_skip_slow_tests", True)
  @jtu.ignore_warning(category=UserWarning,
                      message="Using reduced precision for gradient.*")
  def testReduceWindowGrad(self, op, init_val, dtype, padding, rng_factory):
    rng = rng_factory(self.rng())
    init_val = np.asarray(init_val, dtype=dtype)

    # We need this conditional and the corresponding loop logic to be in the
    # test method, rather than at the parameterized test level, because it
    # depends on FLAGS for the device under test.
    # TODO(b/31565929): enable when fixed.
    if jtu.device_under_test() == "tpu" and op is not lax.add:
      all_configs = [((6, 5, 4, 3), (2, 2, 1, 1), (1, 2, 1, 1))]

      # TODO(b/73062247): need variadic reduce-window for better precision.
      gradient_order = 1
    else:
      all_configs = itertools.chain(
          itertools.product(
              [(4, 6)],  # shapes
              [(2, 1), (1, 2)],  # window_dimensions
              [(1, 1), (2, 1), (1, 2)]  # strides
          ),
          itertools.product(
              [(3, 2, 4, 6)],  # shapes
              [(1, 1, 2, 1), (2, 1, 2, 1)],  # window_dimensions
              [(1, 2, 2, 1), (1, 1, 1, 1)]),  # strides
      )
      gradient_order = 3

    def fun(operand):
      return lax.reduce_window(operand, init_val, op, dims, strides, padding)

    for shape, dims, strides in all_configs:
      operand = rng(shape, dtype)
      if op is lax.add:
        eps = 1.
        tol = None
      else:
        # this test can fail if there are duplicates in operand
        self.assertEqual(np.unique(operand).size, operand.size,
                         msg="test requires operand elements to be unique.")
        eps = 1e-2
        tol = {np.float16: 1e-1, np.float32: 6e-2, np.float64: 6e-2}
      check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], tol, tol,
                  eps)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_shape={}_axis={}"
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis),
       "op": op, "shape": shape, "dtype": dtype,
       "axis": axis, "rng_factory": rng_factory}
      for op, types in [
          (lax.cumsum, [np.float32, np.float64]),
          (lax.cumprod, [np.float32, np.float64]),
      ]
      for dtype in types
      for shape in [[10], [3, 4, 5]]
      for axis in range(len(shape))
      for rng_factory in [
          jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
          else jtu.rand_small]))
  def testCumulativeReduceGrad(self, op, shape, dtype, axis, rng_factory):
    rng = rng_factory(self.rng())
    check_grads(partial(op, axis=axis), (rng(shape, dtype),), order=2)


  # TODO(b/205052657): enable more tests when supported
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_axis={}_isstable={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, is_stable),
       "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "axis": axis,
       "is_stable": is_stable}
      for dtype in [np.float32]
      for shape in [(5,), (5, 7)]
      for axis in [len(shape) - 1]
      for is_stable in [False, True]
      for rng_factory in [jtu.rand_default]))
  def testSortGrad(self, shape, dtype, axis, is_stable, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    sort = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable)
    check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2)

  # TODO(b/205052657): enable more tests when supported
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_keyshape={}_valshape={}_axis={}_isstable={}".format(
          jtu.format_shape_dtype_string(shape, key_dtype),
          jtu.format_shape_dtype_string(shape, val_dtype),
          axis, is_stable),
       "rng_factory": rng_factory, "shape": shape,
       "key_dtype": key_dtype, "val_dtype": val_dtype, "axis": axis,
       "is_stable": is_stable}
      for key_dtype in [np.float32]
      for val_dtype in [np.float32]
      for shape in [(3,), (5, 3)]
      for axis in [len(shape) - 1]
      for is_stable in [False, True]
      for rng_factory in [jtu.rand_default]))
  def testSortKeyValGrad(self, shape, key_dtype, val_dtype, axis, is_stable,
                         rng_factory):
    rng = rng_factory(self.rng())
    # This test relies on the property that wherever keys are tied, values are
    # too, since we don't guarantee the same ordering of values with equal keys.
    # To avoid that case, we generate unique keys (globally in the key array).
    def args_maker():
      flat_keys = np.arange(np.prod(shape, dtype=int), dtype=key_dtype)
      keys = self.rng().permutation(flat_keys).reshape(shape)
      values = rng(shape, val_dtype)
      return keys, values
    keys, values = args_maker()

    fun = lambda keys, values: lax.sort_key_val(keys, values, axis, is_stable)
    check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_k={}".format(
          jtu.format_shape_dtype_string(shape, dtype), k),
       "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k}
      for dtype in [np.float32,]
      for shape in [(4,), (5, 5), (2, 1, 4)]
      for k in [1, 3]
      for rng_factory in [jtu.rand_default]))
  def testTopKGrad(self, shape, dtype, k, rng_factory):
    flat_values = np.arange(np.prod(shape, dtype=int), dtype=dtype)
    values = self.rng().permutation(flat_values).reshape(shape)
    fun = lambda vs: lax.top_k(vs, k=k)[0]
    check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_axes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), idxs, axes),
       "shape": shape, "dtype": dtype, "idxs": idxs, "axes": axes,
       "rng_factory": rng_factory}
      for dtype in float_dtypes
      for shape, idxs, axes in [
          [(3, 4, 5), (np.array([0, 2, 1]),), (0,)],
          [(3, 4, 5), (np.array([-1, -2]),), (0,)],
          [(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 1)],
          [(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 2)],
      ]
      for rng_factory in [jtu.rand_default]))
  def testIndexTakeGrad(self, shape, dtype, idxs, axes, rng_factory):
    rng = rng_factory(self.rng())
    src = rng(shape, dtype)
    index_take = lambda src: lax.index_take(src, idxs, axes)
    check_grads(index_take, (src,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), idxs, dnums,
          slice_sizes),
       "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
       "slice_sizes": slice_sizes, "rng_factory": rng_factory,
       "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for shape, idxs, dnums, slice_sizes, max_idx in [
          ((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers(
            offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1,), 5),
          ((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
            (2,), 9),
          ((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1, 3), 3),
      ]
      for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)]
      for rng_factory in [jtu.rand_default]))
  def testGatherGrad(self, shape, dtype, idxs, dnums, slice_sizes, rng_factory,
                     rng_idx_factory):
    rng = rng_factory(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    gather = lambda x: lax.gather(x, idxs, dimension_numbers=dnums,
                                  slice_sizes=slice_sizes)
    x = rng(shape, dtype)
    check_grads(gather, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory,
       "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for arg_shape, idxs, update_shape, dnums, max_idx in [
          ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,)), 4),
          ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,)), 9),
          ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,)), 3),
      ]
      for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)]
      for rng_factory in [jtu.rand_default]))
  def testScatterAddGrad(self, arg_shape, dtype, idxs, update_shape, dnums,
                         rng_factory, rng_idx_factory):
    rng = rng_factory(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    scatter_add = lambda x, y: lax.scatter_add(x, idxs, y,
                                               dimension_numbers=dnums)
    x = rng(arg_shape, dtype)
    y = rng(update_shape, dtype)
    check_grads(scatter_add, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory,
       "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for arg_shape, idxs, update_shape, dnums, max_idx in [
          ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,)), 4),
          ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,)), 9),
          ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,)), 3),
      ]
      # Scatters with conflicting indices are not deterministic on GPU, so we
      # use indices that do not collide.
      for rng_idx_factory in [partial(jtu.rand_unique_int, high=max_idx)]
      for rng_factory in [jtu.rand_default]))
  def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums,
                      rng_factory, rng_idx_factory):
    rng = rng_factory(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    scatter = lambda x, y: lax.scatter(x, idxs, y, dimension_numbers=dnums)
    x = rng(arg_shape, dtype)
    y = rng(update_shape, dtype)
    check_grads(scatter, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)

  def testScatterGradSymbolicZeroUpdate(self):
    # https://github.com/google/jax/issues/1901
    def f(x):
      n = x.shape[0]
      y = np.arange(n, dtype=x.dtype)
      return jax.ops.index_update(x, np.diag_indices(n), y)
    rng = jtu.rand_default(self.rng())
    check_grads(f, (rng((5, 5), np.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2,
                1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums,
       "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for arg_shape, idxs, update_shape, dnums in [
          ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
          ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,))),
          ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
      ]
      for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))]
      for rng_factory in [jtu.rand_default]))
  def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums,
                     rng_factory, rng_idx_factory):
    rng = rng_factory(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    scatter_max = lambda x, y: lax.scatter_max(x, idxs, y, dnums)
    x = rng(arg_shape, dtype)
    y = rng(update_shape, dtype)
    check_grads(scatter_max, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums,
       "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for arg_shape, idxs, update_shape, dnums in [
          ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
          ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,))),
          ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
      ]
      for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))]
      for rng_factory in [jtu.rand_default]))
  def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums,
                     rng_factory, rng_idx_factory):
    rng = rng_factory(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    scatter_min = lambda x, y: lax.scatter_min(x, idxs, y, dnums)
    x = rng(arg_shape, dtype)
    y = rng(update_shape, dtype)
    check_grads(scatter_min, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)

  def testStopGradient(self):
    def f(x):
      return lax.sin(x) * lax.cos(lax.stop_gradient(x))

    def f2(x, y):
      return lax.sin(x) * lax.cos(y)

    x = 3.14
    ans = api.grad(f)(x)
    expected = api.grad(f2)(x, x)
    self.assertAllClose(ans, expected)

    ans = api.grad(api.grad(f))(x)
    expected = api.grad(api.grad(f2))(x, x)
    self.assertAllClose(ans, expected)

    ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.)
    expected = np.array(0.0)
    self.assertAllClose(ans, expected, check_dtypes=False)

    with core.skipping_checks():
      with self.assertRaises(TypeError):
        lax.stop_gradient(lambda x: x)

  # TODO(mattjj): make this a more systematic test
  def testRemainder(self):
    rng = np.random.RandomState(0)
    x = rng.uniform(-0.9, 9, size=(3, 4))
    y = rng.uniform(0.7, 1.9, size=(3, 1))
    assert not set(np.unique(x)) & set(np.unique(y))
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3
    check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol)

    rng = np.random.RandomState(0)
    x = rng.uniform(-0.9, 9, size=(1, 4))
    y = rng.uniform(0.7, 1.9, size=(3, 4))
    assert not set(np.unique(x)) & set(np.unique(y))
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3
    check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol)

  def testHigherOrderGradientOfReciprocal(self):
    # Regression test for https://github.com/google/jax/issues/3136
    def inv(x):
      # N.B.: intentionally written as 1/x, not x ** -1 or reciprocal(x)
      return 1 / x
    grad_fn = jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(inv))))))
    self.assertAllClose(np.float32(0.0439453125), grad_fn(np.float32(4.)))
Пример #22
0
class DLPackTest(jtu.JaxTestCase):
    def setUp(self):
        if jtu.device_under_test() == "tpu":
            self.skipTest("DLPack not supported on TPU")

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype
        } for shape in all_shapes for dtype in dlpack_dtypes))
    def testJaxRoundTrip(self, shape, dtype):
        rng = jtu.rand_default()
        np = rng(shape, dtype)
        x = jnp.array(np)
        dlpack = jax.dlpack.to_dlpack(x)
        y = jax.dlpack.from_dlpack(dlpack)
        self.assertAllClose(np.astype(x.dtype), y, check_dtypes=True)

        self.assertRaisesRegex(RuntimeError,
                               "DLPack tensor may be consumed at most once",
                               lambda: jax.dlpack.from_dlpack(dlpack))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype
        } for shape in all_shapes for dtype in torch_dtypes))
    @unittest.skipIf(not torch, "Test requires PyTorch")
    def testTorchToJax(self, shape, dtype):
        rng = jtu.rand_default()
        np = rng(shape, dtype)
        x = torch.from_numpy(np)
        x = x.cuda() if jtu.device_under_test() == "gpu" else x
        dlpack = torch.utils.dlpack.to_dlpack(x)
        y = jax.dlpack.from_dlpack(dlpack)
        self.assertAllClose(np, y, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype
        } for shape in all_shapes for dtype in torch_dtypes))
    @unittest.skipIf(not torch, "Test requires PyTorch")
    def testJaxToTorch(self, shape, dtype):
        rng = jtu.rand_default()
        np = rng(shape, dtype)
        x = jnp.array(np)
        dlpack = jax.dlpack.to_dlpack(x)
        y = torch.utils.dlpack.from_dlpack(dlpack)
        self.assertAllClose(np, y.numpy(), check_dtypes=True)
Пример #23
0
class DoubleDoubleTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_{}".format(op.__name__,
                            jtu.format_shape_dtype_string(shape, dtype)),
            "dtype":
            dtype,
            "shape":
            shape,
            "op":
            op
        } for dtype in (jnp.float16, jnp.float32, jnp.float64)
                            for shape in ((), (5, ), (2, 3), (2, 3, 4))
                            for op in (abs, operator.neg, operator.pos,
                                       jnp.sqrt)))
    def testUnaryOp(self, dtype, shape, op):
        rng = jtu.rand_default(self.rng())
        op_doubled = doubledouble(op)
        args = (rng(shape, dtype), )
        self.assertAllClose(op(*args), op_doubled(*args))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_{}".format(op.__name__,
                            jtu.format_shape_dtype_string(shape, dtype)),
            "dtype":
            dtype,
            "shape":
            shape,
            "op":
            op
        } for dtype in (jnp.float16, jnp.float32, jnp.float64)
                            for shape in ((), (5, ), (2, 3), (2, 3, 4))
                            for op in (operator.add, operator.sub,
                                       operator.mul, operator.truediv,
                                       operator.gt, operator.ge, operator.lt,
                                       operator.le, operator.eq, operator.ne)))
    def testBinaryOp(self, dtype, shape, op):
        rng = jtu.rand_default(self.rng())
        op_doubled = doubledouble(op)
        args = rng(shape, dtype), rng(shape, dtype)
        self.assertAllClose(op(*args), op_doubled(*args))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_{}_{}".format(jtu.format_shape_dtype_string(shape, dtype),
                                label),
                "shape":
                shape,
                "dtype":
                dtype,
                "op1":
                op1,
                "op2":
                op2
            } for dtype in (jnp.float32, jnp.float64)
            for shape in ((), (5, ), (2, 3), (2, 3, 4))
            for label, op1, op2 in [
                ('add_sub', lambda x, y: x + y - x, lambda x, y: y),
                ("add_neg_add", lambda x, y: -(x + y) + x, lambda x, y: -y),
                ("add_mul_sub", lambda x, y: 2 * (x + y) - 2 * x,
                 lambda x, y: 2 * y),
                ("add_div_sub", lambda x, y: (x + y) / 2 - x / 2,
                 lambda x, y: y / 2),
            ]))
    def testDoubledPrecision(self, shape, dtype, op1, op2):
        """Test operations that would lose precision without doubling."""
        rng = jtu.rand_default(self.rng())
        double_op1 = doubledouble(op1)
        args = 1E20 * rng(shape, dtype), rng(shape, dtype)
        check_dtypes = not FLAGS.jax_enable_x64

        self.assertAllClose(double_op1(*args),
                            op2(*args),
                            check_dtypes=check_dtypes)

        # Sanity check: make sure test fails for regular precision.
        with self.assertRaisesRegex(AssertionError, "Not equal to tolerance"):
            self.assertAllClose(op1(*args),
                                op2(*args),
                                check_dtypes=check_dtypes)

    def testTypeConversion(self):
        x = jnp.arange(10, dtype='float16')
        f = lambda x, y: (x + y).astype('float32')
        g = doubledouble(f)
        self.assertAllClose(f(1E2 * x, 1E-2 * x), 1E2 * x.astype('float32'))
        self.assertAllClose(g(1E2 * x, 1E-2 * x), 100.01 * x.astype('float32'))

    def testRepeatedDoubling(self):
        def f(x, y, z):
            return x + y + z - x - y

        f2 = doubledouble(f)
        f4 = doubledouble(f2)
        dtype = jnp.float32
        x, y, z = dtype(1E20), dtype(1.0), dtype(1E-20)

        self.assertEqual(f(x, y, z), -y)
        self.assertEqual(f2(x, y, z), 0)
        self.assertEqual(f4(x, y, z), z)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_{}_{}".format(dtype, val),
                "dtype": dtype,
                "val": val
            } for dtype in ["float16", "float32", "float64"]
            for val in ["6.0221409e23", "3.14159265358", "0", 123456789]))
    def testClassInstantiation(self, dtype, val):
        dtype = jnp.dtype(dtype).type
        self.assertEqual(dtype(val), _DoubleDouble(val, dtype).to_array())

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_{}".format(jtu.format_shape_dtype_string(shape, dtype),
                            op.__name__),
            "shape":
            shape,
            "dtype":
            dtype,
            "op":
            op
        } for dtype in (jnp.float32, jnp.float64)
                            for shape in ((), (5, ), (2, 3), (2, 3, 4))
                            for op in (operator.neg, operator.abs)))
    def testClassUnaryOp(self, dtype, shape, op):
        rng = jtu.rand_default(self.rng())
        args = (rng(shape, dtype), )
        class_op = lambda x: op(_DoubleDouble(x)).to_array()
        self.assertAllClose(op(*args), class_op(*args))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_{}".format(jtu.format_shape_dtype_string(shape, dtype),
                            op.__name__),
            "shape":
            shape,
            "dtype":
            dtype,
            "op":
            op
        } for dtype in (jnp.float32, jnp.float64)
                            for shape in ((), (5, ), (2, 3), (2, 3, 4))
                            for op in (operator.add, operator.sub,
                                       operator.mul, operator.truediv,
                                       operator.gt, operator.ge, operator.lt,
                                       operator.le, operator.eq, operator.ne)))
    def testClassBinaryOp(self, dtype, shape, op):
        rng = jtu.rand_default(self.rng())
        args = rng(shape, dtype), rng(shape, dtype)

        def class_op(x, y):
            result = op(_DoubleDouble(x), _DoubleDouble(y))
            if isinstance(result, _DoubleDouble):
                result = result.to_array()
            return result

        self.assertAllClose(op(*args), class_op(*args))
Пример #24
0
class IndexingTest(jtu.JaxTestCase):
    """Tests for Numpy indexing translation rules."""
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "{}_inshape={}_indexer={}".format(
                name, jtu.format_shape_dtype_string(shape, dtype), indexer),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng,
            "indexer":
            indexer
        } for name, index_specs in STATIC_INDEXING_TESTS
                            for shape, indexer in index_specs
                            for dtype in all_dtypes
                            for rng in [jtu.rand_default()]))
    def testStaticIndexing(self, shape, dtype, rng, indexer):
        args_maker = lambda: [rng(shape, dtype)]
        fun = lambda x: x[indexer]
        self._CompileAndCheck(fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng":
        rng,
        "indexer":
        indexer
    } for name, index_specs in STATIC_INDEXING_GRAD_TESTS
                                    for shape, indexer in index_specs
                                    for dtype in float_dtypes
                                    for rng in [jtu.rand_default()])
    def testStaticIndexingGrads(self, shape, dtype, rng, indexer):
        tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
        arg = rng(shape, dtype)
        fun = lambda x: x[indexer]**2
        check_grads(fun, (arg, ), 2, tol, tol, tol)

    def _ReplaceSlicesWithTuples(self, idx):
        """Helper method to replace slices with tuples for dynamic indexing args."""
        if isinstance(idx, slice):
            triple = idx.start, idx.stop, idx.step
            isnone = [i for i, elt in enumerate(triple) if elt is None]
            zeros = itertools.repeat(0)
            nones = itertools.repeat(None)
            out = lax.subvals(triple, zip(isnone, zeros))
            return out, lambda out: slice(*lax.subvals(out, zip(isnone, nones))
                                          )
        elif isinstance(idx, (tuple, list)) and idx:
            t = type(idx)
            elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx))
            return elts, lambda elts: t(
                (pack(i) for pack, i in zip(packs, elts)))
        else:
            return idx, lambda x: x

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng":
        rng,
        "indexer":
        indexer
    } for name, index_specs in [
        ("OneSliceIndex", [
            IndexSpec(shape=(5, ), indexer=slice(1, 3)),
            IndexSpec(shape=(5, 4), indexer=slice(1, 3))
        ]),
        ("TwoSliceIndices", [
            IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))),
            IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2)))
        ]),
        ("NonUnitStrides", [
            IndexSpec(shape=(3, ), indexer=slice(None, None, -1)),
            IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)),
            IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2))
        ]),
        ("OnlyStartOrStopDynamic", [
            IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))),
            IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None)))
        ]),
    ] for shape, indexer in index_specs for dtype in all_dtypes
                                    for rng in [jtu.rand_default()])
    def testDynamicIndexingWithSlicesErrors(self, shape, dtype, rng, indexer):
        unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

        @api.jit
        def fun(x, unpacked_indexer):
            indexer = pack_indexer(unpacked_indexer)
            return x[indexer]

        args_maker = lambda: [rng(shape, dtype), unpacked_indexer]
        self.assertRaises(IndexError, lambda: fun(*args_maker()))

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng":
        rng,
        "indexer":
        indexer
    } for name, index_specs in [
        ("OneIntIndex", [
            IndexSpec(shape=(3, ), indexer=1),
            IndexSpec(shape=(3, 3), indexer=0),
            IndexSpec(shape=(3, 4, 5), indexer=2),
            IndexSpec(shape=(3, ), indexer=-1),
            IndexSpec(shape=(3, ), indexer=-2)
        ]),
        ("TwoIntIndices", [
            IndexSpec(shape=(3, 3), indexer=(2, 1)),
            IndexSpec(shape=(3, 4, 5), indexer=(1, 2)),
            IndexSpec(shape=(3, 4, 5), indexer=(-1, 2))
        ]),
        ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]),
    ] for shape, indexer in index_specs for dtype in all_dtypes
                                    for rng in [jtu.rand_default()])
    def testDynamicIndexingWithIntegers(self, shape, dtype, rng, indexer):
        unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

        def fun(x, unpacked_indexer):
            indexer = pack_indexer(unpacked_indexer)
            return x[indexer]

        args_maker = lambda: [rng(shape, dtype), unpacked_indexer]
        self._CompileAndCheck(fun, args_maker, check_dtypes=True)

    @unittest.skip
    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng":
        rng,
        "indexer":
        indexer
    } for name, index_specs in [
        ("OneIntIndex", [
            IndexSpec(shape=(3, ), indexer=1),
            IndexSpec(shape=(3, 3), indexer=0),
            IndexSpec(shape=(3, 4, 5), indexer=2),
            IndexSpec(shape=(3, ), indexer=-1),
            IndexSpec(shape=(3, ), indexer=-2),
        ]),
        ("TwoIntIndices", [
            IndexSpec(shape=(3, 3), indexer=(2, 1)),
            IndexSpec(shape=(3, 4, 5), indexer=(1, 2)),
            IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)),
        ]),
        ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]),
    ] for shape, indexer in index_specs for dtype in float_dtypes
                                    for rng in [jtu.rand_default()])
    def DISABLED_testDynamicIndexingWithIntegersGrads(self, shape, dtype, rng,
                                                      indexer):
        # TODO(mattjj): re-enable (test works but for grad-of-compile, in flux)
        tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
        unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

        @api.jit
        def fun(unpacked_indexer, x):
            indexer = pack_indexer(unpacked_indexer)
            return x[indexer]

        arr = rng(shape, dtype)
        check_grads(partial(fun, unpacked_indexer), (arr, ), 2, tol, tol, tol)

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng":
        rng,
        "indexer":
        indexer
    } for name, index_specs in ADVANCED_INDEXING_TESTS
                                    for shape, indexer in index_specs
                                    for dtype in all_dtypes
                                    for rng in [jtu.rand_default()])
    def testAdvancedIntegerIndexing(self, shape, dtype, rng, indexer):
        args_maker = lambda: [rng(shape, dtype), indexer]
        fun = lambda x, idx: x[idx]
        self._CompileAndCheck(fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng":
        rng,
        "indexer":
        indexer
    } for name, index_specs in [
        ("One1DIntArrayIndex", [
            IndexSpec(shape=(3, ), indexer=onp.array([0, 1])),
            IndexSpec(shape=(3, 3), indexer=onp.array([1, 2, 1])),
            IndexSpec(shape=(3, 4, 5), indexer=onp.array([0, 2, 0, 1])),
            IndexSpec(shape=(3, ), indexer=onp.array([-1, 1])),
            IndexSpec(shape=(3, ), indexer=onp.array([-2, -1])),
        ]),
        ("One2DIntArrayIndex", [
            IndexSpec(shape=(3, ), indexer=onp.array([[0, 0]])),
            IndexSpec(shape=(3,
                             3), indexer=onp.array([[1, 2, 1], [0, 1, -1]])),
            IndexSpec(shape=(3, 4, 5),
                      indexer=onp.array([[0, 2, 0, 1], [-1, -2, 1, 0]])),
        ]),
        ("Two1DIntArrayIndicesNoBroadcasting", [
            IndexSpec(shape=(3, 3),
                      indexer=[onp.array([0, 1]),
                               onp.array([1, 2])]),
            IndexSpec(
                shape=(3,
                       4, 5),
                indexer=[onp.array([0, 2, 0, 1]),
                         onp.array([-1, 0, -1, 2])]),
        ]),
        ("Two1DIntArrayIndicesWithBroadcasting", [
            IndexSpec(shape=(3, 3),
                      indexer=[onp.array([[0, 1]]),
                               onp.array([1, 2])]),
            IndexSpec(shape=(3, 4, 5),
                      indexer=[
                          onp.array([[0, 2, 0, 1]]),
                          onp.array([-1, 0, -1, 2])
                      ]),
        ]),
        ("ListOfPythonInts", [
            IndexSpec(shape=(3, ), indexer=[0, 1, 0]),
            IndexSpec(shape=(3, 4, 5), indexer=[0, -1]),
        ]),
        ("ListOfListsOfPythonInts", [
            IndexSpec(shape=(3, 4, 5), indexer=[[0, 1]]),
            IndexSpec(shape=(3, 4, 5), indexer=[[[0], [-1]], [[2, 3, 0, 3]]]),
        ]),
        ("ListOfPythonIntsAndIntArrays", [
            IndexSpec(shape=(3, 4, 5), indexer=[0, onp.array([0, 1])]),
            IndexSpec(shape=(3, 4, 5),
                      indexer=[0, 1, onp.array([[2, 3, 0, 3]])]),
        ]),
        ("ListOfListsOfPythonIntsAndIntArrays", [
            IndexSpec(shape=(3, 4, 5), indexer=[[0, 1], onp.array([0])]),
            IndexSpec(shape=(3, 4, 5),
                      indexer=[[[0], [-1]],
                               onp.array([[2, 3, 0, 3]])]),
        ]),
    ] for shape, indexer in index_specs for dtype in float_dtypes
                                    for rng in [jtu.rand_default()])
    def testAdvancedIntegerIndexingGrads(self, shape, dtype, rng, indexer):
        tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
        arg = rng(shape, dtype)
        fun = lambda x: x[indexer]**2
        check_grads(fun, (arg, ), 2, tol, tol, tol)

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng":
        rng,
        "indexer":
        indexer
    } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS
                                    for shape, indexer in index_specs
                                    for dtype in all_dtypes
                                    for rng in [jtu.rand_default()])
    def testMixedAdvancedIntegerIndexing(self, shape, dtype, rng, indexer):
        indexer_with_dummies = [
            e if isinstance(e, onp.ndarray) else () for e in indexer
        ]
        substitutes = [(i, e) for i, e in enumerate(indexer)
                       if not isinstance(e, onp.ndarray)]
        args_maker = lambda: [rng(shape, dtype), indexer_with_dummies]

        def fun(x, indexer_with_dummies):
            idx = type(indexer)(lax.subvals(indexer_with_dummies, substitutes))
            return x[idx]

        self._CompileAndCheck(fun, args_maker, check_dtypes=True)

    def testAdvancedIndexingManually(self):
        x = onp.random.RandomState(0).randn(3, 4, 5)
        index_array = onp.array([0, 2, -1, 0])

        op = lambda x, index_array: x[..., index_array, :]
        cop = api.jit(op)

        a1 = op(x, index_array)
        a2 = cop(x, index_array)

        self.assertAllClose(a1, a2, check_dtypes=True)

        op = lambda x, index_array: x[..., index_array, :, index_array, None]
        cop = api.jit(op)

        a1 = op(x, index_array)
        a2 = cop(x, index_array)

        self.assertAllClose(a1, a2, check_dtypes=True)

        op = lambda x, index_array: x[index_array, ..., index_array[:, None],
                                      None]
        cop = api.jit(op)

        a1 = op(x, index_array)
        a2 = cop(x, index_array)

        self.assertAllClose(a1, a2, check_dtypes=True)

    def testUnpacking(self):
        def foo(x):
            a, b, c = x
            return a + b + c

        cfoo = api.jit(foo)

        a1 = foo(onp.arange(3))
        a2 = cfoo(onp.arange(3))

        self.assertAllClose(a1, a2, check_dtypes=True)

    def testBooleanIndexingArray1D(self):
        idx = onp.array([True, True, False])
        x = api.device_put(onp.arange(3))
        ans = x[idx]
        expected = onp.arange(3)[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testBooleanIndexingList1D(self):
        idx = [True, True, False]
        x = api.device_put(onp.arange(3))
        ans = x[idx]
        expected = onp.arange(3)[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testBooleanIndexingArray2DBroadcast(self):
        idx = onp.array([True, True, False, True])
        x = onp.arange(8).reshape(4, 2)
        ans = api.device_put(x)[idx]
        expected = x[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testBooleanIndexingList2DBroadcast(self):
        idx = [True, True, False, True]
        x = onp.arange(8).reshape(4, 2)
        ans = api.device_put(x)[idx]
        expected = x[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testBooleanIndexingArray2D(self):
        idx = onp.array([[True, False], [False, True], [False, False],
                         [True, True]])
        x = onp.arange(8).reshape(4, 2)
        ans = api.device_put(x)[idx]
        expected = x[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testBooleanIndexingDynamicShapeError(self):
        x = onp.zeros(3)
        i = onp.array([True, True, False])
        self.assertRaises(IndexError, lambda: api.jit(lambda x, i: x[i])(x, i))

    def testIssue187(self):
        x = lnp.ones((5, 5))
        x[[0, 2, 4], [0, 2, 4]]  # doesn't crash

        x = onp.arange(25).reshape((5, 5))
        ans = api.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x)
        expected = x[[0, 2, 4], [0, 2, 4]]
        self.assertAllClose(ans, expected, check_dtypes=False)
Пример #25
0
class ScipyLinalgTest(jtu.JaxTestCase):

  # TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype, "rng": rng}
      for shape in [(1, 1), (4, 5), (10, 5), (50, 50)]
      for dtype in float_types() + complex_types()
      for rng in [jtu.rand_default()]))
  @jtu.skip_on_devices("gpu", "tpu")
  def testLu(self, shape, dtype, rng):
    args_maker = lambda: [rng(shape, dtype)]

    self._CheckAgainstNumpy(jsp.linalg.lu, osp.linalg.lu, args_maker,
                            check_dtypes=True, tol=1e-3)
    self._CompileAndCheck(jsp.linalg.lu, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype, "rng": rng}
      for shape in [(1, 1), (4, 5), (10, 5), (10, 10)]
      for dtype in float_types() + complex_types()
      for rng in [jtu.rand_default()]))
  # TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
  @jtu.skip_on_devices("gpu", "tpu")
  def testLuGrad(self, shape, dtype, rng):
    a = rng(shape, dtype)

    jtu.check_grads(jsp.linalg.lu, (a,), 2, rtol=1e-1)

  @jtu.skip_on_devices("gpu", "tpu")
  def testLuBatching(self):
    self.skipTest("Test disabled until Jaxlib 0.1.14 is released")
    shape = (4, 5)
    dtype = np.float32
    rng = jtu.rand_default()
    args = [rng(shape, np.float32) for _ in range(10)]
    expected = list(osp.linalg.lu(x) for x in args)
    ps = onp.stack([out[0] for out in expected])
    ls = onp.stack([out[1] for out in expected])
    us = onp.stack([out[2] for out in expected])

    actual_ps, actual_ls, actual_us = vmap(jsp.linalg.lu)(np.stack(args))
    self.assertAllClose(ps, actual_ps, check_dtypes=True)
    self.assertAllClose(ls, actual_ls, check_dtypes=True)
    self.assertAllClose(us, actual_us, check_dtypes=True)

  # TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
       "n": n, "dtype": dtype, "rng": rng}
      for n in [1, 4, 5, 200]
      for dtype in float_types() + complex_types()
      for rng in [jtu.rand_default()]))
  @jtu.skip_on_devices("gpu", "tpu")
  def testLuFactor(self, n, dtype, rng):
    args_maker = lambda: [rng((n, n), dtype)]

    self._CheckAgainstNumpy(jsp.linalg.lu_factor, osp.linalg.lu_factor,
                            args_maker, check_dtypes=True, tol=1e-3)
    self._CompileAndCheck(jsp.linalg.lu_factor, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs={}_rhs={}_sym_pos={}_lower={}".format(
           jtu.format_shape_dtype_string(lhs_shape, dtype),
           jtu.format_shape_dtype_string(rhs_shape, dtype),
           sym_pos, lower),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "sym_pos": sym_pos, "lower": lower, "rng": rng}
      for lhs_shape, rhs_shape in [
          ((1, 1), (1, 1)),
          ((4, 4), (4,)),
          ((8, 8), (8, 4)),
      ]
      for sym_pos, lower in [
        (False, False),
        (True, False),
        (True, True),
      ]
      for dtype in float_types() + complex_types()
      for rng in [jtu.rand_default()]))
  # TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
  @jtu.skip_on_devices("gpu", "tpu")
  def testSolve(self, lhs_shape, rhs_shape, dtype, sym_pos, lower, rng):
    osp_fun = lambda lhs, rhs: osp.linalg.solve(lhs, rhs, sym_pos=sym_pos, lower=lower)
    jsp_fun = lambda lhs, rhs: jsp.linalg.solve(lhs, rhs, sym_pos=sym_pos, lower=lower)

    def args_maker():
      a = rng(lhs_shape, dtype)
      if sym_pos:
        a = onp.matmul(a, onp.conj(T(a)))
        a = onp.tril(a) if lower else onp.triu(a)
      return [a, rng(rhs_shape, dtype)]

    self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
                            check_dtypes=True, tol=1e-3)
    self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs={}_rhs={}_lower={}_transposea={}".format(
           jtu.format_shape_dtype_string(lhs_shape, dtype),
           jtu.format_shape_dtype_string(rhs_shape, dtype),
           lower, transpose_a),
       "lower": lower, "transpose_a": transpose_a,
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "rng": rng}
      for lower, transpose_a in itertools.product([False, True], repeat=2)
      for lhs_shape, rhs_shape in [
          ((4, 4), (4,)),
          ((4, 4), (4, 3)),
          ((2, 8, 8), (2, 8, 10)),
      ]
      for dtype in float_types()
      for rng in [jtu.rand_default()]))
  def testSolveTriangular(self, lower, transpose_a, lhs_shape, rhs_shape, dtype,
                          rng):
    k = rng(lhs_shape, dtype)
    l = onp.linalg.cholesky(onp.matmul(k, T(k))
                            + lhs_shape[-1] * onp.eye(lhs_shape[-1]))
    l = l.astype(k.dtype)
    b = rng(rhs_shape, dtype)

    a = l if lower else T(l)
    inv = onp.linalg.inv(T(a) if transpose_a else a).astype(a.dtype)
    if len(lhs_shape) == len(rhs_shape):
      onp_ans = onp.matmul(inv, b)
    else:
      onp_ans = onp.einsum("...ij,...j->...i", inv, b)

    # The standard scipy.linalg.solve_triangular doesn't support broadcasting.
    # But it seems like an inevitable extension so we support it.
    ans = jsp.linalg.solve_triangular(
        l if lower else T(l), b, trans=1 if transpose_a else 0, lower=lower)

    self.assertAllClose(onp_ans, ans, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs={}_rhs={}_lower={}_transposea={}".format(
           jtu.format_shape_dtype_string(lhs_shape, dtype),
           jtu.format_shape_dtype_string(rhs_shape, dtype),
           lower, transpose_a),
       "lower": lower, "transpose_a": transpose_a,
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "rng": rng}
      for lower, transpose_a in itertools.product([False, True], repeat=2)
      for lhs_shape, rhs_shape in [
          ((4, 4), (4,)),
          ((4, 4), (4, 3)),
          ((2, 8, 8), (2, 8, 10)),
      ]
      for dtype in float_types()
      for rng in [jtu.rand_default()]))
  def testSolveTriangularGrad(self, lower, transpose_a, lhs_shape,
                                     rhs_shape, dtype, rng):
    # TODO(frostig): change ensemble to support a bigger rtol
    self.skipTest("rtol does not cover all devices and precision modes")
    A = np.tril(rng(lhs_shape, dtype) + 5 * onp.eye(lhs_shape[-1], dtype=dtype))
    A = A if lower else T(A)
    B = rng(rhs_shape, dtype)
    f = partial(jsp.linalg.solve_triangular, lower=lower,
                trans=1 if transpose_a else 0)
    jtu.check_grads(f, (A, B), 2, rtol=1e-3)
Пример #26
0
class ScipyLinalgTest(jtu.JaxTestCase):

    # TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(1, 1), (4, 5), (10, 5), (50, 50)]
                            for dtype in float_types() | complex_types()
                            for rng in [jtu.rand_default()]))
    @jtu.skip_on_devices("gpu", "tpu")
    def testLu(self, shape, dtype, rng):
        # TODO(phawkins): remove this after a jaxlib release.
        if not hasattr(lapack, "jax_getrf"):
            self.skipTest("No LU implementation available")
        args_maker = lambda: [rng(shape, dtype)]

        self._CheckAgainstNumpy(jsp.linalg.lu,
                                osp.linalg.lu,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(jsp.linalg.lu, args_maker, check_dtypes=True)

    # TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_n={}".format(jtu.format_shape_dtype_string((n, n), dtype)),
            "n":
            n,
            "dtype":
            dtype,
            "rng":
            rng
        } for n in [1, 4, 5, 200] for dtype in float_types() | complex_types()
                            for rng in [jtu.rand_default()]))
    @jtu.skip_on_devices("gpu", "tpu")
    def testLuFactor(self, n, dtype, rng):
        # TODO(phawkins): remove this after a jaxlib release.
        if not hasattr(lapack, "jax_getrf"):
            self.skipTest("No LU implementation available")
        args_maker = lambda: [rng((n, n), dtype)]

        self._CheckAgainstNumpy(jsp.linalg.lu_factor,
                                osp.linalg.lu_factor,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(jsp.linalg.lu_factor,
                              args_maker,
                              check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs={}_rhs={}_lower={}_transposea={}".format(
                jtu.format_shape_dtype_string(lhs_shape, dtype),
                jtu.format_shape_dtype_string(rhs_shape, dtype), lower,
                transpose_a),
            "lower":
            lower,
            "transpose_a":
            transpose_a,
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for lower, transpose_a in itertools.product([False, True], repeat=2)
                            for lhs_shape, rhs_shape in [
                                ((4, 4), (4, )),
                                ((4, 4), (4, 3)),
                                ((2, 8, 8), (2, 8, 10)),
                            ] for dtype in float_types()
                            for rng in [jtu.rand_default()]))
    def testSolveTriangular(self, lower, transpose_a, lhs_shape, rhs_shape,
                            dtype, rng):
        k = rng(lhs_shape, dtype)
        l = onp.linalg.cholesky(
            onp.matmul(k, T(k)) + lhs_shape[-1] * onp.eye(lhs_shape[-1]))
        l = l.astype(k.dtype)
        b = rng(rhs_shape, dtype)

        a = l if lower else T(l)
        inv = onp.linalg.inv(T(a) if transpose_a else a).astype(a.dtype)
        if len(lhs_shape) == len(rhs_shape):
            onp_ans = onp.matmul(inv, b)
        else:
            onp_ans = onp.einsum("...ij,...j->...i", inv, b)

        # The standard scipy.linalg.solve_triangular doesn't support broadcasting.
        # But it seems like an inevitable extension so we support it.
        ans = jsp.linalg.solve_triangular(l if lower else T(l),
                                          b,
                                          trans=1 if transpose_a else 0,
                                          lower=lower)

        self.assertAllClose(onp_ans, ans, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs={}_rhs={}_lower={}_transposea={}".format(
                jtu.format_shape_dtype_string(lhs_shape, dtype),
                jtu.format_shape_dtype_string(rhs_shape, dtype), lower,
                transpose_a),
            "lower":
            lower,
            "transpose_a":
            transpose_a,
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for lower, transpose_a in itertools.product([False, True], repeat=2)
                            for lhs_shape, rhs_shape in [
                                ((4, 4), (4, )),
                                ((4, 4), (4, 3)),
                                ((2, 8, 8), (2, 8, 10)),
                            ] for dtype in float_types()
                            for rng in [jtu.rand_default()]))
    def testSolveTriangularGrad(self, lower, transpose_a, lhs_shape, rhs_shape,
                                dtype, rng):
        # TODO(frostig): change ensemble to support a bigger rtol
        A = np.tril(
            rng(lhs_shape, dtype) + 5 * onp.eye(lhs_shape[-1], dtype=dtype))
        A = A if lower else T(A)
        B = rng(rhs_shape, dtype)
        f = partial(jsp.linalg.solve_triangular,
                    lower=lower,
                    trans=1 if transpose_a else 0)
        jtu.check_grads(f, (A, B), 2, rtol=1e-3)
Пример #27
0
class ScipyLinalgTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(1, 1), (4, 5), (10, 5), (50, 50)]
                            for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testLu(self, shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        args_maker = lambda: [rng(shape, dtype)]
        x, = args_maker()
        p, l, u = jsp.linalg.lu(x)
        self.assertAllClose(x,
                            onp.matmul(p, onp.matmul(l, u)),
                            check_dtypes=True)
        self._CompileAndCheck(jsp.linalg.lu, args_maker, check_dtypes=True)

    # TODO(phawkins): figure out why this test fails on Travis and reenable.
    @unittest.skip("Test fails on travis")
    def testLuOfSingularMatrixReturnsNans(self):
        xs = np.array([[-1., 3. / 2], [2. / 3, -1.]])
        lu, _ = jsp.linalg.lu_factor(xs)
        self.assertTrue(onp.all(onp.isnan(lu)))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(1, 1), (4, 5), (10, 5), (10, 10), (6, 7, 7)]
                            for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    @jtu.skip_on_devices("tpu")  # TODO(phawkins): precision problems on TPU.
    def testLuGrad(self, shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        a = rng(shape, dtype)
        lu = vmap(jsp.linalg.lu) if len(shape) > 2 else jsp.linalg.lu
        jtu.check_grads(lu, (a, ), 2, atol=5e-2, rtol=1e-1)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(4, 5), (6, 5)] for dtype in [np.float32]
                            for rng in [jtu.rand_default()]))
    def testLuBatching(self, shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        args = [rng(shape, np.float32) for _ in range(10)]
        expected = list(osp.linalg.lu(x) for x in args)
        ps = onp.stack([out[0] for out in expected])
        ls = onp.stack([out[1] for out in expected])
        us = onp.stack([out[2] for out in expected])

        actual_ps, actual_ls, actual_us = vmap(jsp.linalg.lu)(np.stack(args))
        self.assertAllClose(ps, actual_ps, check_dtypes=True)
        self.assertAllClose(ls, actual_ls, check_dtypes=True)
        self.assertAllClose(us, actual_us, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_n={}".format(jtu.format_shape_dtype_string((n, n), dtype)),
            "n":
            n,
            "dtype":
            dtype,
            "rng":
            rng
        } for n in [1, 4, 5, 200] for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testLuFactor(self, n, dtype, rng):
        _skip_if_unsupported_type(dtype)
        args_maker = lambda: [rng((n, n), dtype)]

        x, = args_maker()
        lu, piv = jsp.linalg.lu_factor(x)
        l = onp.tril(lu, -1) + onp.eye(n, dtype=dtype)
        u = onp.triu(lu)
        for i in range(n):
            x[[i, piv[i]], ] = x[[piv[i], i], ]
        self.assertAllClose(x, onp.matmul(l, u), check_dtypes=True, rtol=1e-3)
        self._CompileAndCheck(jsp.linalg.lu_factor,
                              args_maker,
                              check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs={}_rhs={}_sym_pos={}_lower={}".format(
                jtu.format_shape_dtype_string(lhs_shape, dtype),
                jtu.format_shape_dtype_string(rhs_shape, dtype), sym_pos,
                lower),
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "sym_pos":
            sym_pos,
            "lower":
            lower,
            "rng":
            rng
        } for lhs_shape, rhs_shape in [
            ((1, 1), (1, 1)),
            ((4, 4), (4, )),
            ((8, 8), (8, 4)),
        ] for sym_pos, lower in [
            (False, False),
            (True, False),
            (True, True),
        ] for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testSolve(self, lhs_shape, rhs_shape, dtype, sym_pos, lower, rng):
        _skip_if_unsupported_type(dtype)
        if (sym_pos and onp.issubdtype(dtype, onp.complexfloating)
                and jtu.device_under_test() == "tpu"):
            raise unittest.SkipTest(
                "Complex Cholesky decomposition not implemented on TPU")
        osp_fun = lambda lhs, rhs: osp.linalg.solve(
            lhs, rhs, sym_pos=sym_pos, lower=lower)
        jsp_fun = lambda lhs, rhs: jsp.linalg.solve(
            lhs, rhs, sym_pos=sym_pos, lower=lower)

        def args_maker():
            a = rng(lhs_shape, dtype)
            if sym_pos:
                a = onp.matmul(a, onp.conj(T(a)))
                a = onp.tril(a) if lower else onp.triu(a)
            return [a, rng(rhs_shape, dtype)]

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs={}_rhs={}_lower={}_transposea={}_unit_diagonal={}".format(
                jtu.format_shape_dtype_string(lhs_shape, dtype),
                jtu.format_shape_dtype_string(rhs_shape, dtype), lower,
                transpose_a, unit_diagonal),
            "lower":
            lower,
            "transpose_a":
            transpose_a,
            "unit_diagonal":
            unit_diagonal,
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for lower in [False, True] for transpose_a in [False, True]
                            for unit_diagonal in [False, True]
                            for lhs_shape, rhs_shape in [
                                ((4, 4), (4, )),
                                ((4, 4), (4, 3)),
                                ((2, 8, 8), (2, 8, 10)),
                            ] for dtype in float_types
                            for rng in [jtu.rand_default()]))
    def testSolveTriangular(self, lower, transpose_a, unit_diagonal, lhs_shape,
                            rhs_shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        k = rng(lhs_shape, dtype)
        l = onp.linalg.cholesky(
            onp.matmul(k, T(k)) + lhs_shape[-1] * onp.eye(lhs_shape[-1]))
        l = l.astype(k.dtype)
        b = rng(rhs_shape, dtype)

        if unit_diagonal:
            a = onp.tril(l, -1) + onp.eye(lhs_shape[-1], dtype=dtype)
        else:
            a = l
        a = a if lower else T(a)

        inv = onp.linalg.inv(T(a) if transpose_a else a).astype(a.dtype)
        if len(lhs_shape) == len(rhs_shape):
            onp_ans = onp.matmul(inv, b)
        else:
            onp_ans = onp.einsum("...ij,...j->...i", inv, b)

        # The standard scipy.linalg.solve_triangular doesn't support broadcasting.
        # But it seems like an inevitable extension so we support it.
        ans = jsp.linalg.solve_triangular(l if lower else T(l),
                                          b,
                                          trans=1 if transpose_a else 0,
                                          lower=lower,
                                          unit_diagonal=unit_diagonal)

        self.assertAllClose(onp_ans, ans, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_lhs={}_rhs={}_lower={}_transposea={}_unit_diagonal={}".
                format(jtu.format_shape_dtype_string(lhs_shape, dtype),
                       jtu.format_shape_dtype_string(rhs_shape, dtype), lower,
                       transpose_a, unit_diagonal),
                "lower":
                lower,
                "transpose_a":
                transpose_a,
                "unit_diagonal":
                unit_diagonal,
                "lhs_shape":
                lhs_shape,
                "rhs_shape":
                rhs_shape,
                "dtype":
                dtype,
                "rng":
                rng
            } for lower in [False, True] for unit_diagonal in [False, True]
            for dtype in float_types + complex_types for transpose_a in (
                [0, 1] if onp.issubdtype(dtype, np.floating) else [0, 1, 2])
            for lhs_shape, rhs_shape in [
                ((4, 4), (4, )),
                ((4, 4), (4, 3)),
                ((2, 8, 8), (2, 8, 10)),
            ] for rng in [jtu.rand_default()]))
    @jtu.skip_on_devices("tpu")  # TODO(phawkins): Test fails on TPU.
    def testSolveTriangularGrad(self, lower, transpose_a, unit_diagonal,
                                lhs_shape, rhs_shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        A = np.tril(
            rng(lhs_shape, dtype) + 5 * onp.eye(lhs_shape[-1], dtype=dtype))
        A = A if lower else T(A)
        B = rng(rhs_shape, dtype)
        f = partial(jsp.linalg.solve_triangular,
                    lower=lower,
                    trans=transpose_a,
                    unit_diagonal=unit_diagonal)
        jtu.check_grads(f, (A, B), 2, rtol=2e-2, eps=1e-3)
Пример #28
0
class NumpyLinalgTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)]
                            for dtype in float_types()
                            for rng in [jtu.rand_default()]))
    def testCholesky(self, shape, dtype, rng):
        def args_maker():
            a = rng(shape, dtype)
            return [onp.matmul(a, T(a))]

        self._CheckAgainstNumpy(onp.linalg.cholesky,
                                np.linalg.cholesky,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(np.linalg.cholesky,
                              args_maker,
                              check_dtypes=True)

    # TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_n={}".format(jtu.format_shape_dtype_string((n, n), dtype)),
            "n":
            n,
            "dtype":
            dtype,
            "rng":
            rng
        } for n in [0, 4, 5, 50] for dtype in float_types() | complex_types()
                            for rng in [jtu.rand_default()]))
    @jtu.skip_on_devices("gpu", "tpu")
    def testDet(self, n, dtype, rng):
        if not hasattr(lapack, "jax_getrf"):
            self.skipTest("No LU implementation available")
        args_maker = lambda: [rng((n, n), dtype)]

        self._CheckAgainstNumpy(onp.linalg.det,
                                np.linalg.det,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(np.linalg.det, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_n={}".format(jtu.format_shape_dtype_string((n, n), dtype)),
            "n":
            n,
            "dtype":
            dtype,
            "rng":
            rng
        } for n in [0, 4, 10, 200]
                            for dtype in float_types() | complex_types()
                            for rng in [jtu.rand_default()]))
    @jtu.skip_on_devices("gpu", "tpu")
    def testSlogdet(self, n, dtype, rng):
        if not hasattr(lapack, "jax_getrf"):
            self.skipTest("No LU implementation available")
        args_maker = lambda: [rng((n, n), dtype)]

        self._CheckAgainstNumpy(onp.linalg.slogdet,
                                np.linalg.slogdet,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(np.linalg.slogdet, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_fullmatrices={}".format(
                jtu.format_shape_dtype_string(shape, dtype), full_matrices),
            "shape":
            shape,
            "dtype":
            dtype,
            "full_matrices":
            full_matrices,
            "rng":
            rng
        } for shape in [(1, 1), (3, 4), (2, 10, 5), (2, 200, 100)]
                            for dtype in float_types()
                            for full_matrices in [False, True]
                            for rng in [jtu.rand_default()]))
    def testQr(self, shape, dtype, full_matrices, rng):
        m, n = shape[-2:]

        if full_matrices:
            mode, k = "complete", m
        else:
            mode, k = "reduced", min(m, n)

        a = rng(shape, dtype)
        lq, lr = np.linalg.qr(a, mode=mode)

        # onp.linalg.qr doesn't support broadcasting. But it seems like an
        # inevitable extension so we support it in our version.
        nq = onp.zeros(shape[:-2] + (m, k), dtype)
        nr = onp.zeros(shape[:-2] + (k, n), dtype)
        for index in onp.ndindex(*shape[:-2]):
            nq[index], nr[index] = onp.linalg.qr(a[index], mode=mode)

        max_rank = max(m, n)

        # Norm, adjusted for dimension and type.
        def norm(x):
            n = onp.linalg.norm(x, axis=(-2, -1))
            return n / (max_rank * onp.finfo(dtype).eps)

        def compare_orthogonal(q1, q2):
            # Q is unique up to sign, so normalize the sign first.
            sum_of_ratios = onp.sum(onp.divide(q1, q2), axis=-2, keepdims=True)
            phases = onp.divide(sum_of_ratios, onp.abs(sum_of_ratios))
            q1 *= phases
            self.assertTrue(onp.all(norm(q1 - q2) < 30))

        # Check a ~= qr
        self.assertTrue(onp.all(norm(a - onp.matmul(lq, lr)) < 30))

        # Compare the first 'k' vectors of Q; the remainder form an arbitrary
        # orthonormal basis for the null space.
        compare_orthogonal(nq[..., :k], lq[..., :k])

        # Check that q is close to unitary.
        self.assertTrue(onp.all(norm(onp.eye(k) - onp.matmul(T(lq), lq)) < 5))

        if not full_matrices and m >= n:
            jtu.check_jvp(np.linalg.qr, partial(jvp, np.linalg.qr), (a, ))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (5, 5, 5)]
                            for dtype in float_types()
                            for rng in [jtu.rand_default()]))
    def testInv(self, shape, dtype, rng):
        def args_maker():
            invertible = False
            while not invertible:
                a = rng(shape, dtype)
                try:
                    onp.linalg.inv(a)
                    invertible = True
                except onp.linalg.LinAlgError:
                    pass
            return [a]

        self._CheckAgainstNumpy(onp.linalg.inv,
                                np.linalg.inv,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(np.linalg.inv, args_maker, check_dtypes=True)
Пример #29
0
class LaxBackedScipyTests(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_preconditioner={}".format(
                jtu.format_shape_dtype_string(shape, dtype), preconditioner),
            "shape":
            shape,
            "dtype":
            dtype,
            "preconditioner":
            preconditioner
        } for shape in [(4, 4), (7, 7), (32, 32)]
                            for dtype in float_types + complex_types
                            for preconditioner in [None, 'identity', 'exact']))
    # TODO(#2951): reenable 'random' preconditioner.
    def test_cg_against_scipy(self, shape, dtype, preconditioner):

        rng = jtu.rand_default(self.rng())
        A = rand_sym_pos_def(rng, shape, dtype)
        b = rng(shape[:1], dtype)

        if preconditioner == 'identity':
            M = np.eye(shape[0], dtype=dtype)
        elif preconditioner == 'random':
            M = np.linalg.inv(rand_sym_pos_def(rng, shape, dtype))
        elif preconditioner == 'exact':
            M = np.linalg.inv(A)
        else:
            M = None

        def args_maker():
            return A, b

        self._CheckAgainstNumpy(partial(scipy_cg, M=M, maxiter=1),
                                partial(lax_cg, M=M, maxiter=1),
                                args_maker,
                                tol=1e-3)

        # TODO(shoyer,mattjj): I had to loosen the tolerance for complex64[7,7]
        # with preconditioner=random
        self._CheckAgainstNumpy(partial(scipy_cg, M=M, maxiter=3),
                                partial(lax_cg, M=M, maxiter=3),
                                args_maker,
                                tol=3e-3)

        self._CheckAgainstNumpy(np.linalg.solve,
                                partial(lax_cg, M=M, atol=1e-6),
                                args_maker,
                                tol=2e-2)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype
        } for shape in [(2, 2)] for dtype in float_types + complex_types))
    def test_cg_as_solve(self, shape, dtype):

        rng = jtu.rand_default(self.rng())
        a = rng(shape, dtype)
        b = rng(shape[:1], dtype)

        expected = np.linalg.solve(posify(a), b)
        actual = lax_cg(posify(a), b)
        self.assertAllClose(expected, actual)

        actual = jit(lax_cg)(posify(a), b)
        self.assertAllClose(expected, actual)

        # numerical gradients are only well defined if ``a`` is guaranteed to be
        # positive definite.
        jtu.check_grads(lambda x, y: lax_cg(posify(x), y), (a, b),
                        order=2,
                        rtol=1e-2)

    def test_cg_ndarray(self):
        A = lambda x: 2 * x
        b = jnp.arange(9.0).reshape((3, 3))
        expected = b / 2
        actual, _ = jax.scipy.sparse.linalg.cg(A, b)
        self.assertAllClose(expected, actual)

    def test_cg_pytree(self):
        A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
        b = {"a": 1.0, "b": -4.0}
        expected = {"a": 4.0, "b": -6.0}
        actual, _ = jax.scipy.sparse.linalg.cg(A, b)
        self.assertEqual(expected.keys(), actual.keys())
        self.assertAlmostEqual(expected["a"], actual["a"], places=6)
        self.assertAlmostEqual(expected["b"], actual["b"], places=6)

    def test_cg_errors(self):
        A = lambda x: x
        b = jnp.zeros((2, 1))
        x0 = jnp.zeros((2, ))
        with self.assertRaisesRegex(ValueError,
                                    "x0 and b must have matching shape"):
            jax.scipy.sparse.linalg.cg(A, b, x0)
Пример #30
0
class NdimageTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_{}_coordinates={}_order={}_mode={}_cval={}_impl={}_round={}".
                format(
                    jtu.format_shape_dtype_string(shape, dtype),
                    jtu.format_shape_dtype_string(coords_shape, coords_dtype),
                    order, mode, cval, impl, round_),
                "rng_factory":
                rng_factory,
                "shape":
                shape,
                "coords_shape":
                coords_shape,
                "dtype":
                dtype,
                "coords_dtype":
                coords_dtype,
                "order":
                order,
                "mode":
                mode,
                "cval":
                cval,
                "impl":
                impl,
                "round_":
                round_
            } for shape in [(5, ), (3, 4), (3, 4, 5)]
            for coords_shape in [(7, ), (2, 3, 4)]
            for dtype in float_dtypes + int_dtypes
            for coords_dtype in float_dtypes for order in [0, 1]
            for mode in ['wrap', 'constant', 'nearest']
            for cval in ([0, -1] if mode == 'constant' else [0])
            for impl, rng_factory in [
                ("original", partial(jtu.rand_uniform, low=0, high=1)),
                ("fixed", partial(jtu.rand_uniform, low=-0.75, high=1.75)),
            ] for round_ in [True, False]))
    def testMapCoordinates(self, shape, dtype, coords_shape, coords_dtype,
                           order, mode, cval, impl, round_, rng_factory):
        def args_maker():
            x = onp.arange(onp.prod(shape), dtype=dtype).reshape(shape)
            coords = [(size - 1) * rng(coords_shape, coords_dtype)
                      for size in shape]
            if round_:
                coords = [c.round().astype(int) for c in coords]
            return x, coords

        rng = rng_factory(self.rng())
        lsp_op = lambda x, c: lsp_ndimage.map_coordinates(
            x, c, order=order, mode=mode, cval=cval)
        impl_fun = (osp_ndimage.map_coordinates
                    if impl == "original" else _fixed_ref_map_coordinates)
        osp_op = lambda x, c: impl_fun(x, c, order=order, mode=mode, cval=cval)
        if dtype in float_dtypes:
            epsilon = max([
                dtypes.finfo(dtypes.canonicalize_dtype(d)).eps
                for d in [dtype, coords_dtype]
            ])
            self._CheckAgainstNumpy(lsp_op,
                                    osp_op,
                                    args_maker,
                                    tol=100 * epsilon)
        else:
            self._CheckAgainstNumpy(lsp_op, osp_op, args_maker, tol=0)

    def testMapCoordinatesErrors(self):
        x = onp.arange(5.0)
        c = [onp.linspace(0, 5, num=3)]
        with self.assertRaisesRegex(NotImplementedError, 'requires order<=1'):
            lsp_ndimage.map_coordinates(x, c, order=2)
        with self.assertRaisesRegex(NotImplementedError,
                                    'does not yet support mode'):
            lsp_ndimage.map_coordinates(x, c, order=1, mode='reflect')
        with self.assertRaisesRegex(ValueError, 'sequence of length'):
            lsp_ndimage.map_coordinates(x, [c, c], order=1)

    def testMapCoordinateDocstring(self):
        self.assertIn("Only linear interpolation",
                      lsp_ndimage.map_coordinates.__doc__)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_{}_order={}".format(onp.dtype(dtype),
                                                       order),
                "dtype": dtype,
                "order": order
            } for dtype in float_dtypes + int_dtypes for order in [0, 1]))
    def testMapCoordinatesRoundHalf(self, dtype, order):
        x = onp.arange(-3, 3, dtype=dtype)
        c = onp.array([[.5, 1.5, 2.5, 3.5]])

        def args_maker():
            return x, c

        lsp_op = lambda x, c: lsp_ndimage.map_coordinates(x, c, order=order)
        osp_op = lambda x, c: osp_ndimage.map_coordinates(x, c, order=order)
        self._CheckAgainstNumpy(lsp_op, osp_op, args_maker)

    def testContinuousGradients(self):
        # regression test for https://github.com/google/jax/issues/3024

        def loss(delta):
            x = onp.arange(100.0)
            border = 10
            indices = onp.arange(x.size) + delta
            # linear interpolation of the linear function y=x should be exact
            shifted = lsp_ndimage.map_coordinates(x, [indices], order=1)
            return ((x - shifted)**2)[border:-border].mean()

        # analytical gradient of (x - (x - delta)) ** 2 is 2 * delta
        self.assertAllClose(grad(loss)(0.5), 1.0, check_dtypes=False)
        self.assertAllClose(grad(loss)(1.0), 2.0, check_dtypes=False)