Exemplo n.º 1
0
class LaxBackedScipyTests(jtu.JaxTestCase):

  def _fetch_preconditioner(self, preconditioner, A, rng=None):
    """
    Returns one of various preconditioning matrices depending on the identifier
    `preconditioner' and the input matrix A whose inverse it supposedly
    approximates.
    """
    if preconditioner == 'identity':
      M = np.eye(A.shape[0], dtype=A.dtype)
    elif preconditioner == 'random':
      if rng is None:
        rng = jtu.rand_default(self.rng())
      M = np.linalg.inv(rand_sym_pos_def(rng, A.shape, A.dtype))
    elif preconditioner == 'exact':
      M = np.linalg.inv(A)
    else:
      M = None
    return M

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_preconditioner={}".format(
            jtu.format_shape_dtype_string(shape, dtype),
            preconditioner),
       "shape": shape, "dtype": dtype, "preconditioner": preconditioner}
      for shape in [(4, 4), (7, 7)]
      for dtype in [np.float64, np.complex128]
      for preconditioner in [None, 'identity', 'exact', 'random']))
  def test_cg_against_scipy(self, shape, dtype, preconditioner):
    if not config.x64_enabled:
      raise unittest.SkipTest("requires x64 mode")

    rng = jtu.rand_default(self.rng())
    A = rand_sym_pos_def(rng, shape, dtype)
    b = rng(shape[:1], dtype)
    M = self._fetch_preconditioner(preconditioner, A, rng=rng)

    def args_maker():
      return A, b

    self._CheckAgainstNumpy(
        partial(scipy_cg, M=M, maxiter=1),
        partial(lax_cg, M=M, maxiter=1),
        args_maker,
        tol=1e-12)

    self._CheckAgainstNumpy(
        partial(scipy_cg, M=M, maxiter=3),
        partial(lax_cg, M=M, maxiter=3),
        args_maker,
        tol=1e-12)

    self._CheckAgainstNumpy(
        np.linalg.solve,
        partial(lax_cg, M=M, atol=1e-10),
        args_maker,
        tol=1e-6)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype}
      for shape in [(2, 2)]
      for dtype in float_types + complex_types))
  def test_cg_as_solve(self, shape, dtype):

    rng = jtu.rand_default(self.rng())
    a = rng(shape, dtype)
    b = rng(shape[:1], dtype)

    expected = np.linalg.solve(posify(a), b)
    actual = lax_cg(posify(a), b)
    self.assertAllClose(expected, actual, atol=1e-5, rtol=1e-5)

    actual = jit(lax_cg)(posify(a), b)
    self.assertAllClose(expected, actual, atol=1e-5, rtol=1e-5)

    # numerical gradients are only well defined if ``a`` is guaranteed to be
    # positive definite.
    jtu.check_grads(
        lambda x, y: lax_cg(posify(x), y),
        (a, b), order=2, rtol=2e-1)

  def test_cg_ndarray(self):
    A = lambda x: 2 * x
    b = jnp.arange(9.0).reshape((3, 3))
    expected = b / 2
    actual, _ = jax.scipy.sparse.linalg.cg(A, b)
    self.assertAllClose(expected, actual)

  def test_cg_pytree(self):
    A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
    b = {"a": 1.0, "b": -4.0}
    expected = {"a": 4.0, "b": -6.0}
    actual, _ = jax.scipy.sparse.linalg.cg(A, b)
    self.assertEqual(expected.keys(), actual.keys())
    self.assertAlmostEqual(expected["a"], actual["a"], places=6)
    self.assertAlmostEqual(expected["b"], actual["b"], places=6)

  def test_cg_errors(self):
    A = lambda x: x
    b = jnp.zeros((2,))
    with self.assertRaisesRegex(
        ValueError, "x0 and b must have matching tree structure"):
      jax.scipy.sparse.linalg.cg(A, {'x': b}, {'y': b})
    with self.assertRaisesRegex(
        ValueError, "x0 and b must have matching shape"):
      jax.scipy.sparse.linalg.cg(A, b, b[:, np.newaxis])
    with self.assertRaisesRegex(ValueError, "must be a square matrix"):
      jax.scipy.sparse.linalg.cg(jnp.zeros((3, 2)), jnp.zeros((2,)))
    with self.assertRaisesRegex(
        TypeError, "linear operator must be either a function or ndarray"):
      jax.scipy.sparse.linalg.cg([[1]], jnp.zeros((1,)))

  def test_cg_without_pytree_equality(self):

    @register_pytree_node_class
    class MinimalPytree:
      def __init__(self, value):
        self.value = value
      def tree_flatten(self):
        return [self.value], None
      @classmethod
      def tree_unflatten(cls, aux_data, children):
        return cls(*children)

    A = lambda x: MinimalPytree(2 * x.value)
    b = MinimalPytree(jnp.arange(5.0))
    expected = b.value / 2
    actual, _ = jax.scipy.sparse.linalg.cg(A, b)
    self.assertAllClose(expected, actual.value)

  def test_cg_weak_types(self):
    x, _ = jax.scipy.sparse.linalg.bicgstab(lambda x: x, 1.0)
    self.assertTrue(dtypes.is_weakly_typed(x))

  # BICGSTAB
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_preconditioner={}".format(
            jtu.format_shape_dtype_string(shape, dtype),
            preconditioner),
       "shape": shape, "dtype": dtype, "preconditioner": preconditioner}
      for shape in [(5, 5)]
      for dtype in [np.float64, np.complex128]
      for preconditioner in [None, 'identity', 'exact', 'random']
  ))
  def test_bicgstab_against_scipy(
      self, shape, dtype, preconditioner):
    if not config.jax_enable_x64:
      raise unittest.SkipTest("requires x64 mode")

    rng = jtu.rand_default(self.rng())
    A = rng(shape, dtype)
    b = rng(shape[:1], dtype)
    M = self._fetch_preconditioner(preconditioner, A, rng=rng)

    def args_maker():
      return A, b

    self._CheckAgainstNumpy(
        partial(scipy_bicgstab, M=M, maxiter=1),
        partial(lax_bicgstab, M=M, maxiter=1),
        args_maker,
        tol=1e-5)

    self._CheckAgainstNumpy(
        partial(scipy_bicgstab, M=M, maxiter=2),
        partial(lax_bicgstab, M=M, maxiter=2),
        args_maker,
        tol=1e-4)

    self._CheckAgainstNumpy(
        partial(scipy_bicgstab, M=M, maxiter=1),
        partial(lax_bicgstab, M=M, maxiter=1),
        args_maker,
        tol=1e-4)

    self._CheckAgainstNumpy(
        np.linalg.solve,
        partial(lax_bicgstab, M=M, atol=1e-6),
        args_maker,
        tol=1e-4)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_preconditioner={}".format(
         jtu.format_shape_dtype_string(shape, dtype),
         preconditioner),
      "shape": shape, "dtype": dtype, "preconditioner": preconditioner}
      for shape in [(2, 2), (7, 7)]
      for dtype in float_types + complex_types
      for preconditioner in [None, 'identity', 'exact']
      ))
  @jtu.skip_on_devices("gpu")
  def test_bicgstab_on_identity_system(self, shape, dtype, preconditioner):
    A = jnp.eye(shape[1], dtype=dtype)
    solution = jnp.ones(shape[1], dtype=dtype)
    rng = jtu.rand_default(self.rng())
    M = self._fetch_preconditioner(preconditioner, A, rng=rng)
    b = matmul_high_precision(A, solution)
    tol = shape[0] * jnp.finfo(dtype).eps
    x, info = jax.scipy.sparse.linalg.bicgstab(A, b, tol=tol, atol=tol,
                                               M=M)
    using_x64 = solution.dtype.kind in {np.float64, np.complex128}
    solution_tol = 1e-8 if using_x64 else 1e-4
    self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_preconditioner={}".format(
         jtu.format_shape_dtype_string(shape, dtype),
         preconditioner),
      "shape": shape, "dtype": dtype, "preconditioner": preconditioner
      }
      for shape in [(2, 2), (4, 4)]
      for dtype in float_types + complex_types
      for preconditioner in [None, 'identity', 'exact']
      ))
  @jtu.skip_on_devices("gpu")
  def test_bicgstab_on_random_system(self, shape, dtype, preconditioner):
    rng = jtu.rand_default(self.rng())
    A = rng(shape, dtype)
    solution = rng(shape[1:], dtype)
    M = self._fetch_preconditioner(preconditioner, A, rng=rng)
    b = matmul_high_precision(A, solution)
    tol = shape[0] * jnp.finfo(A.dtype).eps
    x, info = jax.scipy.sparse.linalg.bicgstab(A, b, tol=tol, atol=tol, M=M)
    using_x64 = solution.dtype.kind in {np.float64, np.complex128}
    solution_tol = 1e-8 if using_x64 else 1e-4
    self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
    # solve = lambda A, b: jax.scipy.sparse.linalg.bicgstab(A, b)[0]
    # jtu.check_grads(solve, (A, b), order=1, rtol=3e-1)


  def test_bicgstab_pytree(self):
    A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
    b = {"a": 1.0, "b": -4.0}
    expected = {"a": 4.0, "b": -6.0}
    actual, _ = jax.scipy.sparse.linalg.bicgstab(A, b)
    self.assertEqual(expected.keys(), actual.keys())
    self.assertAlmostEqual(expected["a"], actual["a"], places=5)
    self.assertAlmostEqual(expected["b"], actual["b"], places=5)

  def test_bicgstab_weak_types(self):
    x, _ = jax.scipy.sparse.linalg.bicgstab(lambda x: x, 1.0)
    self.assertTrue(dtypes.is_weakly_typed(x))

  # GMRES
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_preconditioner={}_solve_method={}".format(
            jtu.format_shape_dtype_string(shape, dtype),
            preconditioner,
            solve_method),
       "shape": shape, "dtype": dtype, "preconditioner": preconditioner,
       "solve_method": solve_method}
      for shape in [(3, 3)]
      for dtype in [np.float64, np.complex128]
      for preconditioner in [None, 'identity', 'exact', 'random']
      for solve_method in ['incremental', 'batched']))
  def test_gmres_against_scipy(
      self, shape, dtype, preconditioner, solve_method):
    if not config.x64_enabled:
      raise unittest.SkipTest("requires x64 mode")

    rng = jtu.rand_default(self.rng())
    A = rng(shape, dtype)
    b = rng(shape[:1], dtype)
    M = self._fetch_preconditioner(preconditioner, A, rng=rng)

    def args_maker():
      return A, b

    self._CheckAgainstNumpy(
        partial(scipy_gmres, M=M, restart=1, maxiter=1),
        partial(lax_gmres, M=M, restart=1, maxiter=1, solve_method=solve_method),
        args_maker,
        tol=1e-10)

    self._CheckAgainstNumpy(
        partial(scipy_gmres, M=M, restart=1, maxiter=2),
        partial(lax_gmres, M=M, restart=1, maxiter=2, solve_method=solve_method),
        args_maker,
        tol=1e-10)

    self._CheckAgainstNumpy(
        partial(scipy_gmres, M=M, restart=2, maxiter=1),
        partial(lax_gmres, M=M, restart=2, maxiter=1, solve_method=solve_method),
        args_maker,
        tol=1e-10)

    self._CheckAgainstNumpy(
        np.linalg.solve,
        partial(lax_gmres, M=M, atol=1e-6, solve_method=solve_method),
        args_maker,
        tol=1e-10)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_preconditioner={}_solve_method={}".format(
         jtu.format_shape_dtype_string(shape, dtype),
         preconditioner,
         solve_method),
      "shape": shape, "dtype": dtype, "preconditioner": preconditioner,
      "solve_method": solve_method}
      for shape in [(2, 2), (7, 7)]
      for dtype in float_types + complex_types
      for preconditioner in [None, 'identity', 'exact']
      for solve_method in ['batched', 'incremental']
      ))
  @jtu.skip_on_devices("gpu")
  def test_gmres_on_identity_system(self, shape, dtype, preconditioner,
                                    solve_method):
    A = jnp.eye(shape[1], dtype=dtype)

    solution = jnp.ones(shape[1], dtype=dtype)
    rng = jtu.rand_default(self.rng())
    M = self._fetch_preconditioner(preconditioner, A, rng=rng)
    b = matmul_high_precision(A, solution)
    restart = shape[-1]
    tol = shape[0] * jnp.finfo(dtype).eps
    x, info = jax.scipy.sparse.linalg.gmres(A, b, tol=tol, atol=tol,
                                            restart=restart,
                                            M=M, solve_method=solve_method)
    using_x64 = solution.dtype.kind in {np.float64, np.complex128}
    solution_tol = 1e-8 if using_x64 else 1e-4
    self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_preconditioner={}_solve_method={}".format(
         jtu.format_shape_dtype_string(shape, dtype),
         preconditioner,
         solve_method),
      "shape": shape, "dtype": dtype, "preconditioner": preconditioner,
      "solve_method": solve_method}
      for shape in [(2, 2), (4, 4)]
      for dtype in float_types + complex_types
      for preconditioner in [None, 'identity', 'exact']
      for solve_method in ['incremental', 'batched']
      ))
  @jtu.skip_on_devices("gpu")
  def test_gmres_on_random_system(self, shape, dtype, preconditioner,
                                  solve_method):
    rng = jtu.rand_default(self.rng())
    A = rng(shape, dtype)

    solution = rng(shape[1:], dtype)
    M = self._fetch_preconditioner(preconditioner, A, rng=rng)
    b = matmul_high_precision(A, solution)
    restart = shape[-1]
    tol = shape[0] * jnp.finfo(A.dtype).eps
    x, info = jax.scipy.sparse.linalg.gmres(A, b, tol=tol, atol=tol,
                                            restart=restart,
                                            M=M, solve_method=solve_method)
    using_x64 = solution.dtype.kind in {np.float64, np.complex128}
    solution_tol = 1e-8 if using_x64 else 1e-4
    self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
    # solve = lambda A, b: jax.scipy.sparse.linalg.gmres(A, b)[0]
    # jtu.check_grads(solve, (A, b), order=1, rtol=2e-1)

  def test_gmres_pytree(self):
    A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
    b = {"a": 1.0, "b": -4.0}
    expected = {"a": 4.0, "b": -6.0}
    actual, _ = jax.scipy.sparse.linalg.gmres(A, b)
    self.assertEqual(expected.keys(), actual.keys())
    self.assertAlmostEqual(expected["a"], actual["a"], places=5)
    self.assertAlmostEqual(expected["b"], actual["b"], places=5)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_preconditioner={}".format(
         jtu.format_shape_dtype_string(shape, dtype),
         preconditioner),
      "shape": shape, "dtype": dtype, "preconditioner": preconditioner}
      for shape in [(2, 2), (3, 3)]
      for dtype in float_types + complex_types
      for preconditioner in [None, 'identity']))
  def test_gmres_arnoldi_step(self, shape, dtype, preconditioner):
    """
    The Arnoldi decomposition within GMRES is correct.
    """
    if not config.x64_enabled:
      raise unittest.SkipTest("requires x64 mode")

    rng = jtu.rand_default(self.rng())
    A = rng(shape, dtype)
    M = self._fetch_preconditioner(preconditioner, A, rng=rng)
    if preconditioner is None:
      M = lambda x: x
    else:
      M = partial(matmul_high_precision, M)
    n = shape[0]
    x0 = rng(shape[:1], dtype)
    Q = np.zeros((n, n + 1), dtype=dtype)
    Q[:, 0] = x0/jnp.linalg.norm(x0)
    Q = jnp.array(Q)
    H = jnp.eye(n, n + 1, dtype=dtype)

    @jax.tree_util.Partial
    def A_mv(x):
      return matmul_high_precision(A, x)
    for k in range(n):
      Q, H, _ = jax._src.scipy.sparse.linalg._kth_arnoldi_iteration(
          k, A_mv, M, Q, H)
    QA = matmul_high_precision(Q[:, :n].conj().T, A)
    QAQ = matmul_high_precision(QA, Q[:, :n])
    self.assertAllClose(QAQ, H.T[:n, :], rtol=1e-5, atol=1e-5)

  def test_gmres_weak_types(self):
    x, _ = jax.scipy.sparse.linalg.gmres(lambda x: x, 1.0)
    self.assertTrue(dtypes.is_weakly_typed(x))
Exemplo n.º 2
0
class CustomObjectTest(jtu.JaxTestCase):

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_compile={}_primitive={}".format(compile, primitive),
       "compile": compile, "primitive": primitive}
      for primitive in [True, False]
      for compile in [True, False]))
  def testSparseIdentity(self, compile, primitive):
    f = identity if primitive else (lambda x: x)
    f = jit(f) if compile else f
    rng = jtu.rand_default(self.rng())
    M = make_sparse_array(rng, (10,), jnp.float32)
    M2 = f(M)

    jaxpr = make_jaxpr(f)(M).jaxpr
    core.check_jaxpr(jaxpr)

    self.assertEqual(M.dtype, M2.dtype)
    self.assertEqual(M.index_dtype, M2.index_dtype)
    self.assertAllClose(M.data, M2.data)
    self.assertAllClose(M.indices, M2.indices)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_compile={}".format(compile),
       "compile": compile}
      for compile in [True, False]))
  def testSparseSplit(self, compile):
    f = jit(split) if compile else split
    rng = jtu.rand_default(self.rng())
    M = make_sparse_array(rng, (10,), jnp.float32)
    M2, M3 = f(M)

    jaxpr = make_jaxpr(f)(M).jaxpr
    core.check_jaxpr(jaxpr)

    for MM in M2, M3:
      self.assertEqual(M.dtype, MM.dtype)
      self.assertEqual(M.index_dtype, MM.index_dtype)
      self.assertArraysEqual(M.data, MM.data)
      self.assertArraysEqual(M.indices, MM.indices)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_compile={}_primitive={}".format(compile, primitive),
       "compile": compile, "primitive": primitive}
      for primitive in [True, False]
      for compile in [True, False]))
  def testSparseLaxLoop(self, compile, primitive):
    rng = jtu.rand_default(self.rng())
    f = identity if primitive else (lambda x: x)
    f = jit(f) if compile else f
    body_fun = lambda _, A: f(A)
    M = make_sparse_array(rng, (10,), jnp.float32)
    lax.fori_loop(0, 10, body_fun, M)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_attr={}".format(attr), "attr": attr}
      for attr in ["data", "indices"]))
  def testSparseAttrAccess(self, attr):
    rng = jtu.rand_default(self.rng())
    args_maker = lambda: [make_sparse_array(rng, (10,), jnp.float32)]
    f = lambda x: getattr(x, attr)
    self._CompileAndCheck(f, args_maker)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}".format(
         jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype}
      for shape in [(3, 3), (2, 6), (6, 2)]
      for dtype in jtu.dtypes.floating))
  def testSparseMatvec(self, shape, dtype):
    rng = jtu.rand_default(self.rng())
    args_maker = lambda: [make_sparse_array(rng, shape, dtype), rng(shape[-1:], dtype)]
    self._CompileAndCheck(matvec, args_maker)

  def testLowerToNothing(self):
    empty = Empty(AbstractEmpty())
    jaxpr = make_jaxpr(jit(lambda e: e))(empty).jaxpr
    core.check_jaxpr(jaxpr)

    # cannot return a unit, because CompileAndCheck assumes array output.
    testfunc = lambda e: None
    args_maker = lambda: [empty]
    self._CompileAndCheck(testfunc, args_maker)

  def testConstantHandler(self):
    def make_const_array():
      data = np.arange(3.0)
      indices = np.arange(3)[:, None]
      shape = (5,)
      aval = AbstractSparseArray(shape, data.dtype, indices.dtype, len(indices))
      return SparseArray(aval, data, indices)
    out1 = make_const_array()
    out2 = jit(make_const_array)()
    self.assertArraysEqual(out1.data, out2.data)
    self.assertArraysEqual(out1.indices, out2.indices)
Exemplo n.º 3
0
class TestPolynomial(jtu.JaxTestCase):
    def assertSetsAllClose(self,
                           x,
                           y,
                           rtol=None,
                           atol=None,
                           check_dtypes=True):
        """Assert that x and y contain permutations of the same approximate set of values.

    For non-complex inputs, this is accomplished by comparing the sorted inputs.
    For complex, such an approach can be confounded by numerical errors. In this case,
    we compute the structural rank of the pairwise comparison matrix: if the structural
    rank is full, it implies that the matrix can be permuted so that the diagonal is
    non-zero, which implies a one-to-one approximate match between the permuted sets.
    """
        x = np.asarray(x).ravel()
        y = np.asarray(y).ravel()

        atol = max(jtu.tolerance(x.dtype, atol), jtu.tolerance(y.dtype, atol))
        rtol = max(jtu.tolerance(x.dtype, rtol), jtu.tolerance(y.dtype, rtol))

        if not (np.issubdtype(x.dtype, np.complexfloating)
                or np.issubdtype(y.dtype, np.complexfloating)):
            return self.assertAllClose(np.sort(x),
                                       np.sort(y),
                                       atol=atol,
                                       rtol=rtol,
                                       check_dtypes=check_dtypes)

        if check_dtypes:
            self.assertEqual(x.dtype, y.dtype)
        self.assertEqual(x.size, y.size)

        pairwise = np.isclose(x[:, None],
                              x[None, :],
                              atol=atol,
                              rtol=rtol,
                              equal_nan=True)
        rank = csgraph.structural_rank(csr_matrix(pairwise))
        self.assertEqual(rank, x.size)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_dtype={}_leading={}_trailing={}".format(
                jtu.format_shape_dtype_string((
                    length + leading + trailing, ), dtype), leading, trailing),
            "dtype":
            dtype,
            "length":
            length,
            "leading":
            leading,
            "trailing":
            trailing
        } for dtype in all_dtypes for length in [0, 3, 5]
                            for leading in [0, 2] for trailing in [0, 2]))
    # TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU.
    @jtu.skip_on_devices("gpu", "tpu")
    def testRoots(self, dtype, length, leading, trailing):
        rng = jtu.rand_some_zero(self.rng())

        def args_maker():
            p = rng((length, ), dtype)
            return [
                jnp.concatenate([
                    jnp.zeros(leading, p.dtype), p,
                    jnp.zeros(trailing, p.dtype)
                ])
            ]

        jnp_fun = jnp.roots

        def np_fun(arg):
            return np.roots(arg).astype(dtypes._to_complex_dtype(arg.dtype))

        # Note: outputs have no defined order, so we need to use a special comparator.
        args = args_maker()
        np_roots = np_fun(*args)
        jnp_roots = jnp_fun(*args)
        self.assertSetsAllClose(np_roots, jnp_roots)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_dtype={}_leading={}_trailing={}".format(
                jtu.format_shape_dtype_string((
                    length + leading + trailing, ), dtype), leading, trailing),
            "dtype":
            dtype,
            "length":
            length,
            "leading":
            leading,
            "trailing":
            trailing
        } for dtype in all_dtypes for length in [0, 3, 5]
                            for leading in [0, 2] for trailing in [0, 2]))
    # TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU.
    @jtu.skip_on_devices("gpu", "tpu")
    def testRootsNoStrip(self, dtype, length, leading, trailing):
        rng = jtu.rand_some_zero(self.rng())

        def args_maker():
            p = rng((length, ), dtype)
            return [
                jnp.concatenate([
                    jnp.zeros(leading, p.dtype), p,
                    jnp.zeros(trailing, p.dtype)
                ])
            ]

        jnp_fun = partial(jnp.roots, strip_zeros=False)

        def np_fun(arg):
            roots = np.roots(arg).astype(dtypes._to_complex_dtype(arg.dtype))
            if len(roots) < len(arg) - 1:
                roots = np.pad(roots, (0, len(arg) - len(roots) - 1),
                               constant_values=complex(np.nan, np.nan))
            return roots

        # Note: outputs have no defined order, so we need to use a special comparator.
        args = args_maker()
        np_roots = np_fun(*args)
        jnp_roots = jnp_fun(*args)
        self.assertSetsAllClose(np_roots, jnp_roots)
        self._CompileAndCheck(jnp_fun, args_maker)
