Exemple #1
0
 def test_fft(self, harness: primitive_harness.Harness):
     if len(harness.params["fft_lengths"]) > 3:
         if jtu.device_under_test() == "gpu":
             with self.assertRaisesRegex(RuntimeError,
                                         "FFT only supports ranks 1-3"):
                 harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
         else:
             raise unittest.SkipTest("TF does not support >3D FFTs.")
     elif (jtu.device_under_test() == "tpu"
           and len(harness.params["fft_lengths"]) > 1):
         # TODO(b/140351181): FFT is mostly unimplemented on TPU, even for JAX
         with self.assertRaisesRegex(RuntimeError,
                                     "only 1D FFT is currently supported."):
             harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
     else:
         tol = None
         if jtu.device_under_test() in ("cpu", "gpu"):
             if harness.params["dtype"] in jtu.dtypes.boolean:
                 tol = 0.01
             else:
                 tol = 1e-3
         self.ConvertAndCompare(harness.dyn_fun,
                                *harness.dyn_args_maker(self.rng()),
                                atol=tol,
                                rtol=tol)
Exemple #2
0
    def test_reduce_window(self, harness: primitive_harness.Harness):
        dtype = harness.params['dtype']

        if (jtu.device_under_test() == 'tpu' and dtype is np.complex64):
            raise unittest.SkipTest(
                'TODO: JAX reduce_window on TPU does not handle complex64')

        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()))
Exemple #3
0
 def test_pad(self, harness: primitive_harness.Harness):
     if harness.params["dtype"] is dtypes.bfloat16:
         raise unittest.SkipTest("bfloat16 not implemented")
     # TODO: implement (or decide not to) pads with negative edge padding
     if any([lo < 0 or hi < 0 for lo, hi, mid in harness.params["pads"]]):
         raise unittest.SkipTest("pad with negative pad not supported")
     self.ConvertAndCompare(harness.dyn_fun,
                            *harness.dyn_args_maker(self.rng()),
                            with_function=True)
Exemple #4
0
 def test_prim(self, harness: primitive_harness.Harness):
     limitations = Jax2TfLimitation.limitations_for_harness(harness)
     device = jtu.device_under_test()
     limitations = tuple(
         filter(lambda l: l.filter(device=device, dtype=harness.dtype),
                limitations))
     func_jax = harness.dyn_fun
     args = harness.dyn_args_maker(self.rng())
     self.ConvertAndCompare(func_jax, *args, limitations=limitations)
Exemple #5
0
 def test_pad(self, harness: primitive_harness.Harness):
     # TODO: figure out the bfloat16 story
     if harness.params["dtype"] is dtypes.bfloat16:
         raise unittest.SkipTest("bfloat16 not implemented")
     # TODO: fix pad with negative padding in XLA (fixed on 06/16/2020)
     if any([lo < 0 or hi < 0 for lo, hi, mid in harness.params["pads"]]):
         raise unittest.SkipTest("pad with negative pad not supported")
     self.ConvertAndCompare(harness.dyn_fun,
                            *harness.dyn_args_maker(self.rng()))
