def test_jax_implemented(self, harness: primitive_harness.Harness):
        """Runs all harnesses just with JAX to verify the jax_unimplemented field.
    """
        jax_unimpl = [
            l for l in harness.jax_unimplemented
            if l.filter(device=jtu.device_under_test(), dtype=harness.dtype)
        ]
        if any([lim.skip_run for lim in jax_unimpl]):
            logging.info(
                f"Skipping run with expected JAX limitations: "
                f"{[u.description for u in jax_unimpl]} in harness {harness.fullname}"
            )
            return
        try:
            harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
        except Exception as e:
            if jax_unimpl:
                logging.info(
                    f"Found expected JAX error {e} with expected JAX limitations: "
                    f"{[u.description for u in jax_unimpl]} in harness {harness.fullname}"
                )
                return
            else:
                raise e

        if jax_unimpl:
            msg = (
                "Found no JAX error but expected JAX limitations: "
                f"{[u.description for u in jax_unimpl]} in harness: {harness.fullname}"
            )
            logging.warning(msg)
Exemple #2
0
  def test_qr(self, harness: primitive_harness.Harness):
    # See jax.lib.lapack.geqrf for the list of compatible types
    if (harness.params["dtype"] in [jnp.float32, jnp.float64] or
        harness.params["dtype"] == jnp.float16 and jtu.device_under_test() == "tpu"):
      self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
                             atol=1e-5, rtol=1e-5)
    elif harness.params["dtype"] in [jnp.complex64, jnp.complex128]:
      if (jtu.device_under_test() == "tpu" and
          harness.params["dtype"] in [jnp.complex64]):
        raise unittest.SkipTest("QR for c64 not implemented on TPU")

      # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824.
      # - check_compiled=True breaks for complex types;
      # - for now, the performance of the HLO QR implementation called when
      #   compiling with TF is expected to have worse performance than the
      #   custom calls made in JAX.
      self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
                             expect_tf_exceptions=True, atol=1e-5, rtol=1e-5)
    else:
      # TODO(necula): fix QR bug on TPU
      if (jtu.device_under_test() == "tpu" and
          harness.params["dtype"] in (jnp.bfloat16, jnp.int32, jnp.uint32)):
        raise unittest.SkipTest("QR bug on TPU for certain types: error not raised")
      if (jtu.device_under_test() == "tpu" and
          harness.params["dtype"] in (jnp.bool_,)):
        raise unittest.SkipTest("QR bug on TPU for certain types: invalid cast")

      expected_error = ValueError if jtu.device_under_test() == "gpu" else NotImplementedError
      with self.assertRaisesRegex(expected_error, "Unsupported dtype"):
        harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
Exemple #3
0
 def test_fft(self, harness: primitive_harness.Harness):
     if len(harness.params["fft_lengths"]) > 3:
         if jtu.device_under_test() == "gpu":
             with self.assertRaisesRegex(RuntimeError,
                                         "FFT only supports ranks 1-3"):
                 harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
         else:
             raise unittest.SkipTest("TF does not support >3D FFTs.")
     elif (jtu.device_under_test() == "tpu"
           and len(harness.params["fft_lengths"]) > 1):
         # TODO(b/140351181): FFT is mostly unimplemented on TPU, even for JAX
         with self.assertRaisesRegex(RuntimeError,
                                     "only 1D FFT is currently supported."):
             harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
     else:
         tol = None
         if jtu.device_under_test() in ("cpu", "gpu"):
             if harness.params["dtype"] in jtu.dtypes.boolean:
                 tol = 0.01
             else:
                 tol = 1e-3
         self.ConvertAndCompare(harness.dyn_fun,
                                *harness.dyn_args_maker(self.rng()),
                                atol=tol,
                                rtol=tol)