Exemplo n.º 4
0
class LaxVmapTest(jtu.JaxTestCase):

  def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng,
                     rtol=None, atol=None, multiple_results=False):
    batched_shapes = map(partial(add_bdim, bdim_size), bdims, shapes)
    args = [rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes)]
    args_slice = args_slicer(args, bdims)
    ans = jax.vmap(op, bdims)(*args)
    if bdim_size == 0:
      args = [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
      out = op(*args)
      if not multiple_results:
        expected = np.zeros((0,) + out.shape, out.dtype)
      else:
        expected = [np.zeros((0,) + o.shape, o.dtype) for o in out]
    else:
      outs = [op(*args_slice(i)) for i in range(bdim_size)]
      if not multiple_results:
        expected = np.stack(outs)
      else:
        expected = [np.stack(xs) for xs in zip(*outs)]
    self.assertAllClose(ans, expected, rtol=rtol, atol=atol)

  @parameterized.named_parameters(itertools.chain.from_iterable(
      jtu.cases_from_list(
        {"testcase_name": "{}_bdims={}".format(
            jtu.format_test_name_suffix(rec.op, shapes,
                                        itertools.repeat(dtype)), bdims),
         "op_name": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes,
         "dtype": dtype, "bdims": bdims, "tol": rec.tol}
        for shape_group in compatible_shapes
        for shapes in itertools.combinations_with_replacement(shape_group, rec.nargs)
        for bdims in all_bdims(*shapes)
        for dtype in rec.dtypes)
      for rec in LAX_OPS))
  def testOp(self, op_name, rng_factory, shapes, dtype, bdims, tol):
    rng = rng_factory(self.rng())
    op = getattr(lax, op_name)
    self._CheckBatching(op, 10, bdims, shapes, [dtype] * len(shapes), rng,
                        atol=tol, rtol=tol)

  @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
       "testcase_name":
       "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
       "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}"
       "_lhs_bdim={}_rhs_bdim={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums),
               feature_group_count, batch_group_count, lhs_bdim, rhs_bdim),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "strides": strides, "padding": padding, "lhs_dil": lhs_dil,
       "rhs_dil": rhs_dil, "dimension_numbers": dim_nums,
       "perms": perms, "lhs_bdim": lhs_bdim, "rhs_bdim": rhs_bdim,
       "feature_group_count": feature_group_count,
       "batch_group_count": batch_group_count,
     } for batch_group_count, feature_group_count in s([(1, 1), (2, 1), (1, 2)])
       for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in s([
           ((b * batch_group_count, i * feature_group_count, 6, 7),  # lhs_shape
            (j * batch_group_count * feature_group_count, i, 1, 2),  # rhs_shape
            [(1, 1), (1, 2), (2, 1)],  # strides
            [((0, 0), (0, 0)), ((1, 0), (0, 1)), ((0, -1), (0, 0))],  # pads
            [(1, 1), (2, 1)],  # lhs_dils
            [(1, 1), (2, 2)])  # rhs_dils
           for b, i, j in itertools.product([1, 2], repeat=3)])
       for strides in s(all_strides)
       for rhs_dil in s(rhs_dils)
       for lhs_dil in s(lhs_dils)
       for dtype in s([np.float32])
       for padding in s(all_pads)
       for dim_nums, perms in s([
           (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
           (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
           (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))])
       for lhs_bdim in s(itertools.chain([cast(Optional[int], None)],
                                         range(len(lhs_shape) + 1)))
       for rhs_bdim in s(itertools.chain([cast(Optional[int], None)],
                                         range(len(rhs_shape) + 1)))
       if (lhs_bdim, rhs_bdim) != (None, None)
       )))
  def testConvGeneralDilatedBatching(
      self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil,
      dimension_numbers, perms, feature_group_count, batch_group_count,
      lhs_bdim, rhs_bdim):
    rng = jtu.rand_default(self.rng())
    tol = 1e-1 if dtypes.finfo(dtype).bits <= 32 else 1e-3

    # permute shapes to match dim_spec, scale by feature_group_count
    lhs_perm, rhs_perm = perms
    lhs_shape = list(np.take(lhs_shape, lhs_perm))
    rhs_shape = list(np.take(rhs_shape, rhs_perm))

    conv = partial(lax.conv_general_dilated, window_strides=strides,
                   padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
                   dimension_numbers=dimension_numbers,
                   feature_group_count=feature_group_count,
                   batch_group_count=batch_group_count,
                   precision=lax.Precision.HIGHEST)
    self._CheckBatching(conv, 5, (lhs_bdim, rhs_bdim), (lhs_shape, rhs_shape),
                        (dtype, dtype), rng, rtol=tol, atol=tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format(
          shape, from_dtype, to_dtype, bdims),
       "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype,
       "bdims": bdims}
      for from_dtype, to_dtype in itertools.product(
          [np.float32, np.int32, "float32", "int32"], repeat=2)
      for shape in [(2, 3)]
      for bdims in all_bdims(shape)))
  def testConvertElementType(self, shape, from_dtype, to_dtype, bdims):
    rng = jtu.rand_default(self.rng())
    op = lambda x: lax.convert_element_type(x, to_dtype)
    self._CheckBatching(op, 10, bdims, (shape,), (from_dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_nmant={}_nexp={}_bdims={}".format(
          jtu.format_shape_dtype_string(shape, dtype), nmant, nexp, bdims),
       "shape": shape, "dtype": dtype, "nmant": nmant, "nexp": nexp, "bdims": bdims}
      for dtype in float_dtypes
      for shape in [(2, 4)]
      for nexp in [1, 3, 5]
      for nmant in [0, 2, 4]
      for bdims in all_bdims(shape)))
  def testReducePrecision(self, shape, dtype, nmant, nexp, bdims):
    rng = jtu.rand_default(self.rng())
    op = lambda x: lax.reduce_precision(x, exponent_bits=nexp, mantissa_bits=nmant)
    self._CheckBatching(op, 10, bdims, (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format(
          shape, from_dtype, to_dtype, bdims),
       "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype,
       "bdims": bdims}
      for from_dtype, to_dtype in itertools.product(
          [np.float32, np.int32, "float32", "int32"], repeat=2)
      for shape in [(2, 3)]
      for bdims in all_bdims(shape)))
  def testBitcastElementType(self, shape, from_dtype, to_dtype, bdims,):
    rng = jtu.rand_default(self.rng())
    op = lambda x: lax.bitcast_convert_type(x, to_dtype)
    self._CheckBatching(op, 10, bdims, (shape,), (from_dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_min_shape={}_operand_shape={}_max_shape={}_bdims={}"
       .format(jtu.format_shape_dtype_string(min_shape, dtype),
               jtu.format_shape_dtype_string(operand_shape, dtype),
               jtu.format_shape_dtype_string(max_shape, dtype),
               bdims),
       "min_shape": min_shape, "operand_shape": operand_shape,
       "max_shape": max_shape, "dtype": dtype, "bdims": bdims}
      for min_shape, operand_shape, max_shape in [
          [(), (2, 3), ()],
          [(2, 3), (2, 3), ()],
          [(), (2, 3), (2, 3)],
          [(2, 3), (2, 3), (2, 3)],
      ]
      for dtype in default_dtypes
      for bdims in all_bdims(min_shape, operand_shape, max_shape)))
  def testClamp(self, min_shape, operand_shape, max_shape, dtype, bdims):
    rng = jtu.rand_default(self.rng())
    shapes = [min_shape, operand_shape, max_shape]
    self._CheckBatching(lax.clamp, 10, bdims, shapes, [dtype] * 3, rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_lhs_shape={}_rhs_shape={}_bdims={}".format(
          jtu.format_shape_dtype_string(lhs_shape, dtype),
          jtu.format_shape_dtype_string(rhs_shape, dtype),
          bdims),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "bdims": bdims}
      for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]
      for bdims in all_bdims(lhs_shape, rhs_shape)
      for dtype in default_dtypes))
  def testDot(self, lhs_shape, rhs_shape, dtype, bdims):
    rng = jtu.rand_default(self.rng())
    op = partial(lax.dot, precision=lax.Precision.HIGHEST)
    self._CheckBatching(op, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype),
                        rng, rtol={np.float16: 5e-2, np.float64: 5e-14})

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}_bdims={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               lhs_contracting, rhs_contracting, bdims),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting,
       "bdims": bdims}
      for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [
          [(5,), (5,), [0], [0]],
          [(5, 7), (5,), [0], [0]],
          [(7, 5), (5,), [1], [0]],
          [(3, 5), (2, 5), [1], [1]],
          [(5, 3), (5, 2), [0], [0]],
          [(5, 3, 2), (5, 2, 4), [0], [0]],
          [(5, 3, 2), (5, 2, 4), [0,2], [0,1]],
          [(5, 3, 2), (3, 5, 2, 4), [0,2], [1,2]],
          [(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]],
          [(3, 2), (2, 4), [1], [0]],
      ]
      for bdims in all_bdims(lhs_shape, rhs_shape)
      for dtype in default_dtypes))
  def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype,
                                 lhs_contracting, rhs_contracting, bdims):
    rng = jtu.rand_small(self.rng())
    dimension_numbers = ((lhs_contracting, rhs_contracting), ([], []))
    dot = partial(lax.dot_general, dimension_numbers=dimension_numbers)
    self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype),
                        rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_bdims={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               dimension_numbers, bdims),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "dimension_numbers": dimension_numbers, "bdims": bdims}
      for lhs_shape, rhs_shape, dimension_numbers in [
          ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
          ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))),
          ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
      ]
      for bdims in all_bdims(lhs_shape, rhs_shape)
      for dtype in default_dtypes))
  def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype,
                                     dimension_numbers, bdims):
    rng = jtu.rand_small(self.rng())
    dot = partial(lax.dot_general, dimension_numbers=dimension_numbers)
    self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype),
                        rng)

    # Checks that batching didn't introduce any transposes or broadcasts.
    jaxpr = jax.make_jaxpr(dot)(np.zeros(lhs_shape, dtype),
                                np.zeros(rhs_shape, dtype))
    for eqn in jtu.iter_eqns(jaxpr.jaxpr):
      self.assertFalse(eqn.primitive in ["transpose", "broadcast"])

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}_bdims={}".format(
          shape, np.dtype(dtype).name, broadcast_sizes, bdims),
       "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes,
       "bdims": bdims}
      for shape in [(), (2, 3)]
      for dtype in default_dtypes
      for broadcast_sizes in [(), (2,), (1, 2)]
      for bdims in all_bdims(shape)))
  def testBroadcast(self, shape, dtype, broadcast_sizes, bdims):
    rng = jtu.rand_default(self.rng())
    op = lambda x: lax.broadcast(x, broadcast_sizes)
    self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_outshape={}_bcdims={}_bdims={}".format(
          jtu.format_shape_dtype_string(inshape, dtype),
          outshape, broadcast_dimensions, bdims),
       "inshape": inshape, "dtype": dtype, "outshape": outshape,
       "dimensions": broadcast_dimensions, "bdims": bdims}
      for inshape, outshape, broadcast_dimensions in [
          ([2], [2, 2], [0]),
          ([2], [2, 2], [1]),
          ([2], [2, 3], [0]),
          ([], [2, 3], []),
      ]
      for dtype in default_dtypes
      for bdims in all_bdims(inshape)))
  @unittest.skip("this test has failures in some cases")  # TODO(mattjj)
  def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims):
    rng = jtu.rand_default(self.rng())
    op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
    self._CheckBatching(op, 5, bdims, (inshape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_dimensions={}_bdims={}".format(
          jtu.format_shape_dtype_string(arg_shape, np.float32),
          dimensions, bdims),
       "arg_shape": arg_shape, "dimensions": dimensions, "bdims": bdims}
      for arg_shape, dimensions in [
          [(1,), (0,)],
          [(1,), (-1,)],
          [(2, 1, 4), (1,)],
          [(2, 1, 4), (-2,)],
          [(2, 1, 3, 1), (1,)],
          [(2, 1, 3, 1), (1, 3)],
          [(2, 1, 3, 1), (3,)],
          [(2, 1, 3, 1), (1, -1)],
      ]
      for bdims in all_bdims(arg_shape)))
  def testSqueeze(self, arg_shape, dimensions, bdims):
    dtype = np.float32
    rng = jtu.rand_default(self.rng())
    op = lambda x: lax.squeeze(x, dimensions)
    self._CheckBatching(op, 10, bdims, (arg_shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_outshape={}_dims={}_bdims={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          jtu.format_shape_dtype_string(out_shape, dtype),
          dimensions, bdims),
       "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
       "dimensions": dimensions, "bdims": bdims}
      for dtype in default_dtypes
      for arg_shape, dimensions, out_shape in [
          [(3, 4), None, (12,)],
          [(2, 1, 4), None, (8,)],
          [(2, 2, 4), None, (2, 8)],
          [(2, 2, 4), (0, 1, 2), (2, 8)],
          [(2, 2, 4), (1, 0, 2), (8, 2)],
          [(2, 2, 4), (2, 1, 0), (4, 2, 2)]
      ]
      for bdims in all_bdims(arg_shape)))
  def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims):
    rng = jtu.rand_default(self.rng())
    op = lambda x: lax.reshape(x, out_shape, dimensions=dimensions)
    self._CheckBatching(op, 10, bdims, (arg_shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_pads={}_bdims={}"
       .format(jtu.format_shape_dtype_string(shape, dtype), pads, bdims),
       "shape": shape, "dtype": dtype, "pads": pads, "bdims": bdims}
      for shape in [(2, 3)]
      for bdims in all_bdims(shape, ())
      for dtype in default_dtypes
      for pads in [[(1, 2, 1), (0, 1, 0)]]))
  def testPad(self, shape, dtype, pads, bdims):
    rng = jtu.rand_small(self.rng())
    fun = lambda operand, padding: lax.pad(operand, padding, pads)
    self._CheckBatching(fun, 5, bdims, (shape, ()), (dtype, dtype), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_predshape={}_argshapes={}_bdims={}".format(
          jtu.format_shape_dtype_string(pred_shape, np.bool_),
          jtu.format_shape_dtype_string(arg_shape, arg_dtype),
          bdims),
       "pred_shape": pred_shape, "arg_shape": arg_shape, "arg_dtype": arg_dtype,
       "bdims": bdims}
      for arg_shape in [(), (3,), (2, 3)]
      for pred_shape in ([(), arg_shape] if arg_shape else [()])
      for bdims in all_bdims(pred_shape, arg_shape, arg_shape)
      for arg_dtype in default_dtypes))
  def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims):
    rng = jtu.rand_default(self.rng())
    op = lambda c, x, y: lax.select(c < 0, x, y)
    self._CheckBatching(op, 5, bdims, (pred_shape, arg_shape, arg_shape,),
                        (np.bool_, arg_dtype, arg_dtype), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_start_indices={}_limit_indices={}_strides={}_bdims={}".format(
          jtu.format_shape_dtype_string(shape, dtype),
          start_indices, limit_indices, strides, bdims),
       "shape": shape, "dtype": dtype, "starts": start_indices,
       "limits": limit_indices, "strides": strides, "bdims": bdims}
      for shape, start_indices, limit_indices, strides in [
        [(3,), (1,), (2,), None],
        [(7,), (4,), (7,), None],
        [(5,), (1,), (5,), (2,)],
        [(8,), (1,), (6,), (2,)],
        [(5, 3), (1, 1), (3, 2), None],
        [(5, 3), (1, 1), (3, 1), None],
        [(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
        [(5, 3), (1, 1), (2, 1), (1, 1)],
        [(5, 3), (1, 1), (5, 3), (2, 1)],
      ]
      for bdims in all_bdims(shape)
      for dtype in default_dtypes))
  def testSlice(self, shape, dtype, starts, limits, strides, bdims):
    rng = jtu.rand_default(self.rng())
    op = lambda x: lax.slice(x, starts, limits, strides)
    self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_perm={}_bdims={}".format(
          jtu.format_shape_dtype_string(shape, dtype), perm, bdims),
       "shape": shape, "dtype": dtype, "perm": perm, "bdims": bdims}
      for shape, perm in [
        [(3, 4), (1, 0)],
        [(3, 4), (0, 1)],
        [(3, 4, 5), (2, 1, 0)],
        [(3, 4, 5), (1, 0, 2)],
      ]
      for bdims in all_bdims(shape)
      for dtype in default_dtypes))
  def testTranspose(self, shape, dtype, perm, bdims):
    rng = jtu.rand_default(self.rng())
    op = lambda x: lax.transpose(x, perm)
    self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_inshape={}_reducedims={}_initval={}_bdims={}"
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims,
               init_val, bdims),
       "op": op, "init_val": init_val, "shape": shape, "dtype": dtype,
       "dims": dims, "bdims": bdims}
      for init_val, op, dtypes in [
          (0, lax.add, default_dtypes),
          (1, lax.mul, default_dtypes),
          (0, lax.max, all_dtypes), # non-monoidal
          (-np.inf, lax.max, float_dtypes),
          (dtypes.iinfo(np.int32).min, lax.max, [np.int32]),
          (dtypes.iinfo(np.int64).min, lax.max, [np.int64]),
          (dtypes.iinfo(np.uint32).min, lax.max, [np.uint32]),
          (dtypes.iinfo(np.uint64).min, lax.max, [np.uint64]),
          (np.inf, lax.min, float_dtypes),
          (dtypes.iinfo(np.int32).max, lax.min, [np.int32]),
          (dtypes.iinfo(np.int64).max, lax.min, [np.int64]),
          (dtypes.iinfo(np.uint32).max, lax.min, [np.uint32]),
          (dtypes.iinfo(np.uint64).max, lax.min, [np.uint64]),
      ]
      for dtype in dtypes
      for shape, dims in [
          [(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)],
          [(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]
      ]
      for bdims in all_bdims(shape)))
  def testReduce(self, op, init_val, shape, dtype, dims, bdims):
    rng = jtu.rand_small(self.rng())
    init_val = np.asarray(init_val, dtype=dtype)
    fun = lambda operand: lax.reduce(operand, init_val, op, dims)
    self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_reducedims={}_bdims={}"
       .format(jtu.format_shape_dtype_string(shape, dtype), dims, bdims),
       "shape": shape, "dtype": dtype, "dims": dims, "bdims": bdims}
      for dtype in default_dtypes
      for shape, dims in [
          [(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)],
          [(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]
      ]
      for bdims in all_bdims(shape, shape)))
  def testVariadicReduce(self, shape, dtype, dims, bdims):
    def op(a, b):
      x1, y1 = a
      x2, y2 = b
      return x1 + x2, y1 * y2
    rng = jtu.rand_small(self.rng())
    init_val = tuple(np.asarray([0, 1], dtype=dtype))
    fun = lambda x, y: lax.reduce((x, y), init_val, op, dims)
    self._CheckBatching(fun, 5, bdims, (shape, shape), (dtype, dtype), rng,
                        multiple_results=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_inshape={}_reducedims={}_bdims={}"
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dim,
               bdims),
       "op": op, "shape": shape, "dtype": dtype,
       "dim": dim, "bdims": bdims}
      for op in [lax.argmin, lax.argmax]
      for dtype in default_dtypes
      for shape in [(3, 4, 5)]
      for dim in range(len(shape))
      for bdims in all_bdims(shape)))
  def testArgminmax(self, op, shape, dtype, dim, bdims):
    rng = jtu.rand_default(self.rng())
    fun = lambda operand: op(operand, dim, np.int32)
    self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}"
                         "_basedilation={}_windowdilation={}")
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype),
               dims, strides, padding, base_dilation, window_dilation),
       "op": op, "init_val": init_val, "dtype": dtype, "shape": shape,
       "dims": dims, "strides": strides, "padding": padding,
       "base_dilation": base_dilation, "window_dilation": window_dilation}
      for init_val, op, dtypes in [
          (0, lax.add, [np.float32]),
          (-np.inf, lax.max, [np.float32]),
          (np.inf, lax.min, [np.float32]),
      ]
      for shape, dims, strides, padding, base_dilation, window_dilation in (
        itertools.chain(
          itertools.product(
            [(4, 6)],
            [(2, 1), (1, 2)],
            [(1, 1), (2, 1), (1, 2)],
            ["VALID", "SAME", [(0, 3), (1, 2)]],
            [(1, 1), (2, 3)],
            [(1, 1), (1, 2)]),
          itertools.product(
            [(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
            [(1, 2, 2, 1), (1, 1, 1, 1)],
            ["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]],
            [(1, 1, 1, 1), (2, 1, 3, 2)],
            [(1, 1, 1, 1), (1, 2, 2, 1)])))
      for dtype in dtypes))
  def testReduceWindow(self, op, init_val, dtype, shape, dims, strides, padding,
                       base_dilation, window_dilation):
    rng = jtu.rand_small(self.rng())
    init_val = np.asarray(init_val, dtype=dtype)

    def fun(operand):
      return lax.reduce_window(operand, init_val, op, dims, strides, padding,
                               base_dilation, window_dilation)

    for bdims in all_bdims(shape):
      self._CheckBatching(fun, 3, bdims, (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_shape={}_axis={}_bdims={}_reverse={}"
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis,
               bdims, reverse),
       "op": op, "shape": shape, "dtype": dtype, "bdims": bdims,
       "axis": axis, "reverse": reverse}
      for op, types in [
          (lax.cumsum, [np.float32, np.float64]),
          (lax.cumprod, [np.float32, np.float64]),
      ]
      for dtype in types
      for shape in [[10], [3, 4, 5]]
      for axis in range(len(shape))
      for bdims in all_bdims(shape)
      for reverse in [False, True]))
  def testCumulativeReduce(self, op, shape, dtype, axis, bdims, reverse):
    rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
                   else jtu.rand_small)
    rng = rng_factory(self.rng())
    self._CheckBatching(partial(op, axis=axis, reverse=reverse), 7, bdims,
                        (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_dtype={}_padding={}".format(np.dtype(dtype).name,
                                                      padding),
       "dtype": dtype, "padding": padding}
      for dtype in float_dtypes
      for padding in ["VALID", "SAME"]))
  @jtu.skip_on_flag("jax_skip_slow_tests", True)
  @jtu.ignore_warning(message="Using reduced precision for gradient.*")
  def testSelectAndGatherAdd(self, dtype, padding):
    rng = jtu.rand_small(self.rng())
    all_configs = itertools.chain(
        itertools.product(
            [(4, 6)],
            [(2, 1), (1, 2)],
            [(1, 1), (2, 1), (1, 2)]),
        itertools.product(
            [(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
            [(1, 2, 2, 1), (1, 1, 1, 1)]))

    def fun(operand, tangents):
      pads = lax.padtype_to_pads(operand.shape, dims, strides, padding)
      ones = (1,) * len(operand.shape)
      return lax._select_and_gather_add(operand, tangents, lax.ge_p, dims,
                                        strides, pads, ones, ones)

    for shape, dims, strides in all_configs:
      for bdims in all_bdims(shape, shape):
        self._CheckBatching(fun, 3, bdims, (shape, shape), (dtype, dtype), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": f"_dtype={jtu.format_shape_dtype_string(shape, dtype)}"
      f"_padding={padding}_dims={dims}_strides={strides}",
       "dtype": dtype, "padding": padding, "shape": shape,
       "dims": dims, "strides": strides}
      for dtype in float_dtypes
      for padding in ["VALID", "SAME"]
      for shape in [(3, 2, 4, 6)]
      for dims in [(1, 1, 2, 1)]
      for strides in [(1, 2, 2, 1), (1, 1, 1, 1)]))
  def testSelectAndScatterAdd(self, dtype, padding, shape, dims, strides):
    rng = jtu.rand_small(self.rng())

    pads = lax.padtype_to_pads(shape, dims, strides, padding)

    def fun(operand, cotangents):
      return lax._select_and_scatter_add(operand, cotangents, lax.ge_p, dims,
                                         strides, pads)
    ones = (1,) * len(shape)
    cotangent_shape = jax.eval_shape(
      lambda x: lax._select_and_gather_add(x, x, lax.ge_p, dims, strides,
                                           pads, ones, ones),
      np.ones(shape, dtype)).shape

    for bdims in all_bdims(cotangent_shape, shape):
      self._CheckBatching(fun, 3, bdims, (cotangent_shape, shape),
                          (dtype, dtype), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_bdims={}_fft_ndims={}"
       .format(shape, bdims, fft_ndims),
       "shape": shape, "bdims": bdims, "fft_ndims": fft_ndims}
      for shape in [(5,), (3, 4, 5), (2, 3, 4, 5)]
      for bdims in all_bdims(shape)
      for fft_ndims in range(0, min(3, len(shape)) + 1)))
  def testFft(self, fft_ndims, shape, bdims):
    rng = jtu.rand_default(self.rng())
    ndims = len(shape)
    axes = range(ndims - fft_ndims, ndims)
    fft_lengths = tuple(shape[axis] for axis in axes)
    op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths)
    self._CheckBatching(op, 5, bdims, [shape], [np.complex64], rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}_bdims={}"
       .format(jtu.format_shape_dtype_string(shape, dtype), idxs, dnums,
               slice_sizes, bdims),
       "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
       "slice_sizes": slice_sizes, "bdims": bdims}
      for dtype in all_dtypes
      for shape, idxs, dnums, slice_sizes in [
          ((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers(
            offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1,)),
          ((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
            (2,)),
          ((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1, 3)),
          ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
            (1, 3)),
      ]
      for bdims in all_bdims(shape, idxs.shape)))
  def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims):
    fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
    self._CheckBatching(fun, 0, bdims, [shape, idxs.shape], [dtype, idxs.dtype],
                        jtu.rand_default(self.rng()))
    self._CheckBatching(fun, 5, bdims, [shape, idxs.shape], [dtype, idxs.dtype],
                        jtu.rand_default(self.rng()))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}_bdims={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums, bdims),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums, "bdims": bdims}
      for dtype in float_dtypes
      for arg_shape, idxs, update_shape, dnums in [
          ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
          ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,))),
          ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
      ]
      for bdims in all_bdims(arg_shape, idxs.shape, update_shape)))
  def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims):
    fun = partial(lax.scatter_add, dimension_numbers=dnums)
    self._CheckBatching(fun, 5, bdims, [arg_shape, idxs.shape, update_shape],
                        [dtype, idxs.dtype, dtype], jtu.rand_default(self.rng()),
                        rtol={np.float16: 5e-3, dtypes.bfloat16: 3e-2})

  def testShapeUsesBuiltinInt(self):
    x = lax.iota(np.int32, 3) + 1
    self.assertIsInstance(x.shape[0], int)  # not np.int64

  def testBroadcastShapesReturnsPythonInts(self):
    shape1, shape2 = (1, 2, 3), (2, 3)
    out_shape = lax.broadcast_shapes(shape1, shape2)
    self.assertTrue(all(type(s) is int for s in out_shape))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_k={}_bdims={}".format(
          jtu.format_shape_dtype_string(shape, dtype), k, bdims),
       "shape": shape, "dtype": dtype, "k": k, "bdims": bdims, "rng_factory": rng_factory}
      for shape in [(4,), (3, 5, 3)]
      for k in [1, 3]
      for bdims in all_bdims(shape)
      # TODO(b/155170120): test with repeats once the XLA:CPU stable top_k bug is fixed:
      # The top_k indices for integer arrays with identical entries won't match between
      # vmap'd version and manual reference, so only test unique integer arrays for int_dtypes.
      # Note also that we chose 3 * 5 * 3 * 5 such that it fits in the range of
      # values a bfloat16 can represent exactly to avoid ties.
      for dtype, rng_factory in itertools.chain(
        unsafe_zip(default_dtypes, itertools.repeat(jtu.rand_unique_int)))))
  def testTopK(self, shape, dtype, k, bdims, rng_factory):
    rng = rng_factory(self.rng())
    # _CheckBatching doesn't work with tuple outputs, so test outputs separately.
    op1 = lambda x: lax.top_k(x, k=k)[0]
    self._CheckBatching(op1, 5, bdims, (shape,), (dtype,), rng)
    op2 = lambda x: lax.top_k(x, k=k)[1]
    self._CheckBatching(op2, 5, bdims, (shape,), (dtype,), rng)


  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_dimension={}_arity={}_bdims={}_isstable={}"
       .format(jtu.format_shape_dtype_string(shape, np.float32), dimension,
               arity, bdims, is_stable),
       "shape": shape, "dimension": dimension, "arity": arity, "bdims": bdims,
       "is_stable": is_stable}
      for shape in [(2, 3)]
      for dimension in [0, 1]
      for arity in range(3)
      for bdims in all_bdims(*((shape,) * arity))
      for is_stable in [False, True]))
  def testSort(self, shape, dimension, arity, bdims, is_stable):
    rng = jtu.rand_default(self.rng())
    if arity == 1:
      fun = partial(lax.sort, dimension=dimension)
      self._CheckBatching(fun, 5, bdims, (shape,) * arity, (np.float32,) * arity,
                          rng)
    else:
      for i in range(arity):
        fun = lambda *args, i=i: lax.sort(args,
                                          dimension=dimension,
                                          is_stable=is_stable)[i]
        self._CheckBatching(fun, 5, bdims, (shape,) * arity,
                            (np.float32,) * arity, rng)