Exemple #6
0
  def test_eig(self, harness: primitive_harness.Harness):
    operand = harness.dyn_args_maker(self.rng())[0]
    compute_left_eigenvectors = harness.params["compute_left_eigenvectors"]
    compute_right_eigenvectors = harness.params["compute_right_eigenvectors"]
    dtype = harness.params["dtype"]

    if jtu.device_under_test() != "cpu":
      raise unittest.SkipTest("eig only supported on CPU in JAX")

    if dtype in [np.float16, dtypes.bfloat16]:
      raise unittest.SkipTest("eig unsupported with (b)float16 in JAX")

    def custom_assert(result_jax, result_tf):
      result_tf = tuple(map(lambda e: e.numpy(), result_tf))
      inner_dimension = operand.shape[-1]
      # Test ported from tests.lax_test.testEig
      # Norm, adjusted for dimension and type.
      def norm(x):
        norm = np.linalg.norm(x, axis=(-2, -1))
        return norm / ((inner_dimension + 1) * jnp.finfo(dtype).eps)

      def check_right_eigenvectors(a, w, vr):
        self.assertTrue(
          np.all(norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100))

      def check_left_eigenvectors(a, w, vl):
        rank = len(a.shape)
        aH = jnp.conj(a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2]))
        wC = jnp.conj(w)
        check_right_eigenvectors(aH, wC, vl)

      def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array):
        tol = None
        # TODO(bchetioui): numerical discrepancies
        if dtype in [np.float32, np.complex64]:
          tol = 1e-4
        elif dtype in [np.float64, np.complex128]:
          tol = 1e-13
        closest_diff = min(abs(eigenvalues_array - eigenvalue))
        self.assertAllClose(closest_diff, np.array(0., closest_diff.dtype),
                            atol=tol)

      all_w_jax, all_w_tf = result_jax[0], result_tf[0]
      for idx in itertools.product(*map(range, operand.shape[:-2])):
        w_jax, w_tf = all_w_jax[idx], all_w_tf[idx]
        for i in range(inner_dimension):
          check_eigenvalue_is_in_array(w_jax[i], w_tf)
          check_eigenvalue_is_in_array(w_tf[i], w_jax)

      if compute_left_eigenvectors:
        check_left_eigenvectors(operand, all_w_tf, result_tf[1])
      if compute_right_eigenvectors:
        check_right_eigenvectors(operand, all_w_tf,
                                 result_tf[1 + compute_left_eigenvectors])

    self.ConvertAndCompare(harness.dyn_fun, operand,
                           custom_assert=custom_assert)
Exemple #7
0
 def test_sort(self, harness: primitive_harness.Harness):
   if (jtu.device_under_test() == "gpu" and
       len(harness.arg_descriptors) == 4 and
       not harness.params["is_stable"]):
     # TODO: fix the TF GPU test
     raise unittest.SkipTest("GPU tests are running TF on CPU")
   if jtu.device_under_test() == "tpu" and harness.params["dtype"] in jtu.dtypes.complex:
     raise unittest.SkipTest("JAX sort is not implemented on TPU for complex")
   self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
Exemple #8
0
  def test_scatter(self, harness: primitive_harness.Harness):
    f_name = harness.params['f_lax'].__name__
    dtype = harness.params['dtype']

    if jtu.device_under_test() == 'tpu':
      if dtype is np.complex64 and f_name in ['scatter_min', 'scatter_max']:
          raise unittest.SkipTest(f"TODO: complex {f_name} on TPU fails in JAX")

    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
Exemple #9
0
    def test_linear_solve(self, harness: primitive_harness.Harness):
        a, b = harness.dyn_args_maker(self.rng())
        if harness.params["symmetric"]:
            a = a + a.T
        tol = None
        if (harness.params["dtype"] == np.float32
                and jtu.device_under_test() == "tpu"):
            tol = 0.01

        self.ConvertAndCompare(harness.dyn_fun, a, b, atol=tol, rtol=tol)
Exemple #10
0
 def test_triangular_solve(self, harness: primitive_harness.Harness):
   dtype = harness.params["dtype"]
   if dtype == np.float16 and jtu.device_under_test() == "gpu":
     raise unittest.SkipTest(
       f"Triangular solve is not implemented in JAX for dtype {dtype}")
   atol = rtol = None
   if dtype == np.float32:
     atol = rtol = 1e-5
   self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
                          atol=atol, rtol=rtol)
Exemple #11
0
 def test_min_max(self, harness: primitive_harness.Harness):
   # TODO(bchetioui): discrepancies between TF & JAX when comparing with NaN;
   # JAX always returns NaN, while TF returns the value NaN is compared with.
   def custom_assert(result_jax, result_tf):
     mask = np.isnan(result_jax)
     self.assertAllClose(result_jax[~ mask], result_tf[~ mask])
   # TODO(bchetioui): figure out why we need always_custom_assert=True
   always_custom_assert = True
   self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
                          custom_assert=custom_assert,
                          always_custom_assert=always_custom_assert)