Exemple #4
0
  def test_svd(self, harness: primitive_harness.Harness):
    if jtu.device_under_test() == "tpu":
      raise unittest.SkipTest("TODO: test crashes the XLA compiler for some TPU variants")
    expect_tf_exceptions = False
    if harness.params["dtype"] in [jnp.float16, dtypes.bfloat16]:
      if jtu.device_under_test() == "tpu":
        # TODO: SVD on TPU for bfloat16 seems to work for JAX but fails for TF
        expect_tf_exceptions = True
      else:
        # Does not work in JAX
        with self.assertRaisesRegex(NotImplementedError, "Unsupported dtype"):
          harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
        return

    if harness.params["dtype"] in [jnp.complex64, jnp.complex128]:
      if jtu.device_under_test() == "tpu":
        # TODO: on JAX on TPU there is no SVD implementation for complex
        with self.assertRaisesRegex(RuntimeError,
                                    "Binary op compare with different element types"):
          harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
        return
      else:
        # TODO: on CPU and GPU "No registered 'Svd' OpKernel for XLA_CPU_JIT devices".
        # Works on JAX because JAX uses a custom implementation.
        expect_tf_exceptions = True

    def _custom_assert(r_jax, r_tf, atol=1e-6, rtol=1e-6):
      def _reconstruct_operand(result, is_tf: bool):
        # Reconstructing operand as documented in numpy.linalg.svd (see
        # https://numpy.org/doc/stable/reference/generated/numpy.linalg.svd.html)
        s, u, v = result
        if is_tf:
          s = s.numpy()
          u = u.numpy()
          v = v.numpy()
        U = u[..., :s.shape[-1]]
        V = v[..., :s.shape[-1], :]
        S = s[..., None, :]
        return jnp.matmul(U * S, V), s.shape, u.shape, v.shape

      if harness.params["compute_uv"]:
        r_jax_reconstructed = _reconstruct_operand(r_jax, False)
        r_tf_reconstructed = _reconstruct_operand(r_tf, True)
        self.assertAllClose(r_jax_reconstructed, r_tf_reconstructed,
                            atol=atol, rtol=rtol)
      else:
        self.assertAllClose(r_jax, r_tf, atol=atol, rtol=rtol)

    tol = 1e-4
    custom_assert = partial(_custom_assert, atol=tol, rtol=tol)

    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
                           atol=tol, rtol=tol,
                           expect_tf_exceptions=expect_tf_exceptions,
                           custom_assert=custom_assert,
                           always_custom_assert=True)