class DLPackTest(jtu.JaxTestCase):
  def setUp(self):
    super().setUp()
    if jtu.device_under_test() == "tpu":
      self.skipTest("DLPack not supported on TPU")

  @parameterized.named_parameters(jtu.cases_from_list(
     {"testcase_name": "_{}_take_ownership={}_gpu={}".format(
        jtu.format_shape_dtype_string(shape, dtype),
        take_ownership, gpu),
      "shape": shape, "dtype": dtype, "take_ownership": take_ownership,
      "gpu": gpu}
     for shape in all_shapes
     for dtype in dlpack_dtypes
     for take_ownership in [False, True]
     for gpu in [False, True]))
  @jtu.skip_on_devices("rocm") # TODO(sharadmv,phawkins): see GH issue #10973
  def testJaxRoundTrip(self, shape, dtype, take_ownership, gpu):
    rng = jtu.rand_default(self.rng())
    np = rng(shape, dtype)
    if gpu and jax.default_backend() == "cpu":
      raise unittest.SkipTest("Skipping GPU test case on CPU")
    device = jax.devices("gpu" if gpu else "cpu")[0]
    x = jax.device_put(np, device)
    dlpack = jax.dlpack.to_dlpack(x, take_ownership=take_ownership)
    self.assertEqual(take_ownership, x.device_buffer.is_deleted())
    y = jax.dlpack.from_dlpack(dlpack)
    self.assertEqual(y.device(), device)
    self.assertAllClose(np.astype(x.dtype), y)

    self.assertRaisesRegex(RuntimeError,
                           "DLPack tensor may be consumed at most once",
                           lambda: jax.dlpack.from_dlpack(dlpack))

  @parameterized.named_parameters(jtu.cases_from_list(
     {"testcase_name": "_{}".format(
        jtu.format_shape_dtype_string(shape, dtype)),
     "shape": shape, "dtype": dtype}
     for shape in all_shapes
     for dtype in dlpack_dtypes))
  @unittest.skipIf(not tf, "Test requires TensorFlow")
  @jtu.skip_on_devices("rocm") # TODO(sharadmv,phawkins): see GH issue #10973
  def testTensorFlowToJax(self, shape, dtype):
    if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64, jnp.float64]:
      raise self.skipTest("x64 types are disabled by jax_enable_x64")
    if (jtu.device_under_test() == "gpu" and
        not tf.config.list_physical_devices("GPU")):
      raise self.skipTest("TensorFlow not configured with GPU support")

    if jtu.device_under_test() == "gpu" and dtype == jnp.int32:
      raise self.skipTest("TensorFlow does not place int32 tensors on GPU")

    rng = jtu.rand_default(self.rng())
    np = rng(shape, dtype)
    with tf.device("/GPU:0" if jtu.device_under_test() == "gpu" else "/CPU:0"):
      x = tf.identity(tf.constant(np))
    dlpack = tf.experimental.dlpack.to_dlpack(x)
    y = jax.dlpack.from_dlpack(dlpack)
    self.assertAllClose(np, y)

  @parameterized.named_parameters(jtu.cases_from_list(
     {"testcase_name": "_{}".format(
        jtu.format_shape_dtype_string(shape, dtype)),
     "shape": shape, "dtype": dtype}
     for shape in all_shapes
     for dtype in dlpack_dtypes))
  @unittest.skipIf(not tf, "Test requires TensorFlow")
  def testJaxToTensorFlow(self, shape, dtype):
    if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64,
                                              jnp.float64]:
      self.skipTest("x64 types are disabled by jax_enable_x64")
    if (jtu.device_under_test() == "gpu" and
        not tf.config.list_physical_devices("GPU")):
      raise self.skipTest("TensorFlow not configured with GPU support")
    rng = jtu.rand_default(self.rng())
    np = rng(shape, dtype)
    x = jnp.array(np)
    # TODO(b/171320191): this line works around a missing context initialization
    # bug in TensorFlow.
    _ = tf.add(1, 1)
    dlpack = jax.dlpack.to_dlpack(x)
    y = tf.experimental.dlpack.from_dlpack(dlpack)
    self.assertAllClose(np, y.numpy())

  @parameterized.named_parameters(jtu.cases_from_list(
     {"testcase_name": "_{}".format(
        jtu.format_shape_dtype_string(shape, dtype)),
     "shape": shape, "dtype": dtype}
     for shape in all_shapes
     for dtype in torch_dtypes))
  @unittest.skipIf(not torch, "Test requires PyTorch")
  def testTorchToJax(self, shape, dtype):
    if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]:
      self.skipTest("x64 types are disabled by jax_enable_x64")
    rng = jtu.rand_default(self.rng())
    np = rng(shape, dtype)
    x = torch.from_numpy(np)
    x = x.cuda() if jtu.device_under_test() == "gpu" else x
    dlpack = torch.utils.dlpack.to_dlpack(x)
    y = jax.dlpack.from_dlpack(dlpack)
    self.assertAllClose(np, y)

  @unittest.skipIf(not torch, "Test requires PyTorch")
  def testTorchToJaxFailure(self):
    x = torch.arange(6).reshape((2, 3))
    y = torch.utils.dlpack.to_dlpack(x[:, :2])

    backend = xla_bridge.get_backend()
    client = getattr(backend, "client", backend)

    regex_str = (r'UNIMPLEMENTED: Only DLPack tensors with trivial \(compact\) '
                 r'striding are supported')
    with self.assertRaisesRegex(RuntimeError, regex_str):
      xla_client._xla.dlpack_managed_tensor_to_buffer(
          y, client)

  @parameterized.named_parameters(jtu.cases_from_list(
     {"testcase_name": "_{}".format(
        jtu.format_shape_dtype_string(shape, dtype)),
     "shape": shape, "dtype": dtype}
     for shape in all_shapes
     for dtype in torch_dtypes))
  @unittest.skipIf(not torch, "Test requires PyTorch")
  def testJaxToTorch(self, shape, dtype):
    if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]:
      self.skipTest("x64 types are disabled by jax_enable_x64")
    rng = jtu.rand_default(self.rng())
    np = rng(shape, dtype)
    x = jnp.array(np)
    dlpack = jax.dlpack.to_dlpack(x)
    y = torch.utils.dlpack.from_dlpack(dlpack)
    self.assertAllClose(np, y.cpu().numpy())
Exemplo n.º 6
0
class LaxBackedScipySignalTests(jtu.JaxTestCase):
    """Tests for LAX-backed scipy.stats implementations"""
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_op={}_xshape={}_yshape={}_mode={}".format(
                    op, jtu.format_shape_dtype_string(xshape, dtype),
                    jtu.format_shape_dtype_string(yshape, dtype), mode),
                "xshape":
                xshape,
                "yshape":
                yshape,
                "dtype":
                dtype,
                "mode":
                mode,
                "jsp_op":
                getattr(jsp_signal, op),
                "osp_op":
                getattr(osp_signal, op)
            } for mode in ['full', 'same', 'valid']
            for op in ['convolve', 'correlate'] for dtype in default_dtypes
            for shapeset in [onedim_shapes, twodim_shapes, threedim_shapes]
            for xshape in shapeset for yshape in shapeset))
    def testConvolutions(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
        osp_fun = partial(osp_op, mode=mode)
        jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST)
        tol = {
            np.float16: 1e-2,
            np.float32: 1e-2,
            np.float64: 1e-12,
            np.complex64: 1e-2,
            np.complex128: 1e-12
        }
        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "op={}_xshape={}_yshape={}_mode={}".format(
                op, jtu.format_shape_dtype_string(xshape, dtype),
                jtu.format_shape_dtype_string(yshape, dtype), mode),
            "xshape":
            xshape,
            "yshape":
            yshape,
            "dtype":
            dtype,
            "mode":
            mode,
            "jsp_op":
            getattr(jsp_signal, op),
            "osp_op":
            getattr(osp_signal, op)
        } for mode in ['full', 'same', 'valid']
                            for op in ['convolve2d', 'correlate2d']
                            for dtype in default_dtypes
                            for xshape in twodim_shapes
                            for yshape in twodim_shapes))
    def testConvolutions2D(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
        osp_fun = partial(osp_op, mode=mode)
        jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST)
        tol = {
            np.float16: 1e-2,
            np.float32: 1e-2,
            np.float64: 1e-12,
            np.complex64: 1e-2,
            np.complex128: 1e-12
        }
        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_axis={}_type={}_bp={}".format(
                jtu.format_shape_dtype_string(shape, dtype), axis, type, bp),
            "shape":
            shape,
            "dtype":
            dtype,
            "axis":
            axis,
            "type":
            type,
            "bp":
            bp
        } for shape in [(5, ), (4, 5), (3, 4, 5)]
                            for dtype in jtu.dtypes.floating +
                            jtu.dtypes.integer for axis in [0, -1]
                            for type in ['constant', 'linear']
                            for bp in [0, [0, 2]]))
    @jtu.skip_on_devices("rocm")  # will be fixed in rocm-5.1
    def testDetrend(self, shape, dtype, axis, type, bp):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)]
        osp_fun = partial(osp_signal.detrend, axis=axis, type=type, bp=bp)
        jsp_fun = partial(jsp_signal.detrend, axis=axis, type=type, bp=bp)
        tol = {np.float32: 1e-5, np.float64: 1e-12}
        self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.
        cases_from_list({
            "testcase_name":
            f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
            f"_fs={fs}_window={window}_boundary={boundary}_detrend={detrend}"
            f"_padded={padded}_nperseg={nperseg}_noverlap={noverlap}"
            f"_axis={timeaxis}_nfft={nfft}",
            "shape":
            shape,
            "dtype":
            dtype,
            "fs":
            fs,
            "window":
            window,
            "nperseg":
            nperseg,
            "noverlap":
            noverlap,
            "nfft":
            nfft,
            "detrend":
            detrend,
            "boundary":
            boundary,
            "padded":
            padded,
            "timeaxis":
            timeaxis
        } for shape, nperseg, noverlap, timeaxis in stft_test_shapes
                        for dtype in default_dtypes for fs in [1.0, 16000.0]
                        for window in
                        ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
                        for nfft in
                        [None, nperseg,
                         int(nperseg * 1.5), nperseg * 2]
                        for detrend in ['constant', 'linear', False]
                        for boundary in [None, 'even', 'odd', 'zeros']
                        for padded in [True, False]))
    @jtu.skip_on_devices("rocm")  # will be fixed in ROCm 5.1
    def testStftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
                             noverlap, nfft, detrend, boundary, padded,
                             timeaxis):
        is_complex = np.dtype(dtype).kind == 'c'
        if is_complex and detrend is not None:
            return

        osp_fun = partial(osp_signal.stft,
                          fs=fs,
                          window=window,
                          nfft=nfft,
                          boundary=boundary,
                          padded=padded,
                          detrend=detrend,
                          nperseg=nperseg,
                          noverlap=noverlap,
                          axis=timeaxis,
                          return_onesided=not is_complex)
        jsp_fun = partial(jsp_signal.stft,
                          fs=fs,
                          window=window,
                          nfft=nfft,
                          boundary=boundary,
                          padded=padded,
                          detrend=detrend,
                          nperseg=nperseg,
                          noverlap=noverlap,
                          axis=timeaxis,
                          return_onesided=not is_complex)
        tol = {
            np.float32: 1e-5,
            np.float64: 1e-12,
            np.complex64: 1e-5,
            np.complex128: 1e-12
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)]

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    # Tests with `average == 'median'`` is excluded from `testCsd*`
    # due to the issue:
    #   https://github.com/scipy/scipy/issues/15601
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                f"_xshape={jtu.format_shape_dtype_string(xshape, dtype)}"
                f"_yshape={jtu.format_shape_dtype_string(yshape, dtype)}"
                f"_average={average}_scaling={scaling}_nfft={nfft}"
                f"_fs={fs}_window={window}_detrend={detrend}"
                f"_nperseg={nperseg}_noverlap={noverlap}"
                f"_axis={timeaxis}",
                "xshape":
                xshape,
                "yshape":
                yshape,
                "dtype":
                dtype,
                "fs":
                fs,
                "window":
                window,
                "nperseg":
                nperseg,
                "noverlap":
                noverlap,
                "nfft":
                nfft,
                "detrend":
                detrend,
                "scaling":
                scaling,
                "timeaxis":
                timeaxis,
                "average":
                average
            }
            for xshape, yshape, nperseg, noverlap, timeaxis in csd_test_shapes
            for dtype in default_dtypes for fs in [1.0, 16000.0]
            for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
            for nfft in [None, nperseg,
                         int(nperseg * 1.5), nperseg * 2]
            for detrend in ['constant', 'linear', False]
            for scaling in ['density', 'spectrum'] for average in ['mean']))
    @jtu.skip_on_devices("rocm")  # will be fixed in next ROCm version
    def testCsdAgainstNumpy(self, *, xshape, yshape, dtype, fs, window,
                            nperseg, noverlap, nfft, detrend, scaling,
                            timeaxis, average):
        is_complex = np.dtype(dtype).kind == 'c'
        if is_complex and detrend is not None:
            raise unittest.SkipTest(
                "Complex signal is not supported in lax-backed `signal.detrend`."
            )

        osp_fun = partial(osp_signal.csd,
                          fs=fs,
                          window=window,
                          nperseg=nperseg,
                          noverlap=noverlap,
                          nfft=nfft,
                          detrend=detrend,
                          return_onesided=not is_complex,
                          scaling=scaling,
                          axis=timeaxis,
                          average=average)
        jsp_fun = partial(jsp_signal.csd,
                          fs=fs,
                          window=window,
                          nperseg=nperseg,
                          noverlap=noverlap,
                          nfft=nfft,
                          detrend=detrend,
                          return_onesided=not is_complex,
                          scaling=scaling,
                          axis=timeaxis,
                          average=average)
        tol = {
            np.float32: 1e-5,
            np.float64: 1e-12,
            np.complex64: 1e-5,
            np.complex128: 1e-12
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
                f"_average={average}_scaling={scaling}_nfft={nfft}"
                f"_fs={fs}_window={window}_detrend={detrend}"
                f"_nperseg={nperseg}_noverlap={noverlap}"
                f"_axis={timeaxis}",
                "shape":
                shape,
                "dtype":
                dtype,
                "fs":
                fs,
                "window":
                window,
                "nperseg":
                nperseg,
                "noverlap":
                noverlap,
                "nfft":
                nfft,
                "detrend":
                detrend,
                "scaling":
                scaling,
                "timeaxis":
                timeaxis,
                "average":
                average
            } for shape, unused_yshape, nperseg, noverlap, timeaxis in
            csd_test_shapes for dtype in default_dtypes
            for fs in [1.0, 16000.0]
            for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
            for nfft in [None, nperseg,
                         int(nperseg * 1.5), nperseg * 2]
            for detrend in ['constant', 'linear', False]
            for scaling in ['density', 'spectrum'] for average in ['mean']))
    @jtu.skip_on_devices("rocm")  # will be fixed in next rocm release
    def testCsdWithSameParamAgainstNumpy(self, *, shape, dtype, fs, window,
                                         nperseg, noverlap, nfft, detrend,
                                         scaling, timeaxis, average):
        is_complex = np.dtype(dtype).kind == 'c'
        if is_complex and detrend is not None:
            raise unittest.SkipTest(
                "Complex signal is not supported in lax-backed `signal.detrend`."
            )

        def osp_fun(x, y):
            # When the identical parameters are given, jsp-version follows
            # the behavior with copied parameters.
            freqs, Pxy = osp_signal.csd(x,
                                        y.copy(),
                                        fs=fs,
                                        window=window,
                                        nperseg=nperseg,
                                        noverlap=noverlap,
                                        nfft=nfft,
                                        detrend=detrend,
                                        return_onesided=not is_complex,
                                        scaling=scaling,
                                        axis=timeaxis,
                                        average=average)
            return freqs, Pxy

        jsp_fun = partial(jsp_signal.csd,
                          fs=fs,
                          window=window,
                          nperseg=nperseg,
                          noverlap=noverlap,
                          nfft=nfft,
                          detrend=detrend,
                          return_onesided=not is_complex,
                          scaling=scaling,
                          axis=timeaxis,
                          average=average)

        tol = {
            np.float32: 1e-5,
            np.float64: 1e-12,
            np.complex64: 1e-5,
            np.complex128: 1e-12
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)] * 2

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
                f"_fs={fs}_window={window}"
                f"_nperseg={nperseg}_noverlap={noverlap}_nfft={nfft}"
                f"_detrend={detrend}_return_onesided={return_onesided}"
                f"_scaling={scaling}_axis={timeaxis}_average={average}",
                "shape":
                shape,
                "dtype":
                dtype,
                "fs":
                fs,
                "window":
                window,
                "nperseg":
                nperseg,
                "noverlap":
                noverlap,
                "nfft":
                nfft,
                "detrend":
                detrend,
                "return_onesided":
                return_onesided,
                "scaling":
                scaling,
                "timeaxis":
                timeaxis,
                "average":
                average
            } for shape, nperseg, noverlap, timeaxis in welch_test_shapes
            for dtype in default_dtypes for fs in [1.0, 16000.0]
            for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
            for nfft in [None, nperseg,
                         int(nperseg * 1.5), nperseg * 2]
            for detrend in ['constant', 'linear', False]
            for return_onesided in [True, False]
            for scaling in ['density', 'spectrum']
            for average in ['mean', 'median']))
    @jtu.skip_on_devices("rocm")  # will be fixed in next ROCm release
    def testWelchAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
                              noverlap, nfft, detrend, return_onesided,
                              scaling, timeaxis, average):
        if np.dtype(dtype).kind == 'c':
            return_onesided = False
            if detrend is not None:
                raise unittest.SkipTest(
                    "Complex signal is not supported in lax-backed `signal.detrend`."
                )

        osp_fun = partial(osp_signal.welch,
                          fs=fs,
                          window=window,
                          nperseg=nperseg,
                          noverlap=noverlap,
                          nfft=nfft,
                          detrend=detrend,
                          return_onesided=return_onesided,
                          scaling=scaling,
                          axis=timeaxis,
                          average=average)
        jsp_fun = partial(jsp_signal.welch,
                          fs=fs,
                          window=window,
                          nperseg=nperseg,
                          noverlap=noverlap,
                          nfft=nfft,
                          detrend=detrend,
                          return_onesided=return_onesided,
                          scaling=scaling,
                          axis=timeaxis,
                          average=average)
        tol = {
            np.float32: 1e-5,
            np.float64: 1e-12,
            np.complex64: 1e-5,
            np.complex128: 1e-12
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)]

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
                f"_nperseg={nperseg}_noverlap={noverlap}"
                f"_use_nperseg={use_nperseg}_use_overlap={use_noverlap}"
                f"_axis={timeaxis}",
                "shape":
                shape,
                "dtype":
                dtype,
                "nperseg":
                nperseg,
                "noverlap":
                noverlap,
                "use_nperseg":
                use_nperseg,
                "use_noverlap":
                use_noverlap,
                "timeaxis":
                timeaxis
            } for shape, nperseg, noverlap, timeaxis in welch_test_shapes
            for use_nperseg in [False, True] for use_noverlap in [False, True]
            for dtype in jtu.dtypes.floating + jtu.dtypes.integer))
    def testWelchWithDefaultStepArgsAgainstNumpy(self, *, shape, dtype,
                                                 nperseg, noverlap,
                                                 use_nperseg, use_noverlap,
                                                 timeaxis):
        kwargs = {'axis': timeaxis}

        if use_nperseg:
            kwargs['nperseg'] = nperseg
        else:
            kwargs['window'] = osp_signal.get_window('hann', nperseg)
        if use_noverlap:
            kwargs['noverlap'] = noverlap

        osp_fun = partial(osp_signal.welch, **kwargs)
        jsp_fun = partial(jsp_signal.welch, **kwargs)
        tol = {
            np.float32: 1e-5,
            np.float64: 1e-12,
            np.complex64: 1e-5,
            np.complex128: 1e-12
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)]

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
Exemplo n.º 7
0
class NNInitializersTest(jtu.JaxTestCase):
    def setUp(self):
        super().setUp()
        config.update("jax_numpy_rank_promotion", "raise")

    def tearDown(self):
        super().tearDown()
        config.update("jax_numpy_rank_promotion", "allow")

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_{}".format(rec.name,
                            jtu.format_shape_dtype_string(shape, dtype)),
            "initializer":
            rec.initializer(),
            "shape":
            shape,
            "dtype":
            dtype
        } for rec in INITIALIZER_RECS for shape in rec.shapes
                            for dtype in rec.dtypes))
    def testInitializer(self, initializer, shape, dtype):
        rng = random.PRNGKey(0)
        val = initializer(rng, shape, dtype)

        self.assertEqual(shape, jnp.shape(val))
        self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), jnp.dtype(val))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_{}".format(rec.name,
                            jtu.format_shape_dtype_string(shape, dtype)),
            "initializer_provider":
            rec.initializer,
            "shape":
            shape,
            "dtype":
            dtype
        } for rec in INITIALIZER_RECS for shape in rec.shapes
                            for dtype in rec.dtypes))
    def testInitializerProvider(self, initializer_provider, shape, dtype):
        rng = random.PRNGKey(0)
        initializer = initializer_provider(dtype=dtype)
        val = initializer(rng, shape)

        self.assertEqual(shape, jnp.shape(val))
        self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), jnp.dtype(val))

    def testVarianceScalingMultiAxis(self):
        rng = random.PRNGKey(0)
        shape = (2, 3, 4, 5)
        initializer = nn.initializers.variance_scaling(
            scale=1.0,
            mode='fan_avg',
            distribution='truncated_normal',
            in_axis=(0, 1),
            out_axis=(-2, -1))
        val = initializer(rng, shape)

        self.assertEqual(shape, jnp.shape(val))