Exemple #12
0
 def test_div(self, harness: primitive_harness.Harness):
     dividend, divisor = harness.dyn_args_maker(self.rng())
     prim = harness.params["prim"]
     if dtypes.issubdtype(dividend.dtype, np.integer):
         if (prim is lax.div_p
                 and np.any(divisor == np.array(0, dtype=divisor.dtype))):
             raise unittest.SkipTest(
                 "Divisor contains a 0, and TF returns an error value in compiled "
                 "mode instead of failing like in eager and graph mode for dtype "
                 f"{divisor.dtype}")
     self.ConvertAndCompare(harness.dyn_fun, dividend, divisor)
Exemple #13
0
  def test_jax_implemented(self, harness: primitive_harness.Harness):
    """Runs all harnesses just with JAX to verify the jax_unimplemented field.
    """
    jax_unimpl = [l for l in harness.jax_unimplemented
                  if l.filter(jtu.device_under_test())]
    try:
      harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
    except Exception as e:
      if jax_unimpl:
        logging.info(
          f"Found expected JAX error {e} with expected JAX limitations: "
          f"{[u.description for u in jax_unimpl]} in harness {harness.fullname}")
        return
      else:
        raise e

    if jax_unimpl:
      msg = ("Found no JAX error but expected JAX limitations: "
             f"{[u.description for u in jax_unimpl]} in harness: {harness.fullname}")
      logging.warning(msg)
Exemple #14
0
  def test_svd(self, harness: primitive_harness.Harness):
    if jtu.device_under_test() == "tpu":
      raise unittest.SkipTest("TODO: test crashes the XLA compiler for some TPU variants")
    expect_tf_exceptions = False
    if harness.params["dtype"] in [jnp.float16, dtypes.bfloat16]:
      if jtu.device_under_test() == "tpu":
        # TODO: SVD on TPU for bfloat16 seems to work for JAX but fails for TF
        expect_tf_exceptions = True
      else:
        # Does not work in JAX
        with self.assertRaisesRegex(NotImplementedError, "Unsupported dtype"):
          harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
        return

    if harness.params["dtype"] in [jnp.complex64, jnp.complex128]:
      if jtu.device_under_test() == "tpu":
        # TODO: on JAX on TPU there is no SVD implementation for complex
        with self.assertRaisesRegex(RuntimeError,
                                    "Binary op compare with different element types"):
          harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
        return
      else:
        # TODO: on CPU and GPU "No registered 'Svd' OpKernel for XLA_CPU_JIT devices".
        # Works on JAX because JAX uses a custom implementation.
        expect_tf_exceptions = True

    def _custom_assert(r_jax, r_tf, atol=1e-6, rtol=1e-6):
      def _reconstruct_operand(result, is_tf: bool):
        # Reconstructing operand as documented in numpy.linalg.svd (see
        # https://numpy.org/doc/stable/reference/generated/numpy.linalg.svd.html)
        s, u, v = result
        if is_tf:
          s = s.numpy()
          u = u.numpy()
          v = v.numpy()
        U = u[..., :s.shape[-1]]
        V = v[..., :s.shape[-1], :]
        S = s[..., None, :]
        return jnp.matmul(U * S, V), s.shape, u.shape, v.shape

      if harness.params["compute_uv"]:
        r_jax_reconstructed = _reconstruct_operand(r_jax, False)
        r_tf_reconstructed = _reconstruct_operand(r_tf, True)
        self.assertAllClose(r_jax_reconstructed, r_tf_reconstructed,
                            atol=atol, rtol=rtol)
      else:
        self.assertAllClose(r_jax, r_tf, atol=atol, rtol=rtol)

    tol = 1e-4
    custom_assert = partial(_custom_assert, atol=tol, rtol=tol)

    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
                           atol=tol, rtol=tol,
                           expect_tf_exceptions=expect_tf_exceptions,
                           custom_assert=custom_assert,
                           always_custom_assert=True)