Exemple #5
0
    def test_svd(self, harness: primitive_harness.Harness):
        if harness.params["dtype"] in [np.float16, dtypes.bfloat16]:
            if jtu.device_under_test() != "tpu":
                # Does not work in JAX
                with self.assertRaisesRegex(NotImplementedError,
                                            "Unsupported dtype"):
                    harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
                return

        if harness.params["dtype"] in [np.complex64, np.complex128]:
            if jtu.device_under_test() == "tpu":
                # TODO: on JAX on TPU there is no SVD implementation for complex
                with self.assertRaisesRegex(
                        RuntimeError,
                        "Binary op compare with different element types"):
                    harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
                return

        def _custom_assert(r_jax, r_tf, atol=1e-6, rtol=1e-6):
            def _reconstruct_operand(result, is_tf: bool):
                # Reconstructing operand as documented in numpy.linalg.svd (see
                # https://numpy.org/doc/stable/reference/generated/numpy.linalg.svd.html)
                s, u, v = result
                if is_tf:
                    s = s.numpy()
                    u = u.numpy()
                    v = v.numpy()
                U = u[..., :s.shape[-1]]
                V = v[..., :s.shape[-1], :]
                S = s[..., None, :]
                return jnp.matmul(U * S, V), s.shape, u.shape, v.shape

            if harness.params["compute_uv"]:
                r_jax_reconstructed = _reconstruct_operand(r_jax, False)
                r_tf_reconstructed = _reconstruct_operand(r_tf, True)
                self.assertAllClose(r_jax_reconstructed,
                                    r_tf_reconstructed,
                                    atol=atol,
                                    rtol=rtol)
            else:
                self.assertAllClose(r_jax, r_tf, atol=atol, rtol=rtol)

        tol = 1e-4
        custom_assert = partial(_custom_assert, atol=tol, rtol=tol)

        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               atol=tol,
                               rtol=tol,
                               custom_assert=custom_assert,
                               always_custom_assert=True)
  def test_prim(self, harness: Harness):
    args = harness.dyn_args_maker(self.rng())
    poly_axes = harness.params["poly_axes"]
    assert len(args) == len(poly_axes)
    # Make the polymorphic_shapes and input_signature
    polymorphic_shapes: List[Optional[str]] = []
    input_signature: List[tf.TensorSpec] = []
    for i, (arg, poly_axis) in enumerate(zip(args, poly_axes)):
      if poly_axis is None:
        polymorphic_shapes.append(None)
        input_signature.append(tf.TensorSpec(np.shape(arg), arg.dtype))
      else:
        polymorphic_shapes.append(
            ", ".join([str(d) if i != poly_axis else "b"
                       for i, d in enumerate(arg.shape)]))
        input_signature.append(tf.TensorSpec([d if i != poly_axis else None
                                              for i, d in enumerate(arg.shape)],
                                             arg.dtype))

    res_jax = harness.dyn_fun(*args)
    f_tf = self.CheckShapePolymorphism(
        harness.dyn_fun,
        input_signature=input_signature,
        polymorphic_shapes=polymorphic_shapes,
        expected_output_signature=None)

    if harness.params["check_result"]:
      self.assertAllClose(res_jax, f_tf(*args))
Exemple #7
0
 def test_top_k(self, harness: primitive_harness.Harness):
   if (harness.params["k"] > harness.params["shape"][-1] or
       harness.params["k"] < 0):
     with self.assertRaisesRegex(ValueError, "k argument to top_k must be"):
       harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
   elif harness.params["dtype"] in jtu.dtypes.complex:
     # TODO(necula): fix top_k complex bug on TPU
     if jtu.device_under_test() == "tpu":
       raise unittest.SkipTest("top_k complex on TPU raises different error")
     with self.assertRaisesRegex(RuntimeError, "Unimplemented: complex comparison"):
       harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
   # TODO: TF and JAX sort [inf, nan] differently.
   elif harness.name.startswith("nan_"):
     raise unittest.SkipTest("inconsistent [nan, inf] sorting")
   else:
     self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
Exemple #8
0
 def test_top_k(self, harness: primitive_harness.Harness):
     if (harness.params["k"] > harness.params["shape"][-1]
             or harness.params["k"] < 0):
         with self.assertRaisesRegex(ValueError,
                                     "k argument to top_k must be"):
             harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
     # TODO: figure out what's up with bfloat16
     elif harness.params["dtype"] is dtypes.bfloat16:
         raise unittest.SkipTest("bfloat16 support not implemented")
     elif harness.params["dtype"] in jtu.dtypes.complex:
         with self.assertRaisesRegex(RuntimeError,
                                     "Unimplemented: complex comparison"):
             harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
     # TODO: TF and JAX sort [inf, nan] differently.
     elif harness.name.startswith("nan_"):
         raise unittest.SkipTest("inconsistent [nan, inf] sorting")
     else:
         self.ConvertAndCompare(harness.dyn_fun,
                                *harness.dyn_args_maker(self.rng()))