Exemplo n.º 8
0
class LaxBackedScipySignalTests(jtu.JaxTestCase):
    """Tests for LAX-backed scipy.stats implementations"""
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_op={}_xshape={}_yshape={}_mode={}".format(
                    op, jtu.format_shape_dtype_string(xshape, dtype),
                    jtu.format_shape_dtype_string(yshape, dtype), mode),
                "xshape":
                xshape,
                "yshape":
                yshape,
                "dtype":
                dtype,
                "mode":
                mode,
                "jsp_op":
                getattr(jsp_signal, op),
                "osp_op":
                getattr(osp_signal, op)
            } for mode in ['full', 'same', 'valid']
            for op in ['convolve', 'correlate'] for dtype in default_dtypes
            for shapeset in [onedim_shapes, twodim_shapes, threedim_shapes]
            for xshape in shapeset for yshape in shapeset))
    def testConvolutions(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
        osp_fun = partial(osp_op, mode=mode)
        jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST)
        tol = {
            np.float16: 1e-2,
            np.float32: 1e-2,
            np.float64: 1e-12,
            np.complex64: 1e-2,
            np.complex128: 1e-12
        }
        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "op={}_xshape={}_yshape={}_mode={}".format(
                op, jtu.format_shape_dtype_string(xshape, dtype),
                jtu.format_shape_dtype_string(yshape, dtype), mode),
            "xshape":
            xshape,
            "yshape":
            yshape,
            "dtype":
            dtype,
            "mode":
            mode,
            "jsp_op":
            getattr(jsp_signal, op),
            "osp_op":
            getattr(osp_signal, op)
        } for mode in ['full', 'same', 'valid']
                            for op in ['convolve2d', 'correlate2d']
                            for dtype in default_dtypes
                            for xshape in twodim_shapes
                            for yshape in twodim_shapes))
    def testConvolutions2D(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
        osp_fun = partial(osp_op, mode=mode)
        jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST)
        tol = {
            np.float16: 1e-2,
            np.float32: 1e-2,
            np.float64: 1e-12,
            np.complex64: 1e-2,
            np.complex128: 1e-12
        }
        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_axis={}_type={}_bp={}".format(
                jtu.format_shape_dtype_string(shape, dtype), axis, type, bp),
            "shape":
            shape,
            "dtype":
            dtype,
            "axis":
            axis,
            "type":
            type,
            "bp":
            bp
        } for shape in [(5, ), (4, 5), (3, 4, 5)]
                            for dtype in jtu.dtypes.floating +
                            jtu.dtypes.integer for axis in [0, -1]
                            for type in ['constant', 'linear']
                            for bp in [0, [0, 2]]))
    def testDetrend(self, shape, dtype, axis, type, bp):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)]
        kwds = dict(axis=axis, type=type, bp=bp)

        def osp_fun(x):
            return osp_signal.detrend(x, **kwds).astype(
                dtypes._to_inexact_dtype(x.dtype))

        jsp_fun = partial(jsp_signal.detrend, **kwds)

        if jtu.device_under_test() == 'tpu':
            tol = {np.float32: 3e-2, np.float64: 1e-12}
        else:
            tol = {np.float32: 1e-5, np.float64: 1e-12}

        self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.
        cases_from_list({
            "testcase_name":
            f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
            f"_fs={fs}_window={window}_boundary={boundary}_detrend={detrend}"
            f"_padded={padded}_nperseg={nperseg}_noverlap={noverlap}"
            f"_axis={timeaxis}_nfft={nfft}",
            "shape":
            shape,
            "dtype":
            dtype,
            "fs":
            fs,
            "window":
            window,
            "nperseg":
            nperseg,
            "noverlap":
            noverlap,
            "nfft":
            nfft,
            "detrend":
            detrend,
            "boundary":
            boundary,
            "padded":
            padded,
            "timeaxis":
            timeaxis
        } for shape, nperseg, noverlap, timeaxis in stft_test_shapes
                        for dtype in default_dtypes for fs in [1.0, 16000.0]
                        for window in
                        ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
                        for nfft in
                        [None, nperseg,
                         int(nperseg * 1.5), nperseg * 2]
                        for detrend in ['constant', 'linear', False]
                        for boundary in [None, 'even', 'odd', 'zeros']
                        for padded in [True, False]))
    def testStftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
                             noverlap, nfft, detrend, boundary, padded,
                             timeaxis):
        is_complex = dtypes.issubdtype(dtype, np.complexfloating)
        if is_complex and detrend is not None:
            self.skipTest(
                "Complex signal is not supported in lax-backed `signal.detrend`."
            )

        kwds = dict(fs=fs,
                    window=window,
                    nfft=nfft,
                    boundary=boundary,
                    padded=padded,
                    detrend=detrend,
                    nperseg=nperseg,
                    noverlap=noverlap,
                    axis=timeaxis,
                    return_onesided=not is_complex)

        def osp_fun(x):
            freqs, time, Pxx = osp_signal.stft(x, **kwds)
            return freqs.astype(_real_dtype(dtype)), time.astype(
                _real_dtype(dtype)), Pxx.astype(_complex_dtype(dtype))

        jsp_fun = partial(jsp_signal.stft, **kwds)
        tol = {
            np.float32: 1e-5,
            np.float64: 1e-12,
            np.complex64: 1e-5,
            np.complex128: 1e-12
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)]

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    # Tests with `average == 'median'`` is excluded from `testCsd*`
    # due to the issue:
    #   https://github.com/scipy/scipy/issues/15601
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                f"_xshape={jtu.format_shape_dtype_string(xshape, dtype)}"
                f"_yshape={jtu.format_shape_dtype_string(yshape, dtype)}"
                f"_average={average}_scaling={scaling}_nfft={nfft}"
                f"_fs={fs}_window={window}_detrend={detrend}"
                f"_nperseg={nperseg}_noverlap={noverlap}"
                f"_axis={timeaxis}",
                "xshape":
                xshape,
                "yshape":
                yshape,
                "dtype":
                dtype,
                "fs":
                fs,
                "window":
                window,
                "nperseg":
                nperseg,
                "noverlap":
                noverlap,
                "nfft":
                nfft,
                "detrend":
                detrend,
                "scaling":
                scaling,
                "timeaxis":
                timeaxis,
                "average":
                average
            }
            for xshape, yshape, nperseg, noverlap, timeaxis in csd_test_shapes
            for dtype in default_dtypes for fs in [1.0, 16000.0]
            for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
            for nfft in [None, nperseg,
                         int(nperseg * 1.5), nperseg * 2]
            for detrend in ['constant', 'linear', False]
            for scaling in ['density', 'spectrum'] for average in ['mean']))
    def testCsdAgainstNumpy(self, *, xshape, yshape, dtype, fs, window,
                            nperseg, noverlap, nfft, detrend, scaling,
                            timeaxis, average):
        is_complex = dtypes.issubdtype(dtype, np.complexfloating)
        if is_complex and detrend is not None:
            self.skipTest(
                "Complex signal is not supported in lax-backed `signal.detrend`."
            )

        kwds = dict(fs=fs,
                    window=window,
                    nperseg=nperseg,
                    noverlap=noverlap,
                    nfft=nfft,
                    detrend=detrend,
                    return_onesided=not is_complex,
                    scaling=scaling,
                    axis=timeaxis,
                    average=average)

        def osp_fun(x, y):
            freqs, Pxy = osp_signal.csd(x, y, **kwds)
            # Make type-casting the same as JAX.
            return freqs.astype(_real_dtype(dtype)), Pxy.astype(
                _complex_dtype(dtype))

        jsp_fun = partial(jsp_signal.csd, **kwds)
        tol = {
            np.float32: 1e-5,
            np.float64: 1e-12,
            np.complex64: 1e-5,
            np.complex128: 1e-12
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
                f"_average={average}_scaling={scaling}_nfft={nfft}"
                f"_fs={fs}_window={window}_detrend={detrend}"
                f"_nperseg={nperseg}_noverlap={noverlap}"
                f"_axis={timeaxis}",
                "shape":
                shape,
                "dtype":
                dtype,
                "fs":
                fs,
                "window":
                window,
                "nperseg":
                nperseg,
                "noverlap":
                noverlap,
                "nfft":
                nfft,
                "detrend":
                detrend,
                "scaling":
                scaling,
                "timeaxis":
                timeaxis,
                "average":
                average
            } for shape, unused_yshape, nperseg, noverlap, timeaxis in
            csd_test_shapes for dtype in default_dtypes
            for fs in [1.0, 16000.0]
            for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
            for nfft in [None, nperseg,
                         int(nperseg * 1.5), nperseg * 2]
            for detrend in ['constant', 'linear', False]
            for scaling in ['density', 'spectrum'] for average in ['mean']))
    def testCsdWithSameParamAgainstNumpy(self, *, shape, dtype, fs, window,
                                         nperseg, noverlap, nfft, detrend,
                                         scaling, timeaxis, average):
        is_complex = dtypes.issubdtype(dtype, np.complexfloating)
        if is_complex and detrend is not None:
            self.skipTest(
                "Complex signal is not supported in lax-backed `signal.detrend`."
            )

        kwds = dict(fs=fs,
                    window=window,
                    nperseg=nperseg,
                    noverlap=noverlap,
                    nfft=nfft,
                    detrend=detrend,
                    return_onesided=not is_complex,
                    scaling=scaling,
                    axis=timeaxis,
                    average=average)

        def osp_fun(x, y):
            # When the identical parameters are given, jsp-version follows
            # the behavior with copied parameters.
            freqs, Pxy = osp_signal.csd(x, y.copy(), **kwds)
            # Make type-casting the same as JAX.
            return freqs.astype(_real_dtype(dtype)), Pxy.astype(
                _complex_dtype(dtype))

        jsp_fun = partial(jsp_signal.csd, **kwds)
        tol = {
            np.float32: 1e-5,
            np.float64: 1e-12,
            np.complex64: 1e-5,
            np.complex128: 1e-12
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)] * 2

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
                f"_fs={fs}_window={window}"
                f"_nperseg={nperseg}_noverlap={noverlap}_nfft={nfft}"
                f"_detrend={detrend}_return_onesided={return_onesided}"
                f"_scaling={scaling}_axis={timeaxis}_average={average}",
                "shape":
                shape,
                "dtype":
                dtype,
                "fs":
                fs,
                "window":
                window,
                "nperseg":
                nperseg,
                "noverlap":
                noverlap,
                "nfft":
                nfft,
                "detrend":
                detrend,
                "return_onesided":
                return_onesided,
                "scaling":
                scaling,
                "timeaxis":
                timeaxis,
                "average":
                average
            } for shape, nperseg, noverlap, timeaxis in welch_test_shapes
            for dtype in default_dtypes for fs in [1.0, 16000.0]
            for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
            for nfft in [None, nperseg,
                         int(nperseg * 1.5), nperseg * 2]
            for detrend in ['constant', 'linear', False]
            for return_onesided in [True, False]
            for scaling in ['density', 'spectrum']
            for average in ['mean', 'median']))
    def testWelchAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
                              noverlap, nfft, detrend, return_onesided,
                              scaling, timeaxis, average):
        if np.dtype(dtype).kind == 'c':
            return_onesided = False
            if detrend is not None:
                raise unittest.SkipTest(
                    "Complex signal is not supported in lax-backed `signal.detrend`."
                )

        kwds = dict(fs=fs,
                    window=window,
                    nperseg=nperseg,
                    noverlap=noverlap,
                    nfft=nfft,
                    detrend=detrend,
                    return_onesided=return_onesided,
                    scaling=scaling,
                    axis=timeaxis,
                    average=average)

        def osp_fun(x):
            freqs, Pxx = osp_signal.welch(x, **kwds)
            return freqs.astype(_real_dtype(dtype)), Pxx.astype(
                _real_dtype(dtype))

        jsp_fun = partial(jsp_signal.welch, **kwds)
        tol = {
            np.float32: 1e-5,
            np.float64: 1e-12,
            np.complex64: 1e-5,
            np.complex128: 1e-12
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)]

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
                f"_nperseg={nperseg}_noverlap={noverlap}"
                f"_use_nperseg={use_nperseg}_use_overlap={use_noverlap}"
                f"_axis={timeaxis}",
                "shape":
                shape,
                "dtype":
                dtype,
                "nperseg":
                nperseg,
                "noverlap":
                noverlap,
                "use_nperseg":
                use_nperseg,
                "use_noverlap":
                use_noverlap,
                "timeaxis":
                timeaxis
            } for shape, nperseg, noverlap, timeaxis in welch_test_shapes
            for use_nperseg in [False, True] for use_noverlap in [False, True]
            for dtype in jtu.dtypes.floating + jtu.dtypes.integer))
    def testWelchWithDefaultStepArgsAgainstNumpy(self, *, shape, dtype,
                                                 nperseg, noverlap,
                                                 use_nperseg, use_noverlap,
                                                 timeaxis):
        kwargs = {'axis': timeaxis}

        if use_nperseg:
            kwargs['nperseg'] = nperseg
        else:
            kwargs['window'] = jnp.array(osp_signal.get_window(
                'hann', nperseg),
                                         dtype=dtypes._to_complex_dtype(dtype))
        if use_noverlap:
            kwargs['noverlap'] = noverlap

        def osp_fun(x):
            freqs, Pxx = osp_signal.welch(x, **kwargs)
            return freqs.astype(_real_dtype(dtype)), Pxx.astype(
                _real_dtype(dtype))

        jsp_fun = partial(jsp_signal.welch, **kwargs)
        tol = {
            np.float32: 1e-5,
            np.float64: 1e-12,
            np.complex64: 1e-5,
            np.complex128: 1e-12
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)]

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
                f"_fs={fs}_window={window}_boundary={boundary}"
                f"_nperseg={nperseg}_noverlap={noverlap}_onesided={onesided}"
                f"_timeaxis={timeaxis}_freqaxis{freqaxis}_nfft={nfft}",
                "shape":
                shape,
                "dtype":
                dtype,
                "fs":
                fs,
                "window":
                window,
                "nperseg":
                nperseg,
                "noverlap":
                noverlap,
                "nfft":
                nfft,
                "onesided":
                onesided,
                "boundary":
                boundary,
                "timeaxis":
                timeaxis,
                "freqaxis":
                freqaxis
            } for shape, nperseg, noverlap, timeaxis, freqaxis in
            istft_test_shapes for dtype in default_dtypes
            for fs in [1.0, 16000.0]
            for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
            for nfft in [None, nperseg,
                         int(nperseg * 1.5), nperseg * 2]
            for onesided in [False, True] for boundary in [False, True]))
    def testIstftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
                              noverlap, nfft, onesided, boundary, timeaxis,
                              freqaxis):
        if not onesided:
            new_freq_len = (shape[freqaxis] - 1) * 2
            shape = shape[:freqaxis] + (new_freq_len, ) + shape[freqaxis + 1:]

        kwds = dict(fs=fs,
                    window=window,
                    nperseg=nperseg,
                    noverlap=noverlap,
                    nfft=nfft,
                    input_onesided=onesided,
                    boundary=boundary,
                    time_axis=timeaxis,
                    freq_axis=freqaxis)

        osp_fun = partial(osp_signal.istft, **kwds)
        osp_fun = jtu.ignore_warning(
            message="NOLA condition failed, STFT may not be invertible")(
                osp_fun)
        jsp_fun = partial(jsp_signal.istft, **kwds)

        tol = {
            np.float32: 1e-4,
            np.float64: 1e-6,
            np.complex64: 1e-4,
            np.complex128: 1e-6
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)]

        # Here, dtype of output signal is different depending on osp versions,
        # and so depending on the test environment.  Thus, dtype check is disabled.
        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol,
                                check_dtypes=False)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
Exemplo n.º 9
0
class AnnTest(jtu.JaxTestCase):

  @parameterized.named_parameters(
      jtu.cases_from_list({
          "testcase_name":
              "_qy={}_db={}_k={}_recall={}".format(
                  jtu.format_shape_dtype_string(qy_shape, dtype),
                  jtu.format_shape_dtype_string(db_shape, dtype), k, recall),
          "qy_shape": qy_shape, "db_shape": db_shape, "dtype": dtype,
          "k": k, "recall": recall }
        for qy_shape in [(200, 128), (128, 128)]
        for db_shape in [(128, 500), (128, 3000)]
        for dtype in jtu.dtypes.all_floating
        for k in [1, 10, 50] for recall in [0.9, 0.95]))
  def test_approx_max_k(self, qy_shape, db_shape, dtype, k, recall):
    rng = jtu.rand_default(self.rng())
    qy = rng(qy_shape, dtype)
    db = rng(db_shape, dtype)
    scores = lax.dot(qy, db)
    _, gt_args = lax.top_k(scores, k)
    _, ann_args = ann.approx_max_k(scores, k, recall_target=recall)
    self.assertEqual(k, len(ann_args[0]))
    gt_args_sets = [set(np.asarray(x)) for x in gt_args]
    hits = sum(
        len(list(x
                 for x in ann_args_per_q
                 if x.item() in gt_args_sets[q]))
        for q, ann_args_per_q in enumerate(ann_args))
    self.assertGreater(hits / (qy_shape[0] * k), recall)

  @parameterized.named_parameters(
      jtu.cases_from_list({
          "testcase_name":
              "_qy={}_db={}_k={}_recall={}".format(
                  jtu.format_shape_dtype_string(qy_shape, dtype),
                  jtu.format_shape_dtype_string(db_shape, dtype), k, recall),
          "qy_shape": qy_shape, "db_shape": db_shape, "dtype": dtype,
          "k": k, "recall": recall }
        for qy_shape in [(200, 128), (128, 128)]
        for db_shape in [(128, 500), (128, 3000)]
        for dtype in jtu.dtypes.all_floating
        for k in [1, 10, 50] for recall in [0.9, 0.95]))
  def test_approx_min_k(self, qy_shape, db_shape, dtype, k, recall):
    rng = jtu.rand_default(self.rng())
    qy = rng(qy_shape, dtype)
    db = rng(db_shape, dtype)
    scores = lax.dot(qy, db)
    _, gt_args = lax.top_k(-scores, k)
    _, ann_args = ann.approx_min_k(scores, k, recall_target=recall)
    self.assertEqual(k, len(ann_args[0]))
    gt_args_sets = [set(np.asarray(x)) for x in gt_args]
    hits = sum(
        len(list(x
                 for x in ann_args_per_q
                 if x.item() in gt_args_sets[q]))
        for q, ann_args_per_q in enumerate(ann_args))
    self.assertGreater(hits / (qy_shape[0] * k), recall)

  @parameterized.named_parameters(
      jtu.cases_from_list({
          "testcase_name":
              "_shape={}_k={}_max_k={}".format(
                  jtu.format_shape_dtype_string(shape, dtype), k, is_max_k),
          "shape": shape, "dtype": dtype, "k": k, "is_max_k": is_max_k }
        for dtype in [np.float32]
        for shape in [(4,), (5, 5), (2, 1, 4)]
        for k in [1, 3]
        for is_max_k in [True, False]))
  def test_autodiff(self, shape, dtype, k, is_max_k):
    vals = np.arange(prod(shape), dtype=dtype)
    vals = self.rng().permutation(vals).reshape(shape)
    if is_max_k:
      fn = lambda vs: ann.approx_max_k(vs, k=k)[0]
    else:
      fn = lambda vs: ann.approx_min_k(vs, k=k)[0]
    jtu.check_grads(fn, (vals,), 2, ["fwd", "rev"], eps=1e-2)