Exemple #15
0
  def test_unary_elementwise(self, harness: primitive_harness.Harness):
    dtype = harness.params["dtype"]
    lax_name = harness.params["lax_name"]
    arg, = harness.dyn_args_maker(self.rng())
    custom_assert = None
    if lax_name == "digamma":
      # TODO(necula): fix bug with digamma/(f32|f16) on TPU
      if dtype in [np.float16, np.float32] and jtu.device_under_test() == "tpu":
        raise unittest.SkipTest("TODO: fix bug: nan vs not-nan")

      # In the bfloat16 case, TF and lax both return NaN in undefined cases.
      if not dtype is dtypes.bfloat16:
        # digamma is not defined at 0 and -1
        def custom_assert(result_jax, result_tf):
          # lax.digamma returns NaN and tf.math.digamma returns inf
          special_cases = (arg == 0.) | (arg == -1.)
          nr_special_cases = np.count_nonzero(special_cases)
          self.assertAllClose(np.full((nr_special_cases,), dtype(np.nan)),
                              result_jax[special_cases])
          self.assertAllClose(np.full((nr_special_cases,), dtype(np.inf)),
                              result_tf[special_cases])
          # non-special cases are equal
          self.assertAllClose(result_jax[~ special_cases],
                              result_tf[~ special_cases])
    if lax_name == "erf_inv":
      # TODO(necula): fix erf_inv bug on TPU
      if jtu.device_under_test() == "tpu":
        raise unittest.SkipTest("erf_inv bug on TPU: nan vs non-nan")
      # TODO: investigate: in the (b)float16 cases, TF and lax both return the
      # same result in undefined cases.
      if not dtype in [np.float16, dtypes.bfloat16]:
        # erf_inv is not defined for arg <= -1 or arg >= 1
        def custom_assert(result_jax, result_tf):  # noqa: F811
          # for arg < -1 or arg > 1
          # lax.erf_inv returns NaN; tf.math.erf_inv return +/- inf
          special_cases = (arg < -1.) | (arg > 1.)
          nr_special_cases = np.count_nonzero(special_cases)
          self.assertAllClose(np.full((nr_special_cases,), dtype(np.nan),
                                      dtype=dtype),
                              result_jax[special_cases])
          signs = np.where(arg[special_cases] < 0., -1., 1.)
          self.assertAllClose(np.full((nr_special_cases,),
                                      signs * dtype(np.inf), dtype=dtype),
                              result_tf[special_cases])
          # non-special cases are equal
          self.assertAllClose(result_jax[~ special_cases],
                              result_tf[~ special_cases])
    atol = None
    if jtu.device_under_test() == "gpu":
      # TODO(necula): revisit once we fix the GPU tests
      atol = 1e-3
    self.ConvertAndCompare(harness.dyn_fun, arg, custom_assert=custom_assert,
                           atol=atol)
Exemple #16
0
 def test_prim(self, harness: primitive_harness.Harness):
   limitations = Jax2TfLimitation.limitations_for_harness(harness)
   device = jtu.device_under_test()
   limitations = tuple(filter(lambda l: l.filter(device=device,
                                                 dtype=harness.dtype), limitations))
   func_jax = harness.dyn_fun
   args = harness.dyn_args_maker(self.rng())
   enable_xla = harness.params.get("enable_xla", True)
   associative_scan_reductions = harness.params.get("associative_scan_reductions", False)
   with jax.jax2tf_associative_scan_reductions(associative_scan_reductions):
     self.ConvertAndCompare(func_jax, *args, limitations=limitations,
                            enable_xla=enable_xla)