Exemple #9
0
 def test_qr(self, harness: primitive_harness.Harness):
     # See jax.lib.lapack.geqrf for the list of compatible types
     if harness.params["dtype"] in [jnp.float32, jnp.float64]:
         self.ConvertAndCompare(harness.dyn_fun,
                                *harness.dyn_args_maker(self.rng()),
                                atol=1e-5,
                                rtol=1e-5)
     elif harness.params["dtype"] in [jnp.complex64, jnp.complex128]:
         # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824.
         # - check_compiled=True breaks for complex types;
         # - for now, the performance of the HLO QR implementation called when
         #   compiling with TF is expected to have worse performance than the
         #   custom calls made in JAX.
         self.ConvertAndCompare(harness.dyn_fun,
                                *harness.dyn_args_maker(self.rng()),
                                expect_tf_exceptions=True,
                                atol=1e-5,
                                rtol=1e-5)
     else:
         expected_error = ValueError if jtu.device_under_test(
         ) == "gpu" else NotImplementedError
         with self.assertRaisesRegex(expected_error, "Unsupported dtype"):
             harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
Exemple #10
0
    def test_top_k(self, harness: primitive_harness.Harness):
        custom_assert = None
        k, dtype = harness.params["k"], harness.params["dtype"]
        if k > harness.params["shape"][-1] or k < 0:
            with self.assertRaisesRegex(ValueError,
                                        "k argument to top_k must be"):
                harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
            return
        if dtype in jtu.dtypes.complex:
            # TODO(necula): fix top_k complex bug on TPU
            if jtu.device_under_test() == "tpu":
                raise unittest.SkipTest(
                    "top_k complex on TPU raises different error")
            with self.assertRaisesRegex(RuntimeError,
                                        "Unimplemented: complex comparison"):
                harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
            return
        if dtype in jtu.dtypes.all_inexact:

            def custom_assert(result_jax, result_tf):
                assert len(result_jax) == len(result_tf)
                # TODO: TF and JAX sort [inf, nan] differently.
                first_arr_jax, first_arr_tf = result_jax[0], result_tf[
                    0].numpy()
                if np.all(first_arr_jax == first_arr_tf):
                    for arr_jax, arr_tf in zip(result_jax, result_tf):
                        self.assertArraysEqual(arr_jax, arr_tf)
                else:
                    mask_jax, mask_tf = np.isnan(first_arr_jax), np.isnan(
                        first_arr_tf)
                    self.assertArraysEqual(first_arr_jax[~mask_jax],
                                           first_arr_tf[~mask_tf])

        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               custom_assert=custom_assert)
  def test_prim(self, harness: Harness):
    args = harness.dyn_args_maker(self.rng())
    poly_axes = harness.params["poly_axes"]  # type: Sequence[Sequence[int]]
    assert len(args) == len(poly_axes)
    # Make the polymorphic_shapes and input_signature
    polymorphic_shapes: List[Optional[str]] = []
    input_signature: List[tf.TensorSpec] = []
    for arg, poly_axis in zip(args, poly_axes):
      if poly_axis is None:
        polymorphic_shapes.append(None)
        input_signature.append(tf.TensorSpec(np.shape(arg), arg.dtype))
      else:
        def make_arg_polymorphic_shapes(poly_axis: Sequence[int]) -> Tuple[str, tf.TensorSpec]:
          idx = -1
          dims = []
          tensorspec_dims: List[Optional[int]] = []
          for i, d in enumerate(arg.shape):
            if i in poly_axis:
              idx += 1
              dims.append(f"b{idx}")
              tensorspec_dims.append(None)
            else:
              dims.append(str(d))
              tensorspec_dims.append(d)
          return ", ".join(dims), tf.TensorSpec(tensorspec_dims, arg.dtype)

        arg_polymorphic_shapes, arg_tensorspec = make_arg_polymorphic_shapes(poly_axis)
        polymorphic_shapes.append(arg_polymorphic_shapes)
        input_signature.append(arg_tensorspec)

    res_jax = harness.dyn_fun(*args)
    f_tf = self.CheckShapePolymorphism(
        harness.dyn_fun,
        input_signature=input_signature,
        polymorphic_shapes=polymorphic_shapes,
        expected_output_signature=None)

    if harness.params["check_result"]:
      tol = harness.params["tol"]
      self.assertAllClose(res_jax, f_tf(*args), atol=tol, rtol=tol)