Exemplo n.º 10
0
class SparsifyTest(jtu.JaxTestCase):
    @classmethod
    def sparsify(cls, f):
        return sparsify(f, use_tracer=False)

    def testNotImplementedMessages(self):
        x = BCOO.fromdense(jnp.arange(5.0))
        # Test a densifying primitive
        with self.assertRaisesRegex(
                NotImplementedError,
                r"^sparse rule for cos is not implemented because it would result in dense output\."
        ):
            self.sparsify(lax.cos)(x)

        # Test a generic not implemented primitive.
        with self.assertRaisesRegex(
                NotImplementedError,
                r"^sparse rule for complex is not implemented\.$"):
            self.sparsify(lax.complex)(x, x)

    def testTracerIsInstanceCheck(self):
        @self.sparsify
        def f(x):
            self.assertNotIsInstance(x, SparseTracer)

        f(jnp.arange(5))

    def assertBcooIdentical(self, x, y):
        self.assertIsInstance(x, BCOO)
        self.assertIsInstance(y, BCOO)
        self.assertEqual(x.shape, y.shape)
        self.assertArraysEqual(x.data, y.data)
        self.assertArraysEqual(x.indices, y.indices)

    def testSparsifyValue(self):
        X = jnp.arange(5)
        X_BCOO = BCOO.fromdense(X)

        args = (X, X_BCOO, X_BCOO)

        # Independent index
        spenv = SparsifyEnv()
        spvalues = arrays_to_spvalues(spenv, args)
        self.assertEqual(len(spvalues), len(args))
        self.assertLen(spenv._buffers, 5)
        self.assertEqual(
            spvalues,
            (SparsifyValue(
                X.shape, 0, None, indices_sorted=False, unique_indices=False),
             SparsifyValue(
                 X.shape, 1, 2, indices_sorted=True, unique_indices=True),
             SparsifyValue(
                 X.shape, 3, 4, indices_sorted=True, unique_indices=True)))

        args_out = spvalues_to_arrays(spenv, spvalues)
        self.assertEqual(len(args_out), len(args))
        self.assertArraysEqual(args[0], args_out[0])
        self.assertBcooIdentical(args[1], args_out[1])
        self.assertBcooIdentical(args[2], args_out[2])

        # Shared index
        spvalues = (SparsifyValue(X.shape, 0, None),
                    SparsifyValue(X.shape, 1, 2), SparsifyValue(X.shape, 3, 2))
        spenv = SparsifyEnv([X, X_BCOO.data, X_BCOO.indices, X_BCOO.data])

        args_out = spvalues_to_arrays(spenv, spvalues)
        self.assertEqual(len(args_out), len(args))
        self.assertArraysEqual(args[0], args_out[0])
        self.assertBcooIdentical(args[1], args_out[1])
        self.assertBcooIdentical(args[2], args_out[2])

    def testDropvar(self):
        def inner(x):
            return x * 2, x * 3

        def f(x):
            _, y = jit(inner)(x)
            return y * 4

        x_dense = jnp.arange(5)
        x_sparse = BCOO.fromdense(x_dense)
        self.assertArraysEqual(
            self.sparsify(f)(x_sparse).todense(), f(x_dense))

    def testPytreeInput(self):
        f = self.sparsify(lambda x: x)
        args = (jnp.arange(4), BCOO.fromdense(jnp.arange(4)))
        out = f(args)
        self.assertLen(out, 2)
        self.assertArraysEqual(args[0], out[0])
        self.assertBcooIdentical(args[1], out[1])

    @jax.numpy_dtype_promotion(
        'standard')  # explicitly exercises implicit dtype promotion.
    def testSparsify(self):
        M_dense = jnp.arange(24).reshape(4, 6)
        M_sparse = BCOO.fromdense(M_dense)
        v = jnp.arange(M_dense.shape[0])

        @self.sparsify
        def func(x, v):
            return -jnp.sin(jnp.pi * x).T @ (v + 1)

        with jtu.ignore_warning(
                category=CuSparseEfficiencyWarning,
                message=
                "bcoo_dot_general GPU lowering requires matrices with sorted indices*"
        ):
            result_sparse = func(M_sparse, v)
        result_dense = func(M_dense, v)
        self.assertAllClose(result_sparse, result_dense)

    def testSparsifyWithConsts(self):
        M_dense = jnp.arange(24).reshape(4, 6)
        M_sparse = BCOO.fromdense(M_dense)

        @self.sparsify
        def func(x):
            return jit(lambda x: jnp.sum(x, 1))(x)

        result_dense = func(M_dense)
        result_sparse = func(M_sparse)

        self.assertAllClose(result_sparse.todense(), result_dense)

    def testSparseMatmul(self):
        X = jnp.arange(16.0).reshape(4, 4)
        Xsp = BCOO.fromdense(X)
        Y = jnp.ones(4)
        Ysp = BCOO.fromdense(Y)

        func = self.sparsify(operator.matmul)

        # dot_general
        result_sparse = func(Xsp, Y)
        result_dense = operator.matmul(X, Y)
        self.assertAllClose(result_sparse, result_dense)

        # rdot_general
        result_sparse = func(Y, Xsp)
        result_dense = operator.matmul(Y, X)
        self.assertAllClose(result_sparse, result_dense)

        # spdot_general
        result_sparse = self.sparsify(operator.matmul)(Xsp, Ysp)
        result_dense = operator.matmul(X, Y)
        self.assertAllClose(result_sparse.todense(), result_dense)

    def testSparseAdd(self):
        x = BCOO.fromdense(jnp.arange(5))
        y = BCOO.fromdense(2 * jnp.arange(5))

        # Distinct indices
        out = self.sparsify(operator.add)(x, y)
        self.assertEqual(out.nse, 8)  # uses concatenation.
        self.assertArraysEqual(out.todense(), 3 * jnp.arange(5))

        # Shared indices – requires lower level call
        spenv = SparsifyEnv([x.indices, x.data, y.data])
        spvalues = [
            spenv.sparse(x.shape, data_ref=1, indices_ref=0),
            spenv.sparse(y.shape, data_ref=2, indices_ref=0)
        ]

        result = sparsify_raw(operator.add)(spenv, *spvalues)
        args_out, _ = result
        out, = spvalues_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() + y.todense())

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_{}_nbatch={}_ndense={}_unique_indices={}".format(
                    jtu.format_shape_dtype_string(shape, dtype), n_batch,
                    n_dense, unique_indices),
                "shape":
                shape,
                "dtype":
                dtype,
                "n_batch":
                n_batch,
                "n_dense":
                n_dense,
                "unique_indices":
                unique_indices
            } for shape in [(5, ), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
            for dtype in (jtu.dtypes.integer + jtu.dtypes.floating +
                          jtu.dtypes.complex)
            for n_batch in range(len(shape) + 1)
            for n_dense in range(len(shape) + 1 - n_batch)
            for unique_indices in [True, False]))
    def testSparseMul(self, shape, dtype, n_batch, n_dense, unique_indices):
        rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero)
        x = BCOO.fromdense(rng_sparse(shape, dtype),
                           n_batch=n_batch,
                           n_dense=n_dense)

        # Scalar multiplication
        scalar = 2
        y = self.sparsify(operator.mul)(x, scalar)
        self.assertArraysEqual(x.todense() * scalar, y.todense())

        # Shared indices – requires lower level call
        spenv = SparsifyEnv([x.indices, x.data, y.data])
        spvalues = [
            spenv.sparse(x.shape,
                         data_ref=1,
                         indices_ref=0,
                         unique_indices=unique_indices),
            spenv.sparse(y.shape,
                         data_ref=2,
                         indices_ref=0,
                         unique_indices=unique_indices)
        ]

        result = sparsify_raw(operator.mul)(spenv, *spvalues)
        args_out, _ = result
        out, = spvalues_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() * y.todense())

    def testSparseSubtract(self):
        x = BCOO.fromdense(3 * jnp.arange(5))
        y = BCOO.fromdense(jnp.arange(5))

        # Distinct indices
        out = self.sparsify(operator.sub)(x, y)
        self.assertEqual(out.nse, 8)  # uses concatenation.
        self.assertArraysEqual(out.todense(), 2 * jnp.arange(5))

        # Shared indices – requires lower level call
        spenv = SparsifyEnv([x.indices, x.data, y.data])
        spvalues = [
            spenv.sparse(x.shape, data_ref=1, indices_ref=0),
            spenv.sparse(y.shape, data_ref=2, indices_ref=0)
        ]

        result = sparsify_raw(operator.sub)(spenv, *spvalues)
        args_out, _ = result
        out, = spvalues_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() - y.todense())

    def testSparseSum(self):
        x = jnp.arange(20).reshape(4, 5)
        xsp = BCOO.fromdense(x)

        def f(x):
            return x.sum(), x.sum(0), x.sum(1), x.sum((0, 1))

        result_dense = f(x)
        result_sparse = self.sparsify(f)(xsp)

        assert len(result_dense) == len(result_sparse)

        for res_dense, res_sparse in zip(result_dense, result_sparse):
            if isinstance(res_sparse, BCOO):
                res_sparse = res_sparse.todense()
            self.assertArraysAllClose(res_dense, res_sparse)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_dimensions={}_nbatch={}_ndense={}".format(
                jtu.format_shape_dtype_string(shape, np.float32), dimensions,
                n_batch, n_dense),
            "shape":
            shape,
            "dimensions":
            dimensions,
            "n_batch":
            n_batch,
            "n_dense":
            n_dense
        } for shape, dimensions in [
            [(1, ), (0, )],
            [(1, ), (-1, )],
            [(2, 1, 4), (1, )],
            [(2, 1, 3, 1), (1, )],
            [(2, 1, 3, 1), (1, 3)],
            [(2, 1, 3, 1), (3, )],
        ] for n_batch in range(len(shape) + 1)
                            for n_dense in range(len(shape) - n_batch + 1)))
    def testSparseSqueeze(self, shape, dimensions, n_batch, n_dense):
        rng = jtu.rand_default(self.rng())

        M_dense = rng(shape, np.float32)
        M_sparse = BCOO.fromdense(M_dense, n_batch=n_batch, n_dense=n_dense)
        func = self.sparsify(partial(lax.squeeze, dimensions=dimensions))

        result_dense = func(M_dense)
        result_sparse = func(M_sparse).todense()

        self.assertAllClose(result_sparse, result_dense)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_shapes={shapes}_func={func}_nbatch={n_batch}",
            "shapes": shapes,
            "func": func,
            "n_batch": n_batch
        } for shapes, func, n_batch in [
            ([(4, ), (4, )], "concatenate", 0),
            ([(4, ), (4, )], "stack", 0),
            ([(4, ), (4, )], "hstack", 0),
            ([(4, ), (4, )], "vstack", 0),
            ([(4, ), (4, )], "concatenate", 1),
            ([(4, ), (4, )], "stack", 1),
            ([(4, ), (4, )], "hstack", 1),
            ([(4, ), (4, )], "vstack", 1),
            ([(2, 4), (2, 4)], "stack", 0),
            ([(2, 4), (3, 4)], "vstack", 0),
            ([(2, 4), (2, 5)], "hstack", 0),
            ([(2, 4), (3, 4)], "vstack", 1),
            ([(2, 4), (2, 5)], "hstack", 1),
            ([(2, 4), (3, 4)], "vstack", 2),
            ([(2, 4), (2, 5)], "hstack", 2),
            ([(2, 4), (4, ), (3, 4)], "vstack", 0),
            ([(1, 4), (4, ), (1, 4)], "vstack", 0),
        ]))
    def testSparseConcatenate(self, shapes, func, n_batch):
        f = self.sparsify(getattr(jnp, func))
        rng = jtu.rand_some_zero(self.rng())
        arrs = [rng(shape, 'int32') for shape in shapes]
        sparrs = [BCOO.fromdense(arr, n_batch=n_batch) for arr in arrs]
        self.assertArraysEqual(f(arrs), f(sparrs).todense())

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            f"_{shape}->{new_shape}_n_batch={n_batch}_n_dense={n_dense}",
            "shape": shape,
            "new_shape": new_shape,
            "n_batch": n_batch,
            "n_dense": n_dense
        } for shape, new_shape, n_batch, n_dense in [
            [(6, ), (2, 3), 0, 0],
            [(1, 4), (2, 2), 0, 0],
            [(12, 2), (2, 3, 4), 0, 0],
            [(1, 3, 2), (2, 3), 0, 0],
            [(1, 6), (2, 3, 1), 0, 0],
            [(2, 3, 4), (3, 8), 0, 0],
            [(2, 3, 4), (1, 2, 12), 1, 0],
            [(2, 3, 4), (6, 2, 2), 2, 0],
        ]))
    def testSparseReshapeMethod(self, shape, new_shape, n_batch, n_dense):
        rng = jtu.rand_some_zero(self.rng())
        arr = rng(shape, 'int32')
        arr_sparse = BCOO.fromdense(arr, n_batch=n_batch, n_dense=n_dense)

        arr2 = arr.reshape(new_shape)
        arr2_sparse = arr_sparse.reshape(new_shape)

        self.assertArraysEqual(arr2, arr2_sparse.todense())

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            f"_{shape}->{new_shape}_n_batch={n_batch}_n_dense={n_dense}_dimensions={dimensions}",
            "shape": shape,
            "new_shape": new_shape,
            "n_batch": n_batch,
            "n_dense": n_dense,
            "dimensions": dimensions
        } for shape, new_shape, n_batch, n_dense, dimensions in [
            [(2, 3, 4), (24, ), 0, 0, None],
            [(2, 3, 4), (24, ), 0, 0, (0, 1, 2)],
            [(2, 3, 4), (24, ), 0, 0, (0, 2, 1)],
            [(2, 3, 4), (24, ), 0, 0, (1, 0, 2)],
            [(2, 3, 4), (24, ), 0, 0, (1, 2, 0)],
            [(2, 3, 4), (24, ), 0, 0, (2, 0, 1)],
            [(2, 3, 4), (24, ), 0, 0, (2, 1, 0)],
            [(4, 2, 3), (2, 2, 6), 1, 0, (0, 1, 2)],
            [(4, 2, 3), (2, 2, 6), 1, 0, (0, 2, 1)],
            [(2, 3, 4), (6, 4), 2, 0, (0, 1, 2)],
            [(2, 3, 4), (6, 4), 2, 0, (1, 0, 2)],
        ]))
    def testSparseReshapeWithDimensions(self, shape, new_shape, n_batch,
                                        n_dense, dimensions):
        rng = jtu.rand_some_zero(self.rng())
        arr = rng(shape, 'int32')
        arr_sparse = BCOO.fromdense(arr, n_batch=n_batch, n_dense=n_dense)

        f = self.sparsify(
            lambda x: lax.reshape(x, new_shape, dimensions=dimensions))

        arr2 = f(arr)
        arr2_sparse = f(arr_sparse)

        self.assertArraysEqual(arr2, arr2_sparse.todense())

    def testSparseWhileLoop(self):
        def cond_fun(params):
            i, A = params
            return i < 5

        def body_fun(params):
            i, A = params
            return i + 1, 2 * A

        def f(A):
            return lax.while_loop(cond_fun, body_fun, (0, A))

        A = jnp.arange(4)
        out_dense = f(A)

        Asp = BCOO.fromdense(A)
        out_sparse = self.sparsify(f)(Asp)

        self.assertEqual(len(out_dense), 2)
        self.assertEqual(len(out_sparse), 2)
        self.assertArraysEqual(out_dense[0], out_dense[0])
        self.assertArraysEqual(out_dense[1], out_sparse[1].todense())

    def testSparseWhileLoopDuplicateIndices(self):
        def cond_fun(params):
            i, A, B = params
            return i < 5

        def body_fun(params):
            i, A, B = params
            # TODO(jakevdp): track shared indices through while loop & use this
            #   version of the test, which requires shared indices in order for
            #   the nse of the result to remain the same.
            # return i + 1, A, A + B

            # This version is fine without shared indices, and tests that we're
            # flattening non-shared indices consistently.
            return i + 1, B, A

        def f(A):
            return lax.while_loop(cond_fun, body_fun, (0, A, A))

        A = jnp.arange(4).reshape((2, 2))
        out_dense = f(A)

        Asp = BCOO.fromdense(A)
        out_sparse = self.sparsify(f)(Asp)

        self.assertEqual(len(out_dense), 3)
        self.assertEqual(len(out_sparse), 3)
        self.assertArraysEqual(out_dense[0], out_dense[0])
        self.assertArraysEqual(out_dense[1], out_sparse[1].todense())
        self.assertArraysEqual(out_dense[2], out_sparse[2].todense())

    def testSparsifyDenseXlaCall(self):
        # Test handling of dense xla_call within jaxpr interpreter.
        out = self.sparsify(jit(lambda x: x + 1))(0.0)
        self.assertEqual(out, 1.0)

    def testSparsifySparseXlaCall(self):
        # Test sparse lowering of XLA call
        def func(M):
            return 2 * M

        M = jnp.arange(6).reshape(2, 3)
        Msp = BCOO.fromdense(M)

        out_dense = func(M)
        out_sparse = self.sparsify(jit(func))(Msp)
        self.assertArraysEqual(out_dense, out_sparse.todense())

    def testSparseForiLoop(self):
        def func(M, x):
            body_fun = lambda i, val: (M @ val) / M.shape[1]
            return lax.fori_loop(0, 2, body_fun, x)

        x = jnp.arange(5.0)
        M = jnp.arange(25).reshape(5, 5)
        M_bcoo = BCOO.fromdense(M)

        with jax.numpy_dtype_promotion('standard'):
            result_dense = func(M, x)
            result_sparse = self.sparsify(func)(M_bcoo, x)

        self.assertArraysAllClose(result_dense, result_sparse)

    def testSparseCondSimple(self):
        def func(x):
            return lax.cond(False, lambda x: x, lambda x: 2 * x, x)

        x = jnp.arange(5.0)
        result_dense = func(x)

        x_bcoo = BCOO.fromdense(x)
        result_sparse = self.sparsify(func)(x_bcoo)

        self.assertArraysAllClose(result_dense, result_sparse.todense())

    def testSparseCondMismatchError(self):
        @self.sparsify
        def func(x, y):
            return lax.cond(False, lambda x: x[0], lambda x: x[1], (x, y))

        x = jnp.arange(5.0)
        y = jnp.arange(5.0)

        x_bcoo = BCOO.fromdense(x)
        y_bcoo = BCOO.fromdense(y)

        func(x, y)  # No error
        func(x_bcoo, y_bcoo)  # No error

        with self.assertRaisesRegex(
                TypeError, "sparsified true_fun and false_fun output.*"):
            func(x_bcoo, y)

    def testToDense(self):
        M = jnp.arange(4)
        Msp = BCOO.fromdense(M)

        @self.sparsify
        def func(M):
            return todense(M) + 1

        self.assertArraysEqual(func(M), M + 1)
        self.assertArraysEqual(func(Msp), M + 1)
        self.assertArraysEqual(jit(func)(M), M + 1)
        self.assertArraysEqual(jit(func)(Msp), M + 1)

    def testWeakTypes(self):
        # Regression test for https://github.com/google/jax/issues/8267
        M = jnp.arange(12, dtype='int32').reshape(3, 4)
        Msp = BCOO.fromdense(M)
        self.assertArraysEqual(
            operator.mul(2, M),
            self.sparsify(operator.mul)(2, Msp).todense(),
            check_dtypes=True,
        )
Exemplo n.º 11
0
class ImageTest(jtu.JaxTestCase):

  @parameterized.named_parameters(jtu.cases_from_list(
       {"testcase_name": "_shape={}_target={}_method={}_antialias={}".format(
          jtu.format_shape_dtype_string(image_shape, dtype),
          jtu.format_shape_dtype_string(target_shape, dtype), method,
          antialias),
        "dtype": dtype, "image_shape": image_shape,
        "target_shape": target_shape,
        "method": method, "antialias": antialias}
       for dtype in float_dtypes
       for target_shape, image_shape in itertools.combinations_with_replacement(
        [[2, 3, 2, 4], [2, 6, 4, 4], [2, 33, 17, 4], [2, 50, 38, 4]], 2)
       for method in ["nearest", "bilinear", "lanczos3", "lanczos5", "bicubic"]
       for antialias in [False, True]))
  @unittest.skipIf(not tf, "Test requires TensorFlow")
  def testResizeAgainstTensorFlow(self, dtype, image_shape, target_shape, method,
                                  antialias):
    # TODO(phawkins): debug this. There is a small mismatch between TF and JAX
    # for some cases of non-antialiased bicubic downscaling; we would expect
    # exact equality.
    if method == "bicubic" and any(x < y for x, y in
                                   zip(target_shape, image_shape)):
      raise unittest.SkipTest("non-antialiased bicubic downscaling mismatch")
    rng = jtu.rand_default(self.rng())
    args_maker = lambda: (rng(image_shape, dtype),)
    def tf_fn(x):
      out = tf.image.resize(
        x.astype(np.float64), tf.constant(target_shape[1:-1]),
        method=method, antialias=antialias).numpy().astype(dtype)
      return out
    jax_fn = partial(image.resize, shape=target_shape, method=method,
                     antialias=antialias)
    self._CheckAgainstNumpy(tf_fn, jax_fn, args_maker, check_dtypes=True,
                            tol={np.float16: 2e-2, jnp.bfloat16: 1e-1,
                                 np.float32: 1e-4, np.float64: 1e-4})


  @parameterized.named_parameters(jtu.cases_from_list(
       {"testcase_name": "_shape={}_target={}_method={}".format(
          jtu.format_shape_dtype_string(image_shape, dtype),
          jtu.format_shape_dtype_string(target_shape, dtype), method),
        "dtype": dtype, "image_shape": image_shape,
        "target_shape": target_shape,
        "method": method}
       for dtype in [np.float32]

       for target_shape, image_shape in itertools.combinations_with_replacement(
        [[3, 2], [6, 4], [33, 17], [50, 39]], 2)
       for method in ["nearest", "bilinear", "lanczos3", "bicubic"]))
  @unittest.skipIf(not PIL_Image, "Test requires PIL")
  def testResizeAgainstPIL(self, dtype, image_shape, target_shape, method):
    rng = jtu.rand_uniform(self.rng())
    args_maker = lambda: (rng(image_shape, dtype),)
    def pil_fn(x):
      pil_methods = {
        "nearest": PIL_Image.NEAREST,
        "bilinear": PIL_Image.BILINEAR,
        "bicubic": PIL_Image.BICUBIC,
        "lanczos3": PIL_Image.LANCZOS,
      }
      img = PIL_Image.fromarray(x.astype(np.float32))
      out = np.asarray(img.resize(target_shape[::-1], pil_methods[method]),
                       dtype=dtype)
      return out
    jax_fn = partial(image.resize, shape=target_shape, method=method,
                     antialias=True)
    self._CheckAgainstNumpy(pil_fn, jax_fn, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
       {"testcase_name": "_shape={}_target={}_method={}".format(
          jtu.format_shape_dtype_string(image_shape, dtype),
          jtu.format_shape_dtype_string(target_shape, dtype), method),
        "dtype": dtype, "image_shape": image_shape, "target_shape": target_shape,
        "method": method}
       for dtype in inexact_dtypes
       for image_shape, target_shape in [
         ([3, 1, 2], [6, 1, 4]),
         ([1, 3, 2, 1], [1, 6, 4, 1]),
       ]
       for method in ["nearest", "linear", "lanczos3", "lanczos5", "cubic"]))
  def testResizeUp(self, dtype, image_shape, target_shape, method):
    data = [64, 32, 32, 64, 50, 100]
    expected_data = {}
    expected_data["nearest"] = [
        64.0, 64.0, 32.0, 32.0, 64.0, 64.0, 32.0, 32.0, 32.0, 32.0, 64.0, 64.0,
        32.0, 32.0, 64.0, 64.0, 50.0, 50.0, 100.0, 100.0, 50.0, 50.0, 100.0,
        100.0
    ]
    expected_data["linear"] = [
        64.0, 56.0, 40.0, 32.0, 56.0, 52.0, 44.0, 40.0, 40.0, 44.0, 52.0, 56.0,
        36.5, 45.625, 63.875, 73.0, 45.5, 56.875, 79.625, 91.0, 50.0, 62.5,
        87.5, 100.0
    ]
    expected_data["lanczos3"] = [
        75.8294, 59.6281, 38.4313, 22.23, 60.6851, 52.0037, 40.6454, 31.964,
        35.8344, 41.0779, 47.9383, 53.1818, 24.6968, 43.0769, 67.1244, 85.5045,
        35.7939, 56.4713, 83.5243, 104.2017, 44.8138, 65.1949, 91.8603, 112.2413
    ]
    expected_data["lanczos5"] = [
        77.5699, 60.0223, 40.6694, 23.1219, 61.8253, 51.2369, 39.5593, 28.9709,
        35.7438, 40.8875, 46.5604, 51.7041, 21.5942, 43.5299, 67.7223, 89.658,
        32.1213, 56.784, 83.984, 108.6467, 44.5802, 66.183, 90.0082, 111.6109
    ]
    expected_data["cubic"] = [
        70.1453, 59.0252, 36.9748, 25.8547, 59.3195, 53.3386, 41.4789, 35.4981,
        36.383, 41.285, 51.0051, 55.9071, 30.2232, 42.151, 65.8032, 77.731,
        41.6492, 55.823, 83.9288, 98.1026, 47.0363, 62.2744, 92.4903, 107.7284
    ]
    x = np.array(data, dtype=dtype).reshape(image_shape)
    output = image.resize(x, target_shape, method)
    expected = np.array(expected_data[method], dtype=dtype).reshape(target_shape)
    self.assertAllClose(output, expected, atol=1e-04)

  @parameterized.named_parameters(jtu.cases_from_list(
       {"testcase_name": "_shape={}_target={}_method={}_antialias={}".format(
          jtu.format_shape_dtype_string(image_shape, dtype),
          jtu.format_shape_dtype_string(target_shape, dtype), method,
          antialias),
        "dtype": dtype, "image_shape": image_shape,
        "target_shape": target_shape,
        "method": method, "antialias": antialias}
       for dtype in [np.float32]
       for target_shape, image_shape in itertools.combinations_with_replacement(
        [[2, 3, 2, 4], [2, 6, 4, 4], [2, 33, 17, 4], [2, 50, 38, 4]], 2)
       for method in ["bilinear", "lanczos3", "lanczos5", "bicubic"]
       for antialias in [False, True]))
  def testResizeGradients(self, dtype, image_shape, target_shape, method,
                           antialias):
    rng = jtu.rand_default(self.rng())
    args_maker = lambda: (rng(image_shape, dtype),)
    jax_fn = partial(image.resize, shape=target_shape, method=method,
                     antialias=antialias)
    jtu.check_grads(jax_fn, args_maker(), order=2, rtol=1e-2, eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
       {"testcase_name": "_shape={}_target={}_method={}_antialias={}".format(
          jtu.format_shape_dtype_string(image_shape, dtype),
          jtu.format_shape_dtype_string(target_shape, dtype), method,
          antialias),
        "dtype": dtype, "image_shape": image_shape,
        "target_shape": target_shape,
        "method": method, "antialias": antialias}
       for dtype in [np.float32]
       for image_shape, target_shape in [
         ([1], [0]),
         ([5, 5], [5, 0]),
         ([5, 5], [0, 1]),
         ([5, 5], [0, 0])
       ]
       for method in ["nearest", "linear", "lanczos3", "lanczos5", "cubic"]
       for antialias in [False, True]))
  def testResizeEmpty(self, dtype, image_shape, target_shape, method, antialias):
    # Regression test for https://github.com/google/jax/issues/7586
    image = np.ones(image_shape, dtype)
    out = jax.image.resize(image, shape=target_shape, method=method, antialias=antialias)
    self.assertArraysEqual(out, jnp.zeros(target_shape, dtype))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_target={}_method={}".format(
         jtu.format_shape_dtype_string(image_shape, dtype),
         jtu.format_shape_dtype_string(target_shape, dtype), method),
       "dtype": dtype, "image_shape": image_shape,
       "target_shape": target_shape,
       "scale": scale, "translation": translation, "method": method}
      for dtype in inexact_dtypes
      for image_shape, target_shape, scale, translation in [
        ([3, 1, 2], [6, 1, 4], [2.0, 1.0, 2.0], [1.0, 0.0, -1.0]),
        ([1, 3, 2, 1], [1, 6, 4, 1], [1.0, 2.0, 2.0, 1.0], [0.0, 1.0, -1.0, 0.0])]
      for method in ["linear", "lanczos3", "lanczos5", "cubic"]))
  def testScaleAndTranslateUp(self, dtype, image_shape, target_shape, scale,
                              translation, method):
    data = [64, 32, 32, 64, 50, 100]
    # Note zeros occur in the output because the sampling location is outside
    # the boundaries of the input image.
    expected_data = {}
    expected_data["linear"] = [
        0.0, 0.0, 0.0, 0.0, 56.0, 40.0, 32.0, 0.0, 52.0, 44.0, 40.0, 0.0, 44.0,
        52.0, 56.0, 0.0, 45.625, 63.875, 73.0, 0.0, 56.875, 79.625, 91.0, 0.0
    ]
    expected_data["lanczos3"] = [
        0.0, 0.0, 0.0, 0.0, 59.6281, 38.4313, 22.23, 0.0, 52.0037, 40.6454,
        31.964, 0.0, 41.0779, 47.9383, 53.1818, 0.0, 43.0769, 67.1244, 85.5045,
        0.0, 56.4713, 83.5243, 104.2017, 0.0
    ]
    expected_data["lanczos5"] = [
        0.0, 0.0, 0.0, 0.0, 60.0223, 40.6694, 23.1219, 0.0, 51.2369, 39.5593,
        28.9709, 0.0, 40.8875, 46.5604, 51.7041, 0.0, 43.5299, 67.7223, 89.658,
        0.0, 56.784, 83.984, 108.6467, 0.0
    ]
    expected_data["cubic"] = [
        0.0, 0.0, 0.0, 0.0, 59.0252, 36.9748, 25.8547, 0.0, 53.3386, 41.4789,
        35.4981, 0.0, 41.285, 51.0051, 55.9071, 0.0, 42.151, 65.8032, 77.731,
        0.0, 55.823, 83.9288, 98.1026, 0.0
    ]
    x = np.array(data, dtype=dtype).reshape(image_shape)
    # Should we test different float types here?
    scale_a = jnp.array(scale, dtype=jnp.float32)
    translation_a = jnp.array(translation, dtype=jnp.float32)
    output = image.scale_and_translate(x, target_shape, range(len(image_shape)),
                                       scale_a, translation_a,
                                       method)

    expected = np.array(
        expected_data[method], dtype=dtype).reshape(target_shape)
    self.assertAllClose(output, expected, atol=2e-03)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_dtype={}_method={}_antialias={}".format(
         jtu.dtype_str(dtype), method, antialias),
       "dtype": dtype, "method": method, "antialias": antialias}
      for dtype in inexact_dtypes
      for method in ["linear", "lanczos3", "lanczos5", "cubic"]
      for antialias in [True, False]))
  def testScaleAndTranslateDown(self, dtype, method, antialias):
    image_shape = [1, 6, 7, 1]
    target_shape = [1, 3, 3, 1]

    data = [
        51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92,
        41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89,
        71, 32, 23, 23, 35, 93
    ]
    if antialias:
      expected_data = {}
      expected_data["linear"] = [
          43.5372, 59.3694, 53.6907, 49.3221, 56.8168, 55.4849, 0, 0, 0
      ]
      expected_data["lanczos3"] = [
          43.2884, 57.9091, 54.6439, 48.5856, 58.2427, 53.7551, 0, 0, 0
      ]
      expected_data["lanczos5"] = [
          43.9209, 57.6360, 54.9575, 48.9272, 58.1865, 53.1948, 0, 0, 0
      ]
      expected_data["cubic"] = [
          42.9935, 59.1687, 54.2138, 48.2640, 58.2678, 54.4088, 0, 0, 0
      ]
    else:
      expected_data = {}
      expected_data["linear"] = [
          43.6071, 89, 59, 37.1785, 27.2857, 58.3571, 0, 0, 0
      ]
      expected_data["lanczos3"] = [
          44.1390, 87.8786, 63.3111, 25.1161, 20.8795, 53.6165, 0, 0, 0
      ]
      expected_data["lanczos5"] = [
          44.8835, 85.5896, 66.7231, 16.9983, 19.8891, 47.1446, 0, 0, 0
      ]
      expected_data["cubic"] = [
          43.6426, 88.8854, 60.6638, 31.4685, 22.1204, 58.3457, 0, 0, 0
      ]
    x = np.array(data, dtype=dtype).reshape(image_shape)

    expected = np.array(
        expected_data[method], dtype=dtype).reshape(target_shape)
    scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32)
    translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32)

    output = image.scale_and_translate(
        x, target_shape, (0,1,2,3),
        scale_a, translation_a, method, antialias=antialias)
    self.assertAllClose(output, expected, atol=2e-03)

    # Tests that running with just a subset of dimensions that have non-trivial
    # scale and translation.
    output = image.scale_and_translate(
        x, target_shape, (1,2),
        scale_a[1:3], translation_a[1:3], method, antialias=antialias)
    self.assertAllClose(output, expected, atol=2e-03)


  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "antialias={}".format(antialias),
       "antialias": antialias}
      for antialias in [True, False]))
  def testScaleAndTranslateJITs(self, antialias):
    image_shape = [1, 6, 7, 1]
    target_shape = [1, 3, 3, 1]

    data = [
        51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92,
        41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89,
        71, 32, 23, 23, 35, 93
    ]
    if antialias:
      expected_data = [
          43.5372, 59.3694, 53.6907, 49.3221, 56.8168, 55.4849, 0, 0, 0
      ]
    else:
      expected_data = [43.6071, 89, 59, 37.1785, 27.2857, 58.3571, 0, 0, 0]
    x = jnp.array(data, dtype=jnp.float32).reshape(image_shape)

    expected = jnp.array(expected_data, dtype=jnp.float32).reshape(target_shape)
    scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32)
    translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32)

    def jit_fn(in_array, s, t):
      return jax.image.scale_and_translate(
          in_array, target_shape, (0, 1, 2, 3), s, t,
          "linear", antialias, precision=jax.lax.Precision.HIGHEST)

    output = jax.jit(jit_fn)(x, scale_a, translation_a)
    self.assertAllClose(output, expected, atol=2e-03)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "antialias={}".format(antialias),
       "antialias": antialias}
      for antialias in [True, False]))
  def testScaleAndTranslateGradFinite(self, antialias):
    image_shape = [1, 6, 7, 1]
    target_shape = [1, 3, 3, 1]

    data = [
        51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92,
        41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89,
        71, 32, 23, 23, 35, 93
    ]

    x = jnp.array(data, dtype=jnp.float32).reshape(image_shape)
    scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32)
    translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32)

    def scale_fn(s):
      return jnp.sum(jax.image.scale_and_translate(
        x, target_shape, (0, 1, 2, 3), s, translation_a, "linear", antialias,
        precision=jax.lax.Precision.HIGHEST))

    scale_out = jax.grad(scale_fn)(scale_a)
    self.assertTrue(jnp.all(jnp.isfinite(scale_out)))

    def translate_fn(t):
      return jnp.sum(jax.image.scale_and_translate(
        x, target_shape, (0, 1, 2, 3), scale_a, t, "linear", antialias,
        precision=jax.lax.Precision.HIGHEST))

    translate_out = jax.grad(translate_fn)(translation_a)
    self.assertTrue(jnp.all(jnp.isfinite(translate_out)))

  def testScaleAndTranslateNegativeDims(self):
    data = jnp.full((3, 3), 0.5)
    actual = jax.image.scale_and_translate(
      data, (5, 5), (-2, -1), jnp.ones(2), jnp.zeros(2), "linear")
    expected = jax.image.scale_and_translate(
      data, (5, 5), (0, 1), jnp.ones(2), jnp.zeros(2), "linear")
    self.assertAllClose(actual, expected)

  def testResizeWithUnusualShapes(self):
    x = jnp.ones((3, 4))
    # Array shapes are accepted
    self.assertEqual((10, 17),
                     jax.image.resize(x, jnp.array((10, 17)), "nearest").shape)
    with self.assertRaises(TypeError):
      # Fractional shapes are disallowed
      jax.image.resize(x, [10.5, 17], "bicubic")