Exemple #17
0
    def test_betainc(self, harness: primitive_harness.Harness):
        dtype = harness.params["dtype"]
        # TODO: https://www.tensorflow.org/api_docs/python/tf/math/betainc only
        # supports float32/64 tests.
        tol = None
        if dtype is np.float64:
            tol = 1e-14

        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               atol=tol,
                               rtol=tol)
  def test_prim(self, harness: Harness):
    args = harness.dyn_args_maker(self.rng())
    poly_axes = harness.params["poly_axes"]  # type: Sequence[Sequence[int]]
    assert len(args) == len(poly_axes)
    # Make the polymorphic_shapes and input_signature
    polymorphic_shapes: List[Optional[str]] = []
    input_signature: List[tf.TensorSpec] = []
    for arg, poly_axis in zip(args, poly_axes):
      if poly_axis is None:
        polymorphic_shapes.append(None)
        input_signature.append(tf.TensorSpec(np.shape(arg), arg.dtype))
      else:
        def make_arg_polymorphic_shapes(poly_axis: Sequence[int]) -> Tuple[str, tf.TensorSpec]:
          idx = -1
          dims = []
          tensorspec_dims: List[Optional[int]] = []
          for i, d in enumerate(arg.shape):
            if i in poly_axis:
              idx += 1
              dims.append(f"b{idx}")
              tensorspec_dims.append(None)
            else:
              dims.append(str(d))
              tensorspec_dims.append(d)
          return ", ".join(dims), tf.TensorSpec(tensorspec_dims, arg.dtype)

        arg_polymorphic_shapes, arg_tensorspec = make_arg_polymorphic_shapes(poly_axis)
        polymorphic_shapes.append(arg_polymorphic_shapes)
        input_signature.append(arg_tensorspec)

    res_jax = harness.dyn_fun(*args)
    f_tf = self.CheckShapePolymorphism(
        harness.dyn_fun,
        input_signature=input_signature,
        polymorphic_shapes=polymorphic_shapes,
        expected_output_signature=None)

    if harness.params["check_result"]:
      tol = harness.params["tol"]
      self.assertAllClose(res_jax, f_tf(*args), atol=tol, rtol=tol)
Exemple #19
0
    def test_qr(self, harness: primitive_harness.Harness):
        # See jax.lib.lapack.geqrf for the list of compatible types
        if (harness.params["dtype"] in [jnp.float32, jnp.float64]
                or harness.params["dtype"] == jnp.float16
                and jtu.device_under_test() == "tpu"):
            self.ConvertAndCompare(harness.dyn_fun,
                                   *harness.dyn_args_maker(self.rng()),
                                   atol=1e-5,
                                   rtol=1e-5)
        elif harness.params["dtype"] in [jnp.complex64, jnp.complex128]:
            if (jtu.device_under_test() == "tpu"
                    and harness.params["dtype"] in [jnp.complex64]):
                raise unittest.SkipTest("QR for c64 not implemented on TPU")

            # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824.
            # - check_compiled=True breaks for complex types;
            # - for now, the performance of the HLO QR implementation called when
            #   compiling with TF is expected to have worse performance than the
            #   custom calls made in JAX.
            self.ConvertAndCompare(harness.dyn_fun,
                                   *harness.dyn_args_maker(self.rng()),
                                   expect_tf_exceptions=True,
                                   atol=1e-5,
                                   rtol=1e-5)
        else:
            # TODO(necula): fix QR bug on TPU
            if (jtu.device_under_test() == "tpu" and harness.params["dtype"]
                    in (jnp.bfloat16, jnp.int32, jnp.uint32)):
                raise unittest.SkipTest(
                    "QR bug on TPU for certain types: error not raised")
            if (jtu.device_under_test() == "tpu"
                    and harness.params["dtype"] in (jnp.bool_, )):
                raise unittest.SkipTest(
                    "QR bug on TPU for certain types: invalid cast")

            expected_error = ValueError if jtu.device_under_test(
            ) == "gpu" else NotImplementedError
            with self.assertRaisesRegex(expected_error, "Unsupported dtype"):
                harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
Exemple #20
0
    def test_unary_elementwise(self, harness: primitive_harness.Harness):
        dtype = harness.params["dtype"]
        if dtype is dtypes.bfloat16:
            raise unittest.SkipTest("bfloat16 not implemented")
        arg, = harness.dyn_args_maker(self.rng())
        custom_assert = None
        if harness.params["lax_name"] == "digamma":
            # digamma is not defined at 0 and -1
            def custom_assert(result_jax, result_tf):
                # lax.digamma returns NaN and tf.math.digamma returns inf
                special_cases = (arg == 0.) | (arg == -1.)
                nr_special_cases = np.count_nonzero(special_cases)
                self.assertAllClose(
                    np.full((nr_special_cases, ), dtype(np.nan)),
                    result_jax[special_cases])
                self.assertAllClose(
                    np.full((nr_special_cases, ), dtype(np.inf)),
                    result_tf[special_cases])
                # non-special cases are equal
                self.assertAllClose(result_jax[~special_cases],
                                    result_tf[~special_cases])

        if harness.params["lax_name"] == "erf_inv":
            # TODO(necula): fix bug with erf_inv/f16
            if dtype is np.float16:
                raise unittest.SkipTest("TODO: fix bug")
            # erf_inf is not defined for arg <= -1 or arg >= 1
            def custom_assert(result_jax, result_tf):  # noqa: F811
                # for arg < -1 or arg > 1
                # lax.erf_inf returns NaN; tf.math.erf_inv return +/- inf
                special_cases = (arg < -1.) | (arg > 1.)
                nr_special_cases = np.count_nonzero(special_cases)
                self.assertAllClose(
                    np.full((nr_special_cases, ), dtype(np.nan)),
                    result_jax[special_cases])
                signs = np.where(arg[special_cases] < 0., -1., 1.)
                self.assertAllClose(
                    np.full((nr_special_cases, ), signs * dtype(np.inf)),
                    result_tf[special_cases])
                # non-special cases are equal
                self.assertAllClose(result_jax[~special_cases],
                                    result_tf[~special_cases])

        atol = None
        if jtu.device_under_test() == "gpu":
            # TODO(necula): revisit once we fix the GPU tests
            atol = 1e-3
        self.ConvertAndCompare(harness.dyn_fun,
                               arg,
                               custom_assert=custom_assert,
                               atol=atol)
Exemple #21
0
 def test_qr(self, harness: primitive_harness.Harness):
     # See jax.lib.lapack.geqrf for the list of compatible types
     if harness.params["dtype"] in [jnp.float32, jnp.float64]:
         self.ConvertAndCompare(harness.dyn_fun,
                                *harness.dyn_args_maker(self.rng()),
                                atol=1e-5,
                                rtol=1e-5)
     elif harness.params["dtype"] in [jnp.complex64, jnp.complex128]:
         # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824.
         # - check_compiled=True breaks for complex types;
         # - for now, the performance of the HLO QR implementation called when
         #   compiling with TF is expected to have worse performance than the
         #   custom calls made in JAX.
         self.ConvertAndCompare(harness.dyn_fun,
                                *harness.dyn_args_maker(self.rng()),
                                expect_tf_exceptions=True,
                                atol=1e-5,
                                rtol=1e-5)
     else:
         expected_error = ValueError if jtu.device_under_test(
         ) == "gpu" else NotImplementedError
         with self.assertRaisesRegex(expected_error, "Unsupported dtype"):
             harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
Exemple #22
0
    def test_conv_general_dilated(self, harness: primitive_harness.Harness):
        dtype, device = harness.params["dtype"], jtu.device_under_test()
        if device == "gpu" and dtype in [np.complex64, np.complex128]:
            raise unittest.SkipTest("TODO: crash on GPU in TF")

        tol = None
        if device == "gpu":
            tol = 1e-4
        elif device == "tpu":
            tol = 1e-3
        # TODO(bchetioui): significant discrepancies in some float16 cases.
        if dtype == np.float16:
            tol = 1.
        # TODO(bchetioui): slight occasional discrepancy in float32 cases.
        elif dtype == np.float32:
            tol = 0.5 if device == "tpu" else (
                1e-3 if device == "gpu" else 1e-4)
        elif dtype == np.complex64 and device == "tpu":
            tol = 0.1
        # TODO(bchetioui): slight discrepancy when going through the path using
        # tf.nn.convolution.
        elif dtype == np.float64 and device == "cpu":
            tol = 1e-13

        # TODO(bchetioui): unidentified bug in compiled mode. The test that fails is
        #
        # test_conv_general_dilated_tf_conversion_path_3d_lhs=float32[1,4,28,28,1]_rhs=float32[2,3,3,1,16]_windowstrides=(1,1,1)_padding=VALID_lhsdilation=(1,1,1)_rhsdilation=(1,1,2)_dimensionnumbers=('NDHWC','DHWIO','NDHWC')_featuregroupcount=1_batchgroupcount=1_precision=None_enablexla=False
        #
        # with the following assertion error in TensorFlowTrace.process_primitive:
        #
        # AssertionError: conv_general_dilated: out.aval = ShapedArray(float32[1,3,24,26,16]); expected ShapedArray(float32[1,3,26,24,16])
        #
        # Deactivating this assertion is enough to pass the test, which suggests
        # that the end shape is indeed the correct one (i.e. (1,3,26,24,16)).
        # Further investigation is required to really understand this behavior,
        # which we have not managed to reproduce as a pure TF test.
        #
        # This bug is low priority since it only occurs when using a non-TFXLA
        # conversion path in compiled mode, i.e. in a context where using the
        # TFXLA path is possible.
        if harness.name == "_tf_conversion_path_3d_lhs=float32[1,4,28,28,1]_rhs=float32[2,3,3,1,16]_windowstrides=(1,1,1)_padding=VALID_lhsdilation=(1,1,1)_rhsdilation=(1,1,2)_dimensionnumbers=('NDHWC','DHWIO','NDHWC')_featuregroupcount=1_batchgroupcount=1_precision=None_enablexla=False":
            raise unittest.SkipTest(
                "TODO: known but unidentified bug in compiled "
                "mode")

        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               atol=tol,
                               rtol=tol,
                               enable_xla=harness.params["enable_xla"])