Exemplo n.º 12
0
class AnnTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_qy={}_db={}_k={}_recall={}".format(
                jtu.format_shape_dtype_string(qy_shape, dtype),
                jtu.format_shape_dtype_string(db_shape, dtype), k, recall),
            "qy_shape":
            qy_shape,
            "db_shape":
            db_shape,
            "dtype":
            dtype,
            "k":
            k,
            "recall":
            recall
        } for qy_shape in [(200, 128), (128, 128)]
                            for db_shape in [(128, 500), (128, 3000)]
                            for dtype in jtu.dtypes.all_floating
                            for k in [1, 10, 50] for recall in [0.9, 0.95]))
    def test_approx_max_k(self, qy_shape, db_shape, dtype, k, recall):
        rng = jtu.rand_default(self.rng())
        qy = rng(qy_shape, dtype)
        db = rng(db_shape, dtype)
        scores = lax.dot(qy, db)
        _, gt_args = lax.top_k(scores, k)
        _, ann_args = lax.approx_max_k(scores, k, recall_target=recall)
        self.assertEqual(k, len(ann_args[0]))
        ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
        self.assertGreater(ann_recall, recall)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_qy={}_db={}_k={}_recall={}".format(
                jtu.format_shape_dtype_string(qy_shape, dtype),
                jtu.format_shape_dtype_string(db_shape, dtype), k, recall),
            "qy_shape":
            qy_shape,
            "db_shape":
            db_shape,
            "dtype":
            dtype,
            "k":
            k,
            "recall":
            recall
        } for qy_shape in [(200, 128), (128, 128)]
                            for db_shape in [(128, 500), (128, 3000)]
                            for dtype in jtu.dtypes.all_floating
                            for k in [1, 10, 50] for recall in [0.9, 0.95]))
    def test_approx_min_k(self, qy_shape, db_shape, dtype, k, recall):
        rng = jtu.rand_default(self.rng())
        qy = rng(qy_shape, dtype)
        db = rng(db_shape, dtype)
        scores = lax.dot(qy, db)
        _, gt_args = lax.top_k(-scores, k)
        _, ann_args = lax.approx_min_k(scores, k, recall_target=recall)
        self.assertEqual(k, len(ann_args[0]))
        ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
        self.assertGreater(ann_recall, recall)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_k={}_max_k={}".format(
                jtu.format_shape_dtype_string(shape, dtype), k, is_max_k),
            "shape":
            shape,
            "dtype":
            dtype,
            "k":
            k,
            "is_max_k":
            is_max_k
        } for dtype in [np.float32] for shape in [(4, ), (5, 5), (2, 1, 4)]
                            for k in [1, 3] for is_max_k in [True, False]))
    def test_autodiff(self, shape, dtype, k, is_max_k):
        vals = np.arange(prod(shape), dtype=dtype)
        vals = self.rng().permutation(vals).reshape(shape)
        if is_max_k:
            fn = lambda vs: lax.approx_max_k(vs, k=k)[0]
        else:
            fn = lambda vs: lax.approx_min_k(vs, k=k)[0]
        jtu.check_grads(fn, (vals, ), 2, ["fwd", "rev"], eps=1e-2)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_qy={}_db={}_k={}_recall={}".format(
                jtu.format_shape_dtype_string(qy_shape, dtype),
                jtu.format_shape_dtype_string(db_shape, dtype), k, recall),
            "qy_shape":
            qy_shape,
            "db_shape":
            db_shape,
            "dtype":
            dtype,
            "k":
            k,
            "recall":
            recall
        } for qy_shape in [(200, 128), (128, 128)]
                            for db_shape in [(2048, 128)]
                            for dtype in jtu.dtypes.all_floating
                            for k in [1, 10] for recall in [0.9, 0.95]))
    def test_pmap(self, qy_shape, db_shape, dtype, k, recall):
        num_devices = jax.device_count()
        rng = jtu.rand_default(self.rng())
        qy = rng(qy_shape, dtype)
        db = rng(db_shape, dtype)
        db_size = db.shape[0]
        gt_scores = lax.dot_general(qy, db, (([1], [1]), ([], [])))
        _, gt_args = lax.top_k(-gt_scores, k)  # negate the score to get min-k
        db_per_device = db_size // num_devices
        sharded_db = db.reshape(num_devices, db_per_device, 128)
        db_offsets = np.arange(num_devices, dtype=np.int32) * db_per_device

        def parallel_topk(qy, db, db_offset):
            scores = lax.dot_general(qy, db, (([1], [1]), ([], [])))
            ann_vals, ann_args = lax.approx_min_k(
                scores,
                k=k,
                reduction_dimension=1,
                recall_target=recall,
                reduction_input_size_override=db_size,
                aggregate_to_topk=False)
            return (ann_vals, ann_args + db_offset)

        # shape = qy_size, num_devices, approx_dp
        ann_vals, ann_args = jax.pmap(parallel_topk,
                                      in_axes=(None, 0, 0),
                                      out_axes=(1, 1))(qy, sharded_db,
                                                       db_offsets)
        # collapse num_devices and approx_dp
        ann_vals = lax.collapse(ann_vals, 1, 3)
        ann_args = lax.collapse(ann_args, 1, 3)
        ann_vals, ann_args = lax.sort_key_val(ann_vals, ann_args, dimension=1)
        ann_args = lax.slice_in_dim(ann_args,
                                    start_index=0,
                                    limit_index=k,
                                    axis=1)
        ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
        self.assertGreater(ann_recall, recall)
Exemplo n.º 13
0
class NdimageTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_{}_coordinates={}_order={}_mode={}_cval={}_impl={}_round={}".
                format(
                    jtu.format_shape_dtype_string(shape, dtype),
                    jtu.format_shape_dtype_string(coords_shape, coords_dtype),
                    order, mode, cval, impl, round_),
                "rng_factory":
                rng_factory,
                "shape":
                shape,
                "coords_shape":
                coords_shape,
                "dtype":
                dtype,
                "coords_dtype":
                coords_dtype,
                "order":
                order,
                "mode":
                mode,
                "cval":
                cval,
                "impl":
                impl,
                "round_":
                round_
            } for shape in [(5, ), (3, 4), (3, 4, 5)]
            for coords_shape in [(7, ), (2, 3, 4)]
            for dtype in float_dtypes + int_dtypes
            for coords_dtype in float_dtypes for order in [0, 1]
            for mode in ['wrap', 'constant', 'nearest', 'mirror', 'reflect']
            for cval in ([0, -1] if mode == 'constant' else [0])
            for impl, rng_factory in [
                ("original", partial(jtu.rand_uniform, low=0, high=1)),
                ("fixed", partial(jtu.rand_uniform, low=-0.75, high=1.75)),
            ] for round_ in [True, False]))
    def testMapCoordinates(self, shape, dtype, coords_shape, coords_dtype,
                           order, mode, cval, impl, round_, rng_factory):
        def args_maker():
            x = np.arange(prod(shape), dtype=dtype).reshape(shape)
            coords = [(size - 1) * rng(coords_shape, coords_dtype)
                      for size in shape]
            if round_:
                coords = [c.round().astype(int) for c in coords]
            return x, coords

        rng = rng_factory(self.rng())
        lsp_op = lambda x, c: lsp_ndimage.map_coordinates(
            x, c, order=order, mode=mode, cval=cval)
        impl_fun = (osp_ndimage.map_coordinates
                    if impl == "original" else _fixed_ref_map_coordinates)
        osp_op = lambda x, c: impl_fun(x, c, order=order, mode=mode, cval=cval)
        if dtype in float_dtypes:
            epsilon = max(
                dtypes.finfo(dtypes.canonicalize_dtype(d)).eps
                for d in [dtype, coords_dtype])
            self._CheckAgainstNumpy(osp_op,
                                    lsp_op,
                                    args_maker,
                                    tol=100 * epsilon)
        else:
            self._CheckAgainstNumpy(osp_op, lsp_op, args_maker, tol=0)

    def testMapCoordinatesErrors(self):
        x = np.arange(5.0)
        c = [np.linspace(0, 5, num=3)]
        with self.assertRaisesRegex(NotImplementedError, 'requires order<=1'):
            lsp_ndimage.map_coordinates(x, c, order=2)
        with self.assertRaisesRegex(NotImplementedError,
                                    'does not yet support mode'):
            lsp_ndimage.map_coordinates(x, c, order=1, mode='grid-wrap')
        with self.assertRaisesRegex(ValueError, 'sequence of length'):
            lsp_ndimage.map_coordinates(x, [c, c], order=1)

    def testMapCoordinateDocstring(self):
        self.assertIn("Only nearest neighbor",
                      lsp_ndimage.map_coordinates.__doc__)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_{np.dtype(dtype)}_order={order}",
            "dtype": dtype,
            "order": order
        } for dtype in float_dtypes + int_dtypes for order in [0, 1]))
    def testMapCoordinatesRoundHalf(self, dtype, order):
        x = np.arange(-3, 3, dtype=dtype)
        c = np.array([[.5, 1.5, 2.5, 3.5]])

        def args_maker():
            return x, c

        lsp_op = lambda x, c: lsp_ndimage.map_coordinates(x, c, order=order)
        osp_op = lambda x, c: osp_ndimage.map_coordinates(x, c, order=order)
        self._CheckAgainstNumpy(osp_op, lsp_op, args_maker)

    def testContinuousGradients(self):
        # regression test for https://github.com/google/jax/issues/3024

        def loss(delta):
            x = np.arange(100.0)
            border = 10
            indices = np.arange(x.size) + delta
            # linear interpolation of the linear function y=x should be exact
            shifted = lsp_ndimage.map_coordinates(x, [indices], order=1)
            return ((x - shifted)**2)[border:-border].mean()

        # analytical gradient of (x - (x - delta)) ** 2 is 2 * delta
        self.assertAllClose(grad(loss)(0.5), 1.0, check_dtypes=False)
        self.assertAllClose(grad(loss)(1.0), 2.0, check_dtypes=False)
Exemplo n.º 14
0
class LaxBackedScipyStatsTests(jtu.JaxTestCase):
    """Tests for LAX-backed scipy.stats implementations"""
    @genNamedParametersNArgs(3)
    def testPoissonLogPmf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.poisson.logpmf
        lax_fun = lsp_stats.poisson.logpmf

        def args_maker():
            k, mu, loc = map(rng, shapes, dtypes)
            k = np.floor(k)
            # clipping to ensure that rate parameter is strictly positive
            mu = np.clip(np.abs(mu), a_min=0.1, a_max=None)
            loc = np.floor(loc)
            return [k, mu, loc]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14})

    @genNamedParametersNArgs(3)
    def testPoissonPmf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.poisson.pmf
        lax_fun = lsp_stats.poisson.pmf

        def args_maker():
            k, mu, loc = map(rng, shapes, dtypes)
            k = np.floor(k)
            # clipping to ensure that rate parameter is strictly positive
            mu = np.clip(np.abs(mu), a_min=0.1, a_max=None)
            loc = np.floor(loc)
            return [k, mu, loc]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testPoissonCdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.poisson.cdf
        lax_fun = lsp_stats.poisson.cdf

        def args_maker():
            k, mu, loc = map(rng, shapes, dtypes)
            # clipping to ensure that rate parameter is strictly positive
            mu = np.clip(np.abs(mu), a_min=0.1, a_max=None)
            return [k, mu, loc]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testBernoulliLogPmf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.bernoulli.logpmf
        lax_fun = lsp_stats.bernoulli.logpmf

        def args_maker():
            x, logit, loc = map(rng, shapes, dtypes)
            x = np.floor(x)
            p = expit(logit)
            loc = np.floor(loc)
            return [x, p, loc]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testGeomLogPmf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.geom.logpmf
        lax_fun = lsp_stats.geom.logpmf

        def args_maker():
            x, logit, loc = map(rng, shapes, dtypes)
            x = np.floor(x)
            p = expit(logit)
            loc = np.floor(loc)
            return [x, p, loc]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(5)
    def testBetaLogPdf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.beta.logpdf
        lax_fun = lsp_stats.beta.logpdf

        def args_maker():
            x, a, b, loc, scale = map(rng, shapes, dtypes)
            return [x, a, b, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun,
                              args_maker,
                              rtol={
                                  np.float32: 2e-3,
                                  np.float64: 1e-4
                              })

    def testBetaLogPdfZero(self):
        # Regression test for https://github.com/google/jax/issues/7645
        a = b = 1.
        x = np.array([0., 1.])
        self.assertAllClose(osp_stats.beta.pdf(x, a, b),
                            lsp_stats.beta.pdf(x, a, b),
                            atol=1E-6)

    @genNamedParametersNArgs(3)
    def testCauchyLogPdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.cauchy.logpdf
        lax_fun = lsp_stats.cauchy.logpdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            jtu.format_test_name_suffix("", [x_shape, alpha_shape], dtypes),
            "shapes": [x_shape, alpha_shape],
            "dtypes":
            dtypes
        } for x_shape in one_and_two_dim_shapes for alpha_shape in [(
            x_shape[0], ), (
                x_shape[0] +
                1, )] for dtypes in itertools.combinations_with_replacement(
                    jtu.dtypes.floating, 2)))
    def testDirichletLogPdf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())

        def _normalize(x, alpha):
            x_norm = x.sum(0) + (0.0 if x.shape[0] == alpha.shape[0] else 0.1)
            return (x / x_norm).astype(x.dtype), alpha

        def lax_fun(x, alpha):
            return lsp_stats.dirichlet.logpdf(*_normalize(x, alpha))

        def scipy_fun(x, alpha):
            # scipy validates the x normalization using float64 arithmetic, so we must
            # cast x to float64 before normalization to ensure this passes.
            x, alpha = _normalize(x.astype('float64'), alpha)

            result = osp_stats.dirichlet.logpdf(x, alpha)
            # if x.shape is (N, 1), scipy flattens the output, while JAX returns arrays
            # of a consistent rank. This check ensures the results have the same shape.
            return result if x.ndim == 1 else np.atleast_1d(result)

        def args_maker():
            # Don't normalize here, because we want normalization to happen at 64-bit
            # precision in the scipy version.
            x, alpha = map(rng, shapes, dtypes)
            return x, alpha

        tol = {np.float32: 1E-3, np.float64: 1e-5}
        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=tol)
        self._CompileAndCheck(lax_fun, args_maker, atol=tol, rtol=tol)

    @genNamedParametersNArgs(3)
    def testExponLogPdf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.expon.logpdf
        lax_fun = lsp_stats.expon.logpdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(4)
    def testGammaLogPdf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.gamma.logpdf
        lax_fun = lsp_stats.gamma.logpdf

        def args_maker():
            x, a, loc, scale = map(rng, shapes, dtypes)
            return [x, a, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=5e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    def testGammaLogPdfZero(self):
        # Regression test for https://github.com/google/jax/issues/7256
        self.assertAllClose(osp_stats.gamma.pdf(0.0, 1.0),
                            lsp_stats.gamma.pdf(0.0, 1.0),
                            atol=1E-6)

    @genNamedParametersNArgs(4)
    def testNBinomLogPmf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.nbinom.logpmf
        lax_fun = lsp_stats.nbinom.logpmf

        def args_maker():
            k, n, logit, loc = map(rng, shapes, dtypes)
            k = np.floor(np.abs(k))
            n = np.ceil(np.abs(n))
            p = expit(logit)
            loc = np.floor(loc)
            return [k, n, p, loc]

        tol = {np.float32: 1e-6, np.float64: 1e-8}
        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=5e-4)
        self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)

    @genNamedParametersNArgs(3)
    def testLaplaceLogPdf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.laplace.logpdf
        lax_fun = lsp_stats.laplace.logpdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = np.clip(scale, a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testLaplaceCdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.laplace.cdf
        lax_fun = lsp_stats.laplace.cdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # ensure that scale is not too low
            scale = np.clip(scale, a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol={
                                    np.float32: 1e-5,
                                    np.float64: 1e-6
                                })
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(1)
    def testLogisticCdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.logistic.cdf
        lax_fun = lsp_stats.logistic.cdf

        def args_maker():
            return list(map(rng, shapes, dtypes))

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-6)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(1)
    def testLogisticLogpdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.logistic.logpdf
        lax_fun = lsp_stats.logistic.logpdf

        def args_maker():
            return list(map(rng, shapes, dtypes))

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(1)
    def testLogisticPpf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.logistic.ppf
        lax_fun = lsp_stats.logistic.ppf

        def args_maker():
            return list(map(rng, shapes, dtypes))

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(1)
    def testLogisticSf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.logistic.sf
        lax_fun = lsp_stats.logistic.sf

        def args_maker():
            return list(map(rng, shapes, dtypes))

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-6)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testNormLogPdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.norm.logpdf
        lax_fun = lsp_stats.norm.logpdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testNormLogCdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.norm.logcdf
        lax_fun = lsp_stats.norm.logcdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testNormCdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.norm.cdf
        lax_fun = lsp_stats.norm.cdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-6)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testNormPpf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.norm.ppf
        lax_fun = lsp_stats.norm.ppf

        def args_maker():
            q, loc, scale = map(rng, shapes, dtypes)
            # ensure probability is between 0 and 1:
            q = np.clip(np.abs(q / 3), a_min=None, a_max=1)
            # clipping to ensure that scale is not too low
            scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
            return [q, loc, scale]

        self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)

    @genNamedParametersNArgs(4)
    def testParetoLogPdf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.pareto.logpdf
        lax_fun = lsp_stats.pareto.logpdf

        def args_maker():
            x, b, loc, scale = map(rng, shapes, dtypes)
            return [x, b, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(4)
    def testTLogPdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.t.logpdf
        lax_fun = lsp_stats.t.logpdf

        def args_maker():
            x, df, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
            return [x, df, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun,
                              args_maker,
                              rtol={np.float64: 1e-14},
                              atol={np.float64: 1e-14})

    @genNamedParametersNArgs(3)
    def testUniformLogPdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.uniform.logpdf
        lax_fun = lsp_stats.uniform.logpdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            return [x, loc, np.abs(scale)]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(4)
    def testChi2LogPdf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.chi2.logpdf
        lax_fun = lsp_stats.chi2.logpdf

        def args_maker():
            x, df, loc, scale = map(rng, shapes, dtypes)
            return [x, df, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=5e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(5)
    def testBetaBinomLogPmf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        lax_fun = lsp_stats.betabinom.logpmf

        def args_maker():
            k, n, a, b, loc = map(rng, shapes, dtypes)
            k = np.floor(k)
            n = np.ceil(n)
            a = np.clip(a, a_min=0.1, a_max=None)
            b = np.clip(a, a_min=0.1, a_max=None)
            loc = np.floor(loc)
            return [k, n, a, b, loc]

        if scipy_version >= (1, 4):
            scipy_fun = osp_stats.betabinom.logpmf
            self._CheckAgainstNumpy(scipy_fun,
                                    lax_fun,
                                    args_maker,
                                    check_dtypes=False,
                                    tol=5e-4)
        self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5)

    def testIssue972(self):
        self.assertAllClose(np.ones((4, ), np.float32),
                            lsp_stats.norm.cdf(
                                np.full((4, ), np.inf, np.float32)),
                            check_dtypes=False)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_x={}_mean={}_cov={}".format(
                jtu.format_shape_dtype_string(x_shape, x_dtype),
                jtu.format_shape_dtype_string(mean_shape, mean_dtype)
                if mean_shape is not None else None,
                jtu.format_shape_dtype_string(cov_shape, cov_dtype)
                if cov_shape is not None else None),
            "x_shape":
            x_shape,
            "x_dtype":
            x_dtype,
            "mean_shape":
            mean_shape,
            "mean_dtype":
            mean_dtype,
            "cov_shape":
            cov_shape,
            "cov_dtype":
            cov_dtype
        } for x_shape, mean_shape, cov_shape in [
            # # These test cases cover default values for mean/cov, but we don't
            # # support those yet (and they seem not very valuable).
            # [(), None, None],
            # [(), (), None],
            # [(2,), None, None],
            # [(2,), (), None],
            # [(2,), (2,), None],
            # [(3, 2), (3, 2,), None],
            # [(5, 3, 2), (5, 3, 2,), None],
            [(), (), ()],
            [(3, ), (), ()],
            [(3, ), (3, ), ()],
            [(3, ), (3, ), (3, 3)],
            [(3, 4), (4, ), (4, 4)],

            # # These test cases are where scipy flattens things, which has
            # # different batch semantics than some might expect
            # [(5, 3, 2), (5, 3, 2,), ()],
            # [(5, 3, 2), (5, 3, 2,), (5, 3, 2, 2)],
            # [(5, 3, 2), (3, 2,), (5, 3, 2, 2)],
            # [(5, 3, 2), (3, 2,), (2, 2)],
        ] for x_dtype, mean_dtype, cov_dtype in
                            itertools.combinations_with_replacement(
                                jtu.dtypes.floating, 3)
                            if (mean_shape is not None
                                or mean_dtype == np.float32) and
                            (cov_shape is not None or cov_dtype == np.float32))
    )
    def testMultivariateNormalLogpdf(self, x_shape, x_dtype, mean_shape,
                                     mean_dtype, cov_shape, cov_dtype):
        rng = jtu.rand_default(self.rng())

        def args_maker():
            args = [rng(x_shape, x_dtype)]
            if mean_shape is not None:
                args.append(5 * rng(mean_shape, mean_dtype))
            if cov_shape is not None:
                if cov_shape == ():
                    args.append(0.1 + rng(cov_shape, cov_dtype)**2)
                else:
                    factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1])
                    factor = rng(factor_shape, cov_dtype)
                    args.append(np.matmul(factor, np.swapaxes(factor, -1, -2)))
            return args

        self._CheckAgainstNumpy(osp_stats.multivariate_normal.logpdf,
                                lsp_stats.multivariate_normal.logpdf,
                                args_maker,
                                tol=1e-3)
        self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf,
                              args_maker,
                              rtol=1e-4,
                              atol=1e-4)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_ndim={}_nbatch={}_dtype={}".format(ndim, nbatch, dtype.__name__),
            "ndim":
            ndim,
            "nbatch":
            nbatch,
            "dtype":
            dtype
        } for ndim in [2, 3] for nbatch in [1, 3, 5]
                            for dtype in jtu.dtypes.floating))
    def testMultivariateNormalLogpdfBatch(self, ndim, nbatch, dtype):
        # Regression test for #5570
        rng = jtu.rand_default(self.rng())
        x = rng((nbatch, ndim), dtype)
        mean = 5 * rng((nbatch, ndim), dtype)
        factor = rng((nbatch, ndim, 2 * ndim), dtype)
        cov = factor @ factor.transpose(0, 2, 1)

        result1 = lsp_stats.multivariate_normal.logpdf(x, mean, cov)
        result2 = jax.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov)
        self.assertArraysEqual(result1, result2)
Exemplo n.º 15
0
class LaxBackedScipyTests(jtu.JaxTestCase):
  """Tests for LAX-backed Scipy implementation."""

  def _GetArgsMaker(self, rng, shapes, dtypes):
    return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]


  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format(
          jtu.format_shape_dtype_string(shapes, dtype),
          axis, keepdims, return_sign, use_b),
       # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU.
       "shapes": shapes, "dtype": dtype,
       "axis": axis, "keepdims": keepdims,
       "return_sign": return_sign, "use_b": use_b}
      for shape_group in compatible_shapes for dtype in float_dtypes + complex_dtypes + int_dtypes
      for use_b in [False, True]
      for shapes in itertools.product(*(
        (shape_group, shape_group) if use_b else (shape_group,)))
      for axis in range(-max(len(shape) for shape in shapes),
                         max(len(shape) for shape in shapes))
      for keepdims in [False, True]
      for return_sign in [False, True]))
  @jtu.ignore_warning(category=RuntimeWarning, message="invalid value encountered in .*")
  @jax.numpy_rank_promotion('allow')  # This test explicitly exercises implicit rank promotion.
  def testLogSumExp(self, shapes, dtype, axis,
                    keepdims, return_sign, use_b):
    if jtu.device_under_test() != "cpu":
      rng = jtu.rand_some_inf_and_nan(self.rng())
    else:
      rng = jtu.rand_default(self.rng())
    # TODO(mattjj): test autodiff
    if use_b:
      def scipy_fun(array_to_reduce, scale_array):
        return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
                                     return_sign=return_sign, b=scale_array)

      def lax_fun(array_to_reduce, scale_array):
        return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
                                     return_sign=return_sign, b=scale_array)

      args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)]
    else:
      def scipy_fun(array_to_reduce):
        return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
                                     return_sign=return_sign)

      def lax_fun(array_to_reduce):
        return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
                                     return_sign=return_sign)

      args_maker = lambda: [rng(shapes[0], dtype)]
    tol = {np.float32: 1E-6, np.float64: 1E-14}
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
    self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)

  def testLogSumExpZeros(self):
    # Regression test for https://github.com/google/jax/issues/5370
    scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b)
    lax_fun = lambda a, b: lsp_special.logsumexp(a, b=b)
    args_maker = lambda: [np.array([-1000, -2]), np.array([1, 0])]
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
    self._CompileAndCheck(lax_fun, args_maker)

  def testLogSumExpOnes(self):
    # Regression test for https://github.com/google/jax/issues/7390
    args_maker = lambda: [np.ones(4, dtype='float32')]
    with jax.debug_infs(True):
      self._CheckAgainstNumpy(osp_special.logsumexp, lsp_special.logsumexp, args_maker)
      self._CompileAndCheck(lsp_special.logsumexp, args_maker)

  def testLogSumExpNans(self):
    # Regression test for https://github.com/google/jax/issues/7634
    with jax.debug_nans(True):
      with jax.disable_jit():
        result = lsp_special.logsumexp(1.0)
        self.assertEqual(result, 1.0)

        result = lsp_special.logsumexp(1.0, b=1.0)
        self.assertEqual(result, 1.0)

  @parameterized.named_parameters(itertools.chain.from_iterable(
    jtu.cases_from_list(
        {"testcase_name": jtu.format_test_name_suffix(
            rec.test_name, shapes, dtypes),
         "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes,
         "test_autodiff": rec.test_autodiff,
         "nondiff_argnums": rec.nondiff_argnums,
         "scipy_op": getattr(osp_special, rec.name),
         "lax_op": getattr(lsp_special, rec.name)}
        for shapes in itertools.combinations_with_replacement(all_shapes, rec.nargs)
        for dtypes in (itertools.combinations_with_replacement(rec.dtypes, rec.nargs)
          if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes)))
      for rec in JAX_SPECIAL_FUNCTION_RECORDS))
  @jax.numpy_rank_promotion('allow')  # This test explicitly exercises implicit rank promotion.
  def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes,
                          test_autodiff, nondiff_argnums):
    if (jtu.device_under_test() == "cpu" and
        (lax_op is lsp_special.gammainc or lax_op is lsp_special.gammaincc)):
      # TODO(b/173608403): re-enable test when LLVM bug is fixed.
      raise unittest.SkipTest("Skipping test due to LLVM lowering bug")
    rng = rng_factory(self.rng())
    args_maker = self._GetArgsMaker(rng, shapes, dtypes)
    args = args_maker()
    self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3,
                        check_dtypes=False)
    self._CompileAndCheck(lax_op, args_maker, rtol=1e-4)

    if test_autodiff:
      def partial_lax_op(*vals):
        list_args = list(vals)
        for i in nondiff_argnums:
          list_args.insert(i, args[i])
        return lax_op(*list_args)

      assert list(nondiff_argnums) == sorted(set(nondiff_argnums))
      diff_args = [x for i, x in enumerate(args) if i not in nondiff_argnums]
      jtu.check_grads(partial_lax_op, diff_args, order=1,
                      atol=jtu.if_device_under_test("tpu", .1, 1e-3),
                      rtol=.1, eps=1e-3)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_d={}".format(
          jtu.format_shape_dtype_string(shape, dtype), d),
       "shape": shape, "dtype": dtype, "d": d}
      for shape in all_shapes
      for dtype in float_dtypes
      for d in [1, 2, 5]))
  @jax.numpy_rank_promotion('raise')
  def testMultigammaln(self, shape, dtype, d):
    def scipy_fun(a):
      return osp_special.multigammaln(a, d)

    def lax_fun(a):
      return lsp_special.multigammaln(a, d)

    rng = jtu.rand_positive(self.rng())
    args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
                            tol={np.float32: 1e-3, np.float64: 1e-14})
    self._CompileAndCheck(
        lax_fun, args_maker, rtol={
            np.float32: 3e-07,
            np.float64: 4e-15
        })

  def testIssue980(self):
    x = np.full((4,), -1e20, dtype=np.float32)
    self.assertAllClose(np.zeros((4,), dtype=np.float32),
                        lsp_special.expit(x))

  @jax.numpy_rank_promotion('raise')
  def testIssue3758(self):
    x = np.array([1e5, 1e19, 1e10], dtype=np.float32)
    q = np.array([1., 40., 30.], dtype=np.float32)
    self.assertAllClose(np.array([1., 0., 0.], dtype=np.float32), lsp_special.zeta(x, q))

  def testXlogyShouldReturnZero(self):
    self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False)

  def testGradOfXlogyAtZero(self):
    partial_xlogy = functools.partial(lsp_special.xlogy, 0.)
    self.assertAllClose(jax.grad(partial_xlogy)(0.), 0., check_dtypes=False)

  def testXlog1pyShouldReturnZero(self):
    self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False)

  def testGradOfXlog1pyAtZero(self):
    partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.)
    self.assertAllClose(jax.grad(partial_xlog1py)(-1.), 0., check_dtypes=False)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_lmax={}".format(
        jtu.format_shape_dtype_string(shape, dtype), l_max),
       "l_max": l_max, "shape": shape, "dtype": dtype}
       for l_max in [1, 2, 3]
       for shape in [(5,), (10,)]
       for dtype in float_dtypes))
  def testLpmn(self, l_max, shape, dtype):
    rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9)
    args_maker = lambda: [rng(shape, dtype)]

    lax_fun = partial(lsp_special.lpmn, l_max, l_max)

    def scipy_fun(z, m=l_max, n=l_max):
      # scipy only supports scalar inputs for z, so we must loop here.
      vals, derivs = zip(*(osp_special.lpmn(m, n, zi) for zi in z))
      return np.dstack(vals), np.dstack(derivs)

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1e-6, atol=1e-6)
    self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_lmax={}".format(
        jtu.format_shape_dtype_string(shape, dtype), l_max),
       "l_max": l_max, "shape": shape, "dtype": dtype}
       for l_max in [3, 4, 6, 32]
       for shape in [(2,), (3,), (4,), (64,)]
       for dtype in float_dtypes))
  def testNormalizedLpmnValues(self, l_max, shape, dtype):
    rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9)
    args_maker = lambda: [rng(shape, dtype)]

    # Note: we test only the normalized values, not the derivatives.
    lax_fun = partial(lsp_special.lpmn_values, l_max, l_max, is_normalized=True)

    def scipy_fun(z, m=l_max, n=l_max):
      # scipy only supports scalar inputs for z, so we must loop here.
      vals, _ = zip(*(osp_special.lpmn(m, n, zi) for zi in z))
      a = np.dstack(vals)

      # apply the normalization
      num_m, num_l, _ = a.shape
      a_normalized = np.zeros_like(a)
      for m in range(num_m):
        for l in range(num_l):
          c0 = (2.0 * l + 1.0) * osp_special.factorial(l - m)
          c1 = (4.0 * np.pi) * osp_special.factorial(l + m)
          c2 = np.sqrt(c0 / c1)
          a_normalized[m, l] = c2 * a[m, l]
      return a_normalized

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1e-5, atol=1e-5)
    self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6)

  def testSphHarmAccuracy(self):
    m = jnp.arange(-3, 3)[:, None]
    n = jnp.arange(3, 6)
    n_max = 5
    theta = 0.0
    phi = jnp.pi

    expected = lsp_special.sph_harm(m, n, theta, phi, n_max)

    actual = osp_special.sph_harm(m, n, theta, phi)

    self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)

  def testSphHarmOrderZeroDegreeZero(self):
    """Tests the spherical harmonics of order zero and degree zero."""
    theta = jnp.array([0.3])
    phi = jnp.array([2.3])
    n_max = 0

    expected = jnp.array([1.0 / jnp.sqrt(4.0 * np.pi)])
    actual = jnp.real(
        lsp_special.sph_harm(jnp.array([0]), jnp.array([0]), theta, phi, n_max))

    self.assertAllClose(actual, expected, rtol=1.1e-7, atol=3e-8)

  def testSphHarmOrderZeroDegreeOne(self):
    """Tests the spherical harmonics of order one and degree zero."""
    theta = jnp.array([2.0])
    phi = jnp.array([3.1])
    n_max = 1

    expected = jnp.sqrt(3.0 / (4.0 * np.pi)) * jnp.cos(phi)
    actual = jnp.real(
        lsp_special.sph_harm(jnp.array([0]), jnp.array([1]), theta, phi, n_max))

    self.assertAllClose(actual, expected, rtol=2e-7, atol=6e-8)

  def testSphHarmOrderOneDegreeOne(self):
    """Tests the spherical harmonics of order one and degree one."""
    theta = jnp.array([2.0])
    phi = jnp.array([2.5])
    n_max = 1

    expected = (-1.0 / 2.0 * jnp.sqrt(3.0 / (2.0 * np.pi)) *
                jnp.sin(phi) * jnp.exp(1j * theta))
    actual = lsp_special.sph_harm(
        jnp.array([1]), jnp.array([1]), theta, phi, n_max)

    self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8)

  @parameterized.named_parameters(jtu.cases_from_list(
      {'testcase_name': '_maxdegree={}_inputsize={}_dtype={}'.format(
        l_max, num_z, dtype),
       'l_max': l_max, 'num_z': num_z, 'dtype': dtype}
      for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8])
      for dtype in jtu.dtypes.all_integer))
  def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype):
    """Tests against JIT compatibility and Numpy."""
    n_max = l_max
    shape = (num_z,)
    rng = jtu.rand_int(self.rng(), -l_max, l_max + 1)

    lsp_special_fn = partial(lsp_special.sph_harm, n_max=n_max)

    def args_maker():
      m = rng(shape, dtype)
      n = abs(m)
      theta = jnp.linspace(-4.0, 5.0, num_z)
      phi = jnp.linspace(-2.0, 1.0, num_z)
      return m, n, theta, phi

    with self.subTest('Test JIT compatibility'):
      self._CompileAndCheck(lsp_special_fn, args_maker)

    with self.subTest('Test against numpy.'):
      self._CheckAgainstNumpy(osp_special.sph_harm, lsp_special_fn, args_maker)

  def testSphHarmCornerCaseWithWrongNmax(self):
    """Tests the corner case where `n_max` is not the maximum value of `n`."""
    m = jnp.array([2])
    n = jnp.array([10])
    n_clipped = jnp.array([6])
    n_max = 6
    theta = jnp.array([0.9])
    phi = jnp.array([0.2])

    expected = lsp_special.sph_harm(m, n, theta, phi, n_max)

    actual = lsp_special.sph_harm(m, n_clipped, theta, phi, n_max)

    self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)

  @parameterized.named_parameters(jtu.cases_from_list(
      {'testcase_name':
        '_shape={}'
        '_n_zero_sv={}_degeneracy={}_geometric_spectrum={}'
        '_max_sv={}_method={}_side={}'
        '_nonzero_condition_number={}_seed={}'.format(
          jtu.format_shape_dtype_string(
            shape, jnp.dtype(dtype).name).replace(" ", ""),
          n_zero_sv, degeneracy, geometric_spectrum, max_sv,
          method, side, nonzero_condition_number, seed
        ),
        'n_zero_sv': n_zero_sv, 'degeneracy': degeneracy,
        'geometric_spectrum': geometric_spectrum,
        'max_sv': max_sv, 'shape': shape, 'method': method,
        'side': side, 'nonzero_condition_number': nonzero_condition_number,
        'dtype': dtype, 'seed': seed}
      for n_zero_sv in n_zero_svs
      for degeneracy in degeneracies
      for geometric_spectrum in geometric_spectra
      for max_sv in max_svs
      for shape in polar_shapes
      for method in methods
      for side in sides
      for nonzero_condition_number in nonzero_condition_numbers
      for dtype in jtu.dtypes.floating
      for seed in seeds))
  @jtu.skip_on_devices("gpu")  # Fails on A100.
  def testPolar(
    self, n_zero_sv, degeneracy, geometric_spectrum, max_sv, shape, method,
      side, nonzero_condition_number, dtype, seed):
    """ Tests jax.scipy.linalg.polar."""
    if jtu.device_under_test() != "cpu":
      if jnp.dtype(dtype).name in ("bfloat16", "float16"):
        raise unittest.SkipTest("Skip half precision off CPU.")
      if method == "svd":
        raise unittest.SkipTest("Can't use SVD mode on TPU/GPU.")

    matrix, _ = _initialize_polar_test(self.rng(),
      shape, n_zero_sv, degeneracy, geometric_spectrum, max_sv,
      nonzero_condition_number, dtype)
    if jnp.dtype(dtype).name in ("bfloat16", "float16"):
      self.assertRaises(
        NotImplementedError, jsp.linalg.polar, matrix, method=method,
        side=side)
      return

    unitary, posdef = jsp.linalg.polar(matrix, method=method, side=side)
    if shape[0] >= shape[1]:
      should_be_eye = np.matmul(unitary.conj().T, unitary)
    else:
      should_be_eye = np.matmul(unitary, unitary.conj().T)
    tol = 10 * jnp.finfo(matrix.dtype).eps
    eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype)
    with self.subTest('Test unitarity.'):
      self.assertAllClose(
        eye_mat, should_be_eye, atol=tol * min(shape))

    with self.subTest('Test Hermiticity.'):
      self.assertAllClose(
        posdef, posdef.conj().T, atol=tol * jnp.linalg.norm(posdef))

    ev, _ = np.linalg.eigh(posdef)
    ev = ev[np.abs(ev) > tol * np.linalg.norm(posdef)]
    negative_ev = jnp.sum(ev < 0.)
    with self.subTest('Test positive definiteness.'):
      assert negative_ev == 0.

    if side == "right":
      recon = jnp.matmul(unitary, posdef, precision=lax.Precision.HIGHEST)
    elif side == "left":
      recon = jnp.matmul(posdef, unitary, precision=lax.Precision.HIGHEST)
    with self.subTest('Test reconstruction.'):
      self.assertAllClose(
        matrix, recon, atol=tol * jnp.linalg.norm(matrix))

  @parameterized.named_parameters(jtu.cases_from_list(
    {'testcase_name':
      '_linear_size_={}_seed={}_dtype={}'.format(
        linear_size, seed, jnp.dtype(dtype).name
      ),
      'linear_size': linear_size, 'seed': seed, 'dtype': dtype}
    for linear_size in linear_sizes
    for seed in seeds
    for dtype in jtu.dtypes.floating))
  def test_spectral_dac_eigh(self, linear_size, seed, dtype):
    if jtu.device_under_test != "cpu":
      raise unittest.SkipTest("Skip eigh off CPU for now.")
    if jnp.dtype(dtype).name in ("bfloat16", "float16"):
      if jtu.device_under_test() != "cpu":
        raise unittest.SkipTest("Skip half precision off CPU.")

    rng = self.rng()
    H = rng.randn(linear_size, linear_size)
    H = jnp.array(0.5 * (H + H.conj().T)).astype(dtype)
    if jnp.dtype(dtype).name in ("bfloat16", "float16"):
      self.assertRaises(
        NotImplementedError, jax._src.scipy.eigh.eigh, H)
      return
    evs, V = jax._src.scipy.eigh.eigh(H)
    ev_exp, eV_exp = jnp.linalg.eigh(H)
    HV = jnp.dot(H, V, precision=lax.Precision.HIGHEST)
    vV = evs * V
    eps = jnp.finfo(H.dtype).eps
    atol = jnp.linalg.norm(H) * eps
    self.assertAllClose(ev_exp, jnp.sort(evs), atol=20 * atol)
    self.assertAllClose(HV, vV, atol=30 * atol)

  @parameterized.named_parameters(jtu.cases_from_list(
    {'testcase_name':
      '_linear_size_={}_seed={}_dtype={}'.format(
        linear_size, seed, jnp.dtype(dtype).name
      ),
      'linear_size': linear_size, 'seed': seed, 'dtype': dtype}
    for linear_size in linear_sizes
    for seed in seeds
    for dtype in jtu.dtypes.floating))
  @jtu.skip_on_devices("gpu")  # Fails on A100.
  def test_spectral_dac_svd(self, linear_size, seed, dtype):
    if jnp.dtype(dtype).name in ("bfloat16", "float16"):
      if jtu.device_under_test() != "cpu":
        raise unittest.SkipTest("Skip half precision off CPU.")

    rng = self.rng()
    A = rng.randn(linear_size, linear_size).astype(dtype)
    if jnp.dtype(dtype).name in ("bfloat16", "float16"):
      self.assertRaises(
        NotImplementedError, jax._src.scipy.eigh.svd, A)
      return
    S_expected = np.linalg.svd(A, compute_uv=False)
    U, S, V = jax._src.scipy.eigh.svd(A)
    recon = jnp.dot((U * jnp.expand_dims(S, 0)), V,
                    precision=lax.Precision.HIGHEST)
    eps = jnp.finfo(dtype).eps
    eps = eps * jnp.linalg.norm(A) * 15
    self.assertAllClose(np.sort(S), np.sort(S_expected), atol=eps)
    self.assertAllClose(A, recon, atol=eps)

    # U is unitary.
    u_unitary_delta = jnp.dot(U.conj().T, U, precision=lax.Precision.HIGHEST)
    u_eye = jnp.eye(u_unitary_delta.shape[0], dtype=dtype)
    self.assertAllClose(u_unitary_delta, u_eye, atol=eps)

    # V is unitary.
    v_unitary_delta = jnp.dot(V.conj().T, V, precision=lax.Precision.HIGHEST)
    v_eye = jnp.eye(v_unitary_delta.shape[0], dtype=dtype)
    self.assertAllClose(v_unitary_delta, v_eye, atol=eps)