Exemple #23
0
 def test_conv_general_dilated(self, harness: primitive_harness.Harness):
     if jtu.device_under_test() == "gpu":
         raise unittest.SkipTest("TODO: test failures on GPU")
     tol = None
     # TODO(bchetioui): significant discrepancies in some float16 cases.
     if harness.params["dtype"] is np.float16:
         tol = 1.
     # TODO(bchetioui): slight occasional discrepancy in float32 cases.
     elif harness.params["dtype"] is np.float32:
         tol = 1e-5
     self.ConvertAndCompare(harness.dyn_fun,
                            *harness.dyn_args_maker(self.rng()),
                            atol=tol,
                            rtol=tol)
Exemple #24
0
  def test_betainc(self, harness: primitive_harness.Harness):
    dtype = harness.params["dtype"]
    # TODO: https://www.tensorflow.org/api_docs/python/tf/math/betainc only
    # supports float32/64 tests.
    # TODO(bchetioui): investigate why the test actually fails in JAX.
    if dtype in [np.float16, dtypes.bfloat16]:
      raise unittest.SkipTest("(b)float16 not implemented in TF")

    tol = None
    if dtype is np.float64:
      tol = 1e-14

    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
                           atol=tol, rtol=tol)
Exemple #25
0
    def test_add_mul(self, harness: primitive_harness.Harness):
        expect_tf_exceptions = False
        dtype = harness.params["dtype"]
        f_name = harness.params["f_jax"].__name__

        if dtype in [np.uint32, np.uint64]:
            # TODO(bchetioui): tf.math.multiply is not defined for the above types.
            expect_tf_exceptions = True
        elif dtype is np.uint16 and f_name == "add":
            # TODO(bchetioui): tf.math.add is defined for the same types as multiply,
            # except uint16.
            expect_tf_exceptions = True
        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               expect_tf_exceptions=expect_tf_exceptions)
Exemple #26
0
    def test_svd(self, harness: primitive_harness.Harness):
        if harness.params["dtype"] in [np.float16, dtypes.bfloat16]:
            if jtu.device_under_test() != "tpu":
                # Does not work in JAX
                with self.assertRaisesRegex(NotImplementedError,
                                            "Unsupported dtype"):
                    harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
                return

        if harness.params["dtype"] in [np.complex64, np.complex128]:
            if jtu.device_under_test() == "tpu":
                # TODO: on JAX on TPU there is no SVD implementation for complex
                with self.assertRaisesRegex(
                        RuntimeError,
                        "Binary op compare with different element types"):
                    harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
                return

        def _custom_assert(r_jax, r_tf, atol=1e-6, rtol=1e-6):
            def _reconstruct_operand(result, is_tf: bool):
                # Reconstructing operand as documented in numpy.linalg.svd (see
                # https://numpy.org/doc/stable/reference/generated/numpy.linalg.svd.html)
                s, u, v = result
                if is_tf:
                    s = s.numpy()
                    u = u.numpy()
                    v = v.numpy()
                U = u[..., :s.shape[-1]]
                V = v[..., :s.shape[-1], :]
                S = s[..., None, :]
                return jnp.matmul(U * S, V), s.shape, u.shape, v.shape

            if harness.params["compute_uv"]:
                r_jax_reconstructed = _reconstruct_operand(r_jax, False)
                r_tf_reconstructed = _reconstruct_operand(r_tf, True)
                self.assertAllClose(r_jax_reconstructed,
                                    r_tf_reconstructed,
                                    atol=atol,
                                    rtol=rtol)
            else:
                self.assertAllClose(r_jax, r_tf, atol=atol, rtol=rtol)

        tol = 1e-4
        custom_assert = partial(_custom_assert, atol=tol, rtol=tol)

        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               atol=tol,
                               rtol=tol,
                               custom_assert=custom_assert,
                               always_custom_assert=True)