Exemplo n.º 16
0
class SparsifyTest(jtu.JaxTestCase):

  def assertBcooIdentical(self, x, y):
    self.assertIsInstance(x, BCOO)
    self.assertIsInstance(y, BCOO)
    self.assertEqual(x.shape, y.shape)
    self.assertArraysEqual(x.data, y.data)
    self.assertArraysEqual(x.indices, y.indices)

  def testArgSpec(self):
    X = jnp.arange(5)
    X_BCOO = BCOO.fromdense(X)

    args = (X, X_BCOO, X_BCOO)

    # Independent index
    spenv = SparseEnv()
    argspecs = arrays_to_argspecs(spenv, args)
    self.assertEqual(len(argspecs), len(args))
    self.assertEqual(spenv.size(), 5)
    self.assertEqual(argspecs,
        (ArgSpec(X.shape, 0, None), ArgSpec(X.shape, 1, 2), ArgSpec(X.shape, 3, 4)))

    args_out = argspecs_to_arrays(spenv, argspecs)
    self.assertEqual(len(args_out), len(args))
    self.assertArraysEqual(args[0], args_out[0])
    self.assertBcooIdentical(args[1], args_out[1])
    self.assertBcooIdentical(args[2], args_out[2])

    # Shared index
    argspecs = (ArgSpec(X.shape, 0, None), ArgSpec(X.shape, 1, 2), ArgSpec(X.shape, 3, 2))
    spenv = SparseEnv([X, X_BCOO.data, X_BCOO.indices, X_BCOO.data])

    args_out = argspecs_to_arrays(spenv, argspecs)
    self.assertEqual(len(args_out), len(args))
    self.assertArraysEqual(args[0], args_out[0])
    self.assertBcooIdentical(args[1], args_out[1])
    self.assertBcooIdentical(args[2], args_out[2])

  def testUnitHandling(self):
    x = BCOO.fromdense(jnp.arange(5))
    f = jit(lambda x, y: x)
    result = sparsify(jit(f))(x, core.unit)
    self.assertBcooIdentical(result, x)

  def testDropvar(self):
    def inner(x):
      return x * 2, x * 3

    def f(x):
      _, y = jit(inner)(x)
      return y * 4

    x_dense = jnp.arange(5)
    x_sparse = BCOO.fromdense(x_dense)
    self.assertArraysEqual(sparsify(f)(x_sparse).todense(), f(x_dense))

  def testPytreeInput(self):
    f = sparsify(lambda x: x)
    args = (jnp.arange(4), BCOO.fromdense(jnp.arange(4)))
    out = f(args)
    self.assertLen(out, 2)
    self.assertArraysEqual(args[0], out[0])
    self.assertBcooIdentical(args[1], out[1])

  def testSparsify(self):
    M_dense = jnp.arange(24).reshape(4, 6)
    M_sparse = BCOO.fromdense(M_dense)
    v = jnp.arange(M_dense.shape[0])

    @sparsify
    def func(x, v):
      return -jnp.sin(jnp.pi * x).T @ (v + 1)

    result_dense = func(M_dense, v)
    result_sparse = func(M_sparse, v)

    self.assertAllClose(result_sparse, result_dense)

  def testSparsifyWithConsts(self):
    M_dense = jnp.arange(24).reshape(4, 6)
    M_sparse = BCOO.fromdense(M_dense)

    @sparsify
    def func(x):
      return jit(lambda x: jnp.sum(x, 1))(x)

    result_dense = func(M_dense)
    result_sparse = func(M_sparse)

    self.assertAllClose(result_sparse.todense(), result_dense)

  def testSparseMatmul(self):
    X = jnp.arange(16).reshape(4, 4)
    Xsp = BCOO.fromdense(X)
    Y = jnp.ones(4)
    Ysp = BCOO.fromdense(Y)

    # dot_general
    result_sparse = sparsify(operator.matmul)(Xsp, Y)
    result_dense = operator.matmul(X, Y)
    self.assertAllClose(result_sparse, result_dense)

    # rdot_general
    result_sparse = sparsify(operator.matmul)(Y, Xsp)
    result_dense = operator.matmul(Y, X)
    self.assertAllClose(result_sparse, result_dense)

    # spdot_general
    result_sparse = sparsify(operator.matmul)(Xsp, Ysp)
    result_dense = operator.matmul(X, Y)
    self.assertAllClose(result_sparse.todense(), result_dense)

  def testSparseAdd(self):
    x = BCOO.fromdense(jnp.arange(5))
    y = BCOO.fromdense(2 * jnp.arange(5))

    # Distinct indices
    out = sparsify(operator.add)(x, y)
    self.assertEqual(out.nse, 8)  # uses concatenation.
    self.assertArraysEqual(out.todense(), 3 * jnp.arange(5))

    # Shared indices – requires lower level call
    argspecs = [
      ArgSpec(x.shape, 1, 0),
      ArgSpec(y.shape, 2, 0)
    ]
    spenv = SparseEnv([x.indices, x.data, y.data])

    result = sparsify_raw(operator.add)(spenv, *argspecs)
    args_out, _ = result
    out, = argspecs_to_arrays(spenv, args_out)

    self.assertAllClose(out.todense(), x.todense() + y.todense())

  def testSparseMul(self):
    x = BCOO.fromdense(jnp.arange(5))
    y = BCOO.fromdense(2 * jnp.arange(5))

    # Scalar multiplication
    out = sparsify(operator.mul)(x, 2.5)
    self.assertArraysEqual(out.todense(), x.todense() * 2.5)

    # Shared indices – requires lower level call
    argspecs = [
      ArgSpec(x.shape, 1, 0),
      ArgSpec(y.shape, 2, 0)
    ]
    spenv = SparseEnv([x.indices, x.data, y.data])

    result = sparsify_raw(operator.mul)(spenv, *argspecs)
    args_out, _ = result
    out, = argspecs_to_arrays(spenv, args_out)

    self.assertAllClose(out.todense(), x.todense() * y.todense())

  def testSparseSum(self):
    x = jnp.arange(20).reshape(4, 5)
    xsp = BCOO.fromdense(x)

    def f(x):
      return x.sum(), x.sum(0), x.sum(1), x.sum((0, 1))

    result_dense = f(x)
    result_sparse = sparsify(f)(xsp)

    assert len(result_dense) == len(result_sparse)

    for res_dense, res_sparse in zip(result_dense, result_sparse):
      if isinstance(res_sparse, BCOO):
        res_sparse = res_sparse.todense()
      self.assertArraysAllClose(res_dense, res_sparse)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_dimensions={}_nbatch={}, ndense={}".format(
          jtu.format_shape_dtype_string(shape, np.float32), dimensions, n_batch, n_dense),
       "shape": shape, "dimensions": dimensions, "n_batch": n_batch, "n_dense": n_dense}
      for shape, dimensions in [
          [(1,), (0,)],
          [(1,), (-1,)],
          [(2, 1, 4), (1,)],
          [(2, 1, 3, 1), (1,)],
          [(2, 1, 3, 1), (1, 3)],
          [(2, 1, 3, 1), (3,)],
      ]
      for n_batch in range(len(shape) + 1)
      for n_dense in range(len(shape) - n_batch + 1)))
  def testSparseSqueeze(self, shape, dimensions, n_batch, n_dense):
    rng = jtu.rand_default(self.rng())

    M_dense = rng(shape, np.float32)
    M_sparse = BCOO.fromdense(M_dense, n_batch=n_batch, n_dense=n_dense)
    func = sparsify(partial(lax.squeeze, dimensions=dimensions))

    result_dense = func(M_dense)
    result_sparse = func(M_sparse).todense()

    self.assertAllClose(result_sparse, result_dense)

  def testSparseWhileLoop(self):
    def cond_fun(params):
      i, A = params
      return i < 5

    def body_fun(params):
      i, A = params
      return i + 1, 2 * A

    def f(A):
      return lax.while_loop(cond_fun, body_fun, (0, A))

    A = jnp.arange(4)
    out_dense = f(A)

    Asp = BCOO.fromdense(A)
    out_sparse = sparsify(f)(Asp)

    self.assertEqual(len(out_dense), 2)
    self.assertEqual(len(out_sparse), 2)
    self.assertArraysEqual(out_dense[0], out_dense[0])
    self.assertArraysEqual(out_dense[1], out_sparse[1].todense())

  def testSparseWhileLoopDuplicateIndices(self):
    def cond_fun(params):
      i, A, B = params
      return i < 5

    def body_fun(params):
      i, A, B = params
      # TODO(jakevdp): track shared indices through while loop & use this
      #   version of the test, which requires shared indices in order for
      #   the nse of the result to remain the same.
      # return i + 1, A, A + B

      # This version is fine without shared indices, and tests that we're
      # flattening non-shared indices consistently.
      return i + 1, B, A

    def f(A):
      return lax.while_loop(cond_fun, body_fun, (0, A, A))

    A = jnp.arange(4).reshape((2, 2))
    out_dense = f(A)

    Asp = BCOO.fromdense(A)
    out_sparse = sparsify(f)(Asp)

    self.assertEqual(len(out_dense), 3)
    self.assertEqual(len(out_sparse), 3)
    self.assertArraysEqual(out_dense[0], out_dense[0])
    self.assertArraysEqual(out_dense[1], out_sparse[1].todense())
    self.assertArraysEqual(out_dense[2], out_sparse[2].todense())

  def testSparsifyDenseXlaCall(self):
    # Test handling of dense xla_call within jaxpr interpreter.
    out = sparsify(jit(lambda x: x + 1))(0.0)
    self.assertEqual(out, 1.0)

  def testSparsifySparseXlaCall(self):
    # Test sparse lowering of XLA call
    def func(M):
      return 2 * M

    M = jnp.arange(6).reshape(2, 3)
    Msp = BCOO.fromdense(M)

    out_dense = func(M)
    out_sparse = sparsify(jit(func))(Msp)
    self.assertArraysEqual(out_dense, out_sparse.todense())

  def testSparseForiLoop(self):
    def func(M, x):
      body_fun = lambda i, val: (M @ val) / M.shape[1]
      return lax.fori_loop(0, 2, body_fun, x)

    x = jnp.arange(5.0)
    M = jnp.arange(25).reshape(5, 5)
    M_bcoo = BCOO.fromdense(M)

    result_dense = func(M, x)
    result_sparse = sparsify(func)(M_bcoo, x)

    self.assertArraysAllClose(result_dense, result_sparse)

  def testSparseCondSimple(self):
    def func(x):
      return lax.cond(False, lambda x: x, lambda x: 2 * x, x)

    x = jnp.arange(5.0)
    result_dense = func(x)

    x_bcoo = BCOO.fromdense(x)
    result_sparse = sparsify(func)(x_bcoo)

    self.assertArraysAllClose(result_dense, result_sparse.todense())

  def testSparseCondMismatchError(self):
    @sparsify
    def func(x, y):
      return lax.cond(False, lambda x: x[0], lambda x: x[1], (x, y))

    x = jnp.arange(5.0)
    y = jnp.arange(5.0)

    x_bcoo = BCOO.fromdense(x)
    y_bcoo = BCOO.fromdense(y)

    func(x, y)  # No error
    func(x_bcoo, y_bcoo)  # No error

    with self.assertRaisesRegex(TypeError, "sparsified true_fun and false_fun output.*"):
      func(x_bcoo, y)
Exemplo n.º 17
0
class FftTest(jtu.JaxTestCase):
    def testNotImplemented(self):
        for name in jnp.fft._NOT_IMPLEMENTED:
            func = getattr(jnp.fft, name)
            with self.assertRaises(NotImplementedError):
                func()

    def testLaxFftAcceptsStringTypes(self):
        rng = jtu.rand_default(self.rng())
        x = rng((10, ), np.complex64)
        self.assertAllClose(
            np.fft.fft(x).astype(np.complex64),
            lax.fft(x, "FFT", fft_lengths=(10, )))

    @parameterized.parameters((np.float32, ), (np.float64, ))
    def testLaxIrfftDoesNotMutateInputs(self, dtype):
        if dtype == np.float64 and not config.x64_enabled:
            raise self.skipTest("float64 requires jax_enable_x64=true")
        x = (1 + 1j) * jnp.array([[1.0, 2.0], [3.0, 4.0]],
                                 dtype=dtypes._to_complex_dtype(dtype))
        y = np.asarray(jnp.fft.irfft2(x))
        z = np.asarray(jnp.fft.irfft2(x))
        self.assertAllClose(y, z)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_inverse={}_real={}_shape={}_axes={}_s={}_norm={}".format(
                    inverse, real, jtu.format_shape_dtype_string(shape, dtype),
                    axes, s, norm),
                "axes":
                axes,
                "shape":
                shape,
                "dtype":
                dtype,
                "inverse":
                inverse,
                "real":
                real,
                "s":
                s,
                "norm":
                norm
            } for inverse in [False, True] for real in [False, True]
            for dtype in (real_dtypes if real and not inverse else all_dtypes)
            for shape in [(10, ), (10, 10), (9, ), (2, 3, 4), (2, 3, 4, 5)]
            for axes in _get_fftn_test_axes(shape)
            for s in _get_fftn_test_s(shape, axes) for norm in FFT_NORMS))
    def testFftn(self, inverse, real, shape, dtype, axes, s, norm):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: (rng(shape, dtype), )
        jnp_op = _get_fftn_func(jnp.fft, inverse, real)
        np_op = _get_fftn_func(np.fft, inverse, real)
        jnp_fn = lambda a: jnp_op(a, axes=axes, norm=norm)
        np_fn = lambda a: np_op(a, axes=axes, norm=norm
                                ) if axes is None or axes else a
        # Numpy promotes to complex128 aggressively.
        self._CheckAgainstNumpy(np_fn,
                                jnp_fn,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(jnp_fn, args_maker)
        # Test gradient for differentiable types.
        if (config.x64_enabled and dtype
                in (float_dtypes if real and not inverse else inexact_dtypes)):
            # TODO(skye): can we be more precise?
            tol = 0.15
            jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)

        # check dtypes
        dtype = jnp_fn(rng(shape, dtype)).dtype
        expected_dtype = jnp.promote_types(
            float if inverse and real else complex, dtype)
        self.assertEqual(dtype, expected_dtype)

    def testIrfftTranspose(self):
        # regression test for https://github.com/google/jax/issues/6223
        def build_matrix(linear_func, size):
            return jax.vmap(linear_func)(jnp.eye(size, size))

        def func(x):
            x, = _promote_dtypes_complex(x)
            return jnp.fft.irfft(
                jnp.concatenate(
                    [jnp.zeros_like(x, shape=1), x[:2] + 1j * x[2:]]))

        def func_transpose(x):
            return jax.linear_transpose(func, x)(x)[0]

        matrix = build_matrix(func, 4)
        matrix2 = build_matrix(func_transpose, 4).T
        self.assertAllClose(matrix, matrix2)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_inverse={inverse}_real={real}",
            "inverse": inverse,
            "real": real
        } for inverse in [False, True] for real in [False, True]))
    def testFftnErrors(self, inverse, real):
        rng = jtu.rand_default(self.rng())
        name = 'fftn'
        if real:
            name = 'r' + name
        if inverse:
            name = 'i' + name
        func = _get_fftn_func(jnp.fft, inverse, real)
        self.assertRaisesRegex(
            ValueError, "jax.numpy.fft.{} only supports 1D, 2D, and 3D FFTs. "
            "Got axes None with input rank 4.".format(name),
            lambda: func(rng([2, 3, 4, 5], dtype=np.float64), axes=None))
        self.assertRaisesRegex(
            ValueError,
            f"jax.numpy.fft.{name} does not support repeated axes. Got axes \\[1, 1\\].",
            lambda: func(rng([2, 3], dtype=np.float64), axes=[1, 1]))
        self.assertRaises(
            ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[2]))
        self.assertRaises(
            ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[-3]))

    def testFftEmpty(self):
        out = jnp.fft.fft(jnp.zeros((0, ), jnp.complex64)).block_until_ready()
        self.assertArraysEqual(jnp.zeros((0, ), jnp.complex64), out)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_inverse={}_real={}_hermitian={}_shape={}_n={}_axis={}".
                format(inverse, real, hermitian,
                       jtu.format_shape_dtype_string(shape, dtype), n, axis),
                "axis":
                axis,
                "shape":
                shape,
                "dtype":
                dtype,
                "inverse":
                inverse,
                "real":
                real,
                "hermitian":
                hermitian,
                "n":
                n
            } for inverse in [False, True] for real in [False, True]
            for hermitian in [False, True]
            for dtype in (real_dtypes if (real and not inverse) or (
                hermitian and inverse) else all_dtypes) for shape in [(10, )]
            for n in [None, 1, 7, 13, 20] for axis in [-1, 0]))
    def testFft(self, inverse, real, hermitian, shape, dtype, n, axis):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: (rng(shape, dtype), )
        name = 'fft'
        if real:
            name = 'r' + name
        elif hermitian:
            name = 'h' + name
        if inverse:
            name = 'i' + name
        jnp_op = getattr(jnp.fft, name)
        np_op = getattr(np.fft, name)
        jnp_fn = lambda a: jnp_op(a, n=n, axis=axis)
        np_fn = lambda a: np_op(a, n=n, axis=axis)
        # Numpy promotes to complex128 aggressively.
        self._CheckAgainstNumpy(np_fn,
                                jnp_fn,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(jnp_op, args_maker)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            f"_inverse={inverse}_real={real}_hermitian={hermitian}",
            "inverse": inverse,
            "real": real,
            "hermitian": hermitian
        } for inverse in [False, True] for real in [False, True]
                            for hermitian in [False, True]))
    def testFftErrors(self, inverse, real, hermitian):
        rng = jtu.rand_default(self.rng())
        name = 'fft'
        if real:
            name = 'r' + name
        elif hermitian:
            name = 'h' + name
        if inverse:
            name = 'i' + name
        func = getattr(jnp.fft, name)

        self.assertRaisesRegex(
            ValueError,
            f"jax.numpy.fft.{name} does not support multiple axes. "
            f"Please use jax.numpy.fft.{name}n. Got axis = \\[1, 1\\].",
            lambda: func(rng([2, 3], dtype=np.float64), axis=[1, 1]))
        self.assertRaisesRegex(
            ValueError,
            f"jax.numpy.fft.{name} does not support multiple axes. "
            f"Please use jax.numpy.fft.{name}n. Got axis = \\(1, 1\\).",
            lambda: func(rng([2, 3], dtype=np.float64), axis=(1, 1)))
        self.assertRaises(
            ValueError, lambda: func(rng([2, 3], dtype=np.float64), axis=[2]))
        self.assertRaises(
            ValueError, lambda: func(rng([2, 3], dtype=np.float64), axis=[-3]))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_inverse={}_real={}_shape={}_axes={}_norm={}".format(
                    inverse, real, jtu.format_shape_dtype_string(shape, dtype),
                    axes, norm),
                "axes":
                axes,
                "shape":
                shape,
                "dtype":
                dtype,
                "inverse":
                inverse,
                "real":
                real,
                "norm":
                norm
            } for inverse in [False, True] for real in [False, True]
            for dtype in (real_dtypes if real and not inverse else all_dtypes)
            for shape in [(16, 8, 4, 8), (16, 8, 4, 8, 4)]
            for axes in [(-2, -1), (0, 1), (1, 3), (-1, 2)]
            for norm in FFT_NORMS))
    def testFft2(self, inverse, real, shape, dtype, axes, norm):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: (rng(shape, dtype), )
        name = 'fft2'
        if real:
            name = 'r' + name
        if inverse:
            name = 'i' + name
        jnp_op = getattr(jnp.fft, name)
        np_op = getattr(np.fft, name)
        jnp_fn = lambda a: jnp_op(a, axes=axes, norm=norm)
        np_fn = lambda a: np_op(a, axes=axes, norm=norm
                                ) if axes is None or axes else a
        # Numpy promotes to complex128 aggressively.
        self._CheckAgainstNumpy(np_fn,
                                jnp_fn,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(jnp_op, args_maker)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_inverse={inverse}_real={real}",
            "inverse": inverse,
            "real": real
        } for inverse in [False, True] for real in [False, True]))
    def testFft2Errors(self, inverse, real):
        rng = jtu.rand_default(self.rng())
        name = 'fft2'
        if real:
            name = 'r' + name
        if inverse:
            name = 'i' + name
        func = getattr(jnp.fft, name)

        self.assertRaisesRegex(
            ValueError, "jax.numpy.fft.{} only supports 2 axes. "
            "Got axes = \\[0\\].".format(name),
            lambda: func(rng([2, 3], dtype=np.float64), axes=[0]))
        self.assertRaisesRegex(
            ValueError, "jax.numpy.fft.{} only supports 2 axes. "
            "Got axes = \\(0, 1, 2\\).".format(name),
            lambda: func(rng([2, 3, 3], dtype=np.float64), axes=(0, 1, 2)))
        self.assertRaises(
            ValueError,
            lambda: func(rng([2, 3], dtype=np.float64), axes=[2, 3]))
        self.assertRaises(
            ValueError,
            lambda: func(rng([2, 3], dtype=np.float64), axes=[-3, -4]))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_size={}_d={}".format(
                jtu.format_shape_dtype_string([size], dtype), d),
            "dtype":
            dtype,
            "size":
            size,
            "d":
            d
        } for dtype in all_dtypes for size in [9, 10, 101, 102]
                            for d in [0.1, 2.]))
    def testFftfreq(self, size, d, dtype):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: (rng([size], dtype), )
        jnp_op = jnp.fft.fftfreq
        np_op = np.fft.fftfreq
        jnp_fn = lambda a: jnp_op(size, d=d)
        np_fn = lambda a: np_op(size, d=d)
        # Numpy promotes to complex128 aggressively.
        self._CheckAgainstNumpy(np_fn,
                                jnp_fn,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(jnp_fn, args_maker)
        # Test gradient for differentiable types.
        if dtype in inexact_dtypes:
            tol = 0.15  # TODO(skye): can we be more precise?
            jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_n={n}",
            "n": n
        } for n in [[0, 1, 2]]))
    def testFftfreqErrors(self, n):
        name = 'fftfreq'
        func = jnp.fft.fftfreq
        self.assertRaisesRegex(
            ValueError,
            "The n argument of jax.numpy.fft.{} only takes an int. "
            "Got n = \\[0, 1, 2\\].".format(name), lambda: func(n=n))
        self.assertRaisesRegex(
            ValueError,
            "The d argument of jax.numpy.fft.{} only takes a single value. "
            "Got d = \\[0, 1, 2\\].".format(name), lambda: func(n=10, d=n))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_size={}_d={}".format(
                jtu.format_shape_dtype_string([size], dtype), d),
            "dtype":
            dtype,
            "size":
            size,
            "d":
            d
        } for dtype in all_dtypes for size in [9, 10, 101, 102]
                            for d in [0.1, 2.]))
    def testRfftfreq(self, size, d, dtype):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: (rng([size], dtype), )
        jnp_op = jnp.fft.rfftfreq
        np_op = np.fft.rfftfreq
        jnp_fn = lambda a: jnp_op(size, d=d)
        np_fn = lambda a: np_op(size, d=d)
        # Numpy promotes to complex128 aggressively.
        self._CheckAgainstNumpy(np_fn,
                                jnp_fn,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(jnp_fn, args_maker)
        # Test gradient for differentiable types.
        if dtype in inexact_dtypes:
            tol = 0.15  # TODO(skye): can we be more precise?
            jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_n={n}",
            "n": n
        } for n in [[0, 1, 2]]))
    def testRfftfreqErrors(self, n):
        name = 'rfftfreq'
        func = jnp.fft.rfftfreq
        self.assertRaisesRegex(
            ValueError,
            "The n argument of jax.numpy.fft.{} only takes an int. "
            "Got n = \\[0, 1, 2\\].".format(name), lambda: func(n=n))
        self.assertRaisesRegex(
            ValueError,
            "The d argument of jax.numpy.fft.{} only takes a single value. "
            "Got d = \\[0, 1, 2\\].".format(name), lambda: func(n=10, d=n))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "dtype={}_axes={}".format(
                    jtu.format_shape_dtype_string(shape, dtype), axes),
                "dtype":
                dtype,
                "shape":
                shape,
                "axes":
                axes
            } for dtype in all_dtypes
            for shape in [[9], [10], [101], [102], [3, 5], [3, 17], [5, 7, 11]]
            for axes in _get_fftn_test_axes(shape)))
    def testFftshift(self, shape, dtype, axes):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: (rng(shape, dtype), )
        jnp_fn = lambda arg: jnp.fft.fftshift(arg, axes=axes)
        np_fn = lambda arg: np.fft.fftshift(arg, axes=axes)
        self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "dtype={}_axes={}".format(
                    jtu.format_shape_dtype_string(shape, dtype), axes),
                "dtype":
                dtype,
                "shape":
                shape,
                "axes":
                axes
            } for dtype in all_dtypes
            for shape in [[9], [10], [101], [102], [3, 5], [3, 17], [5, 7, 11]]
            for axes in _get_fftn_test_axes(shape)))
    def testIfftshift(self, shape, dtype, axes):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: (rng(shape, dtype), )
        jnp_fn = lambda arg: jnp.fft.ifftshift(arg, axes=axes)
        np_fn = lambda arg: np.fft.ifftshift(arg, axes=axes)
        self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker)
Exemplo n.º 18
0
class TestPolynomial(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_dtype={}_leading={}_trailing={}".format(
                jtu.format_shape_dtype_string((
                    length + leading + trailing, ), dtype), leading, trailing),
            "dtype":
            dtype,
            "length":
            length,
            "leading":
            leading,
            "trailing":
            trailing
        } for dtype in all_dtypes for length in [0, 3, 9, 10, 17]
                            for leading in [0, 1, 2, 3, 5, 7, 10]
                            for trailing in [0, 1, 2, 3, 5, 7, 10]))
    # TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU.
    @jtu.skip_on_devices("gpu", "tpu")
    def testRoots(self, dtype, length, leading, trailing):
        rng = jtu.rand_default(np.random.RandomState(0))

        def args_maker():
            p = rng((length, ), dtype)
            return jnp.concatenate(
                [jnp.zeros(leading, p.dtype), p,
                 jnp.zeros(trailing, p.dtype)]),

        jnp_fn = lambda arg: jnp.sort(jnp.roots(arg))
        np_fn = lambda arg: np.sort(np.roots(arg))
        self._CheckAgainstNumpy(np_fn,
                                jnp_fn,
                                args_maker,
                                check_dtypes=False,
                                tol=3e-6)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_dtype={}_trailing={}".format(
                jtu.format_shape_dtype_string((length +
                                               trailing, ), dtype), trailing),
            "dtype":
            dtype,
            "length":
            length,
            "trailing":
            trailing
        } for dtype in all_dtypes for length in [0, 1, 3, 10]
                            for trailing in [0, 1, 3, 7]))
    # TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU.
    @jtu.skip_on_devices("gpu", "tpu")
    def testRootsNostrip(self, length, dtype, trailing):
        rng = jtu.rand_default(np.random.RandomState(0))

        def args_maker():
            p = rng((length, ), dtype)
            if length != 0:
                return jnp.concatenate([p, jnp.zeros(trailing, p.dtype)]),
            else:
                # adding trailing would make input invalid (start with zeros)
                return p,

        jnp_fn = lambda arg: jnp.sort(jnp.roots(arg, strip_zeros=False))
        np_fn = lambda arg: np.sort(np.roots(arg))
        self._CheckAgainstNumpy(np_fn,
                                jnp_fn,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-6)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_dtype={}_trailing={}".format(
                jtu.format_shape_dtype_string((length +
                                               trailing, ), dtype), trailing),
            "dtype":
            dtype,
            "length":
            length,
            "trailing":
            trailing
        } for dtype in all_dtypes for length in [0, 1, 3, 10]
                            for trailing in [0, 1, 3, 7]))
    # TODO: enable when there is an eigendecomposition implementation
    # for GPU/TPU.
    @jtu.skip_on_devices("gpu", "tpu")
    def testRootsJit(self, length, dtype, trailing):
        rng = jtu.rand_default(np.random.RandomState(0))

        def args_maker():
            p = rng((length, ), dtype)
            if length != 0:
                return jnp.concatenate([p, jnp.zeros(trailing, p.dtype)]),
            else:
                # adding trailing would make input invalid (start with zeros)
                return p,

        roots_compiled = jit(partial(jnp.roots, strip_zeros=False))
        jnp_fn = lambda arg: jnp.sort(roots_compiled(arg))
        np_fn = lambda arg: np.sort(np.roots(arg))
        # Using strip_zeros=False makes the algorithm less efficient
        # and leads to slightly different values compared ot numpy
        self._CheckAgainstNumpy(np_fn,
                                jnp_fn,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-6)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_dtype={}_zeros={}_nonzeros={}".format(
                jtu.format_shape_dtype_string((
                    zeros + nonzeros, ), dtype), zeros, nonzeros),
            "zeros":
            zeros,
            "nonzeros":
            nonzeros,
            "dtype":
            dtype
        } for dtype in all_dtypes for zeros in [1, 2, 5]
                            for nonzeros in [0, 3]))
    @jtu.skip_on_devices("gpu")
    @unittest.skip("getting segfaults on MKL")  # TODO(#3711)
    def testRootsInvalid(self, zeros, nonzeros, dtype):
        rng = jtu.rand_default(np.random.RandomState(0))

        # The polynomial coefficients here start with zero and would have to
        # be stripped before computing eigenvalues of the companion matrix.
        # Setting strip_zeros=False skips this check,
        # allowing jit transformation but yielding nan's for these inputs.
        p = jnp.concatenate(
            [jnp.zeros(zeros, dtype),
             rng((nonzeros, ), dtype)])

        if p.size == 1:
            # polynomial = const has no roots
            self.assertTrue(jnp.roots(p, strip_zeros=False).size == 0)
        else:
            self.assertTrue(jnp.any(jnp.isnan(jnp.roots(p,
                                                        strip_zeros=False))))