Exemple #27
0
    def test_min_max(self, harness: primitive_harness.Harness):
        expect_tf_exceptions = False
        dtype = harness.params["dtype"]

        if dtype in [
                np.bool_, np.int8, np.uint16, np.uint32, np.uint64,
                np.complex64, np.complex128
        ]:
            # TODO(bchetioui): tf.math.maximum and tf.math.minimum are not defined for
            # the above types.
            expect_tf_exceptions = True

        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               expect_tf_exceptions=expect_tf_exceptions)
Exemple #28
0
    def test_select_and_gather_add(self, harness: primitive_harness.Harness):
        dtype = harness.params["dtype"]

        max_bits = 64
        if jtu.device_under_test() == "tpu":
            max_bits = 32

        expect_tf_exceptions = False
        if dtypes.finfo(dtype).bits * 2 > max_bits:
            # TODO: getting an exception "XLA encountered an HLO for which this rewriting is not implemented"
            expect_tf_exceptions = True

        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               expect_tf_exceptions=expect_tf_exceptions)
Exemple #29
0
 def test_dot_general(self, harness: primitive_harness.Harness):
   tol, dtype = None, harness.params["dtype"]
   if dtype == dtypes.bfloat16:
     tol = 0.3
   elif dtype in [np.complex64, np.float32]:
     if jtu.device_under_test() == "tpu":
       tol = 0.1 if dtype == np.float32 else 0.3
     else:
       tol = 1e-5
   elif dtype == np.float16:
     if jtu.device_under_test() == "gpu":
       tol = 0.1
     else:
       tol = 0.01
   self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
                          atol=tol, rtol=tol)
Exemple #30
0
    def test_lu(self, harness: primitive_harness.Harness):
        dtype = harness.params["dtype"]
        if dtype in [np.float16, dtypes.bfloat16]:
            raise unittest.SkipTest(
                f"LU is not implemented in JAX for dtype {dtype}.")
        tol = None
        if dtype in [np.float32, np.complex64]:
            if jtu.device_under_test() == "tpu":
                tol = 0.1
            else:
                tol = 1e-5
        if dtype in [np.float64, np.complex128]:
            tol = 1e-13
        operand, = harness.dyn_args_maker(self.rng())

        def custom_assert(result_jax, result_tf):
            lu, pivots, perm = tuple(map(lambda t: t.numpy(), result_tf))
            batch_dims = operand.shape[:-2]
            m, n = operand.shape[-2], operand.shape[-1]

            def _make_permutation_matrix(perm):
                result = []
                for idx in itertools.product(*map(range, operand.shape[:-1])):
                    result += [0 if c != perm[idx] else 1 for c in range(m)]
                result = np.reshape(np.array(result, dtype=dtype),
                                    [*batch_dims, m, m])
                return result

            k = min(m, n)
            l = jnp.tril(lu, -1)[..., :, :k] + jnp.eye(m, k, dtype=dtype)
            u = jnp.triu(lu)[..., :k, :]
            p_mat = _make_permutation_matrix(perm)

            self.assertArraysEqual(
                lax_linalg.lu_pivots_to_permutation(pivots, m), perm)
            self.assertAllClose(jnp.matmul(p_mat, operand),
                                jnp.matmul(l, u),
                                atol=tol,
                                rtol=tol)

        self.ConvertAndCompare(harness.dyn_fun,
                               operand,
                               atol=tol,
                               rtol=tol,
                               custom_assert=custom_assert,
                               always_custom_assert=True)