Exemple #1
0
    def test_scatter(self, harness: primitive_harness.Harness):
        f_name = harness.params['f_lax'].__name__
        dtype = harness.params['dtype']

        if jtu.device_under_test() == 'tpu':
            if dtype is np.complex64 and f_name in [
                    'scatter_min', 'scatter_max'
            ]:
                raise unittest.SkipTest(
                    f"TODO: complex {f_name} on TPU fails in JAX")

        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()))
Exemple #2
0
 def test_triangular_solve(self, harness: primitive_harness.Harness):
     dtype = harness.params["dtype"]
     if dtype == np.float16 and jtu.device_under_test() == "gpu":
         raise unittest.SkipTest(
             f"Triangular solve is not implemented in JAX for dtype {dtype}"
         )
     atol = rtol = None
     if dtype == np.float32:
         atol = rtol = 1e-5
     self.ConvertAndCompare(harness.dyn_fun,
                            *harness.dyn_args_maker(self.rng()),
                            atol=atol,
                            rtol=rtol)
Exemple #3
0
 def test_fft(self, harness: primitive_harness.Harness):
     if len(harness.params["fft_lengths"]) > 3:
         with self.assertRaisesRegex(RuntimeError,
                                     "FFT only supports ranks 1-3"):
             harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
     elif (jtu.device_under_test() == "tpu"
           and len(harness.params["fft_lengths"]) > 1):
         # TODO(b/140351181): FFT is mostly unimplemented on TPU, even for JAX
         with self.assertRaisesRegex(RuntimeError,
                                     "only 1D FFT is currently supported."):
             harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
     else:
         tol = None
         if jtu.device_under_test() == "gpu":
             if harness.params["dtype"] in jtu.dtypes.boolean:
                 tol = 0.01
             else:
                 tol = 1e-3
         self.ConvertAndCompare(harness.dyn_fun,
                                *harness.dyn_args_maker(self.rng()),
                                atol=tol,
                                rtol=tol)
Exemple #4
0
    def test_qr(self, harness: primitive_harness.Harness):
        # See jax.lib.lapack.geqrf for the list of compatible types
        if (harness.params["dtype"] in [jnp.float32, jnp.float64]
                or harness.params["dtype"] == jnp.float16
                and jtu.device_under_test() == "tpu"):
            self.ConvertAndCompare(harness.dyn_fun,
                                   *harness.dyn_args_maker(self.rng()),
                                   atol=1e-5,
                                   rtol=1e-5)
        elif harness.params["dtype"] in [jnp.complex64, jnp.complex128]:
            if (jtu.device_under_test() == "tpu"
                    and harness.params["dtype"] in [jnp.complex64]):
                raise unittest.SkipTest("QR for c64 not implemented on TPU")

            # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824.
            # - check_compiled=True breaks for complex types;
            # - for now, the performance of the HLO QR implementation called when
            #   compiling with TF is expected to have worse performance than the
            #   custom calls made in JAX.
            self.ConvertAndCompare(harness.dyn_fun,
                                   *harness.dyn_args_maker(self.rng()),
                                   expect_tf_exceptions=True,
                                   atol=1e-5,
                                   rtol=1e-5)
        else:
            # TODO(necula): fix QR bug on TPU
            if (jtu.device_under_test() == "tpu" and harness.params["dtype"]
                    in (jnp.bfloat16, jnp.int32, jnp.uint32)):
                raise unittest.SkipTest(
                    "QR bug on TPU for certain types: error not raised")
            if (jtu.device_under_test() == "tpu"
                    and harness.params["dtype"] in (jnp.bool_, )):
                raise unittest.SkipTest(
                    "QR bug on TPU for certain types: invalid cast")

            expected_error = ValueError if jtu.device_under_test(
            ) == "gpu" else NotImplementedError
            with self.assertRaisesRegex(expected_error, "Unsupported dtype"):
                harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
Exemple #5
0
 def test_prim(self, harness: primitive_harness.Harness):
     limitations = Jax2TfLimitation.limitations_for_harness(harness)
     device = jtu.device_under_test()
     limitations = tuple(
         filter(lambda l: l.filter(device=device, dtype=harness.dtype),
                limitations))
     func_jax = harness.dyn_fun
     args = harness.dyn_args_maker(self.rng())
     enable_xla = harness.params.get("enable_xla", True)
     self.ConvertAndCompare(func_jax,
                            *args,
                            limitations=limitations,
                            enable_xla=enable_xla)
Exemple #6
0
    def test_min_max(self, harness: primitive_harness.Harness):
        # TODO(bchetioui): discrepancies between TF & JAX when comparing with NaN;
        # JAX always returns NaN, while TF returns the value NaN is compared with.
        def custom_assert(result_jax, result_tf):
            mask = np.isnan(result_jax)
            self.assertAllClose(result_jax[~mask], result_tf[~mask])

        # TODO(bchetioui): figure out why we need always_custom_assert=True
        always_custom_assert = True
        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               custom_assert=custom_assert,
                               always_custom_assert=always_custom_assert)
Exemple #7
0
    def test_unary_elementwise(self, harness: primitive_harness.Harness):
        dtype = harness.params["dtype"]
        if dtype is dtypes.bfloat16:
            raise unittest.SkipTest("bfloat16 not implemented")
        arg, = harness.dyn_args_maker(self.rng())
        custom_assert = None
        if harness.params["lax_name"] == "digamma":
            # digamma is not defined at 0 and -1
            def custom_assert(result_jax, result_tf):
                # lax.digamma returns NaN and tf.math.digamma returns inf
                special_cases = (arg == 0.) | (arg == -1.)
                nr_special_cases = np.count_nonzero(special_cases)
                self.assertAllClose(
                    np.full((nr_special_cases, ), dtype(np.nan)),
                    result_jax[special_cases])
                self.assertAllClose(
                    np.full((nr_special_cases, ), dtype(np.inf)),
                    result_tf[special_cases])
                # non-special cases are equal
                self.assertAllClose(result_jax[~special_cases],
                                    result_tf[~special_cases])

        if harness.params["lax_name"] == "erf_inv":
            # TODO(necula): fix bug with erf_inv/f16
            if dtype is np.float16:
                raise unittest.SkipTest("TODO: fix bug")
            # erf_inf is not defined for arg <= -1 or arg >= 1
            def custom_assert(result_jax, result_tf):  # noqa: F811
                # for arg < -1 or arg > 1
                # lax.erf_inf returns NaN; tf.math.erf_inv return +/- inf
                special_cases = (arg < -1.) | (arg > 1.)
                nr_special_cases = np.count_nonzero(special_cases)
                self.assertAllClose(
                    np.full((nr_special_cases, ), dtype(np.nan)),
                    result_jax[special_cases])
                signs = np.where(arg[special_cases] < 0., -1., 1.)
                self.assertAllClose(
                    np.full((nr_special_cases, ), signs * dtype(np.inf)),
                    result_tf[special_cases])
                # non-special cases are equal
                self.assertAllClose(result_jax[~special_cases],
                                    result_tf[~special_cases])

        atol = None
        if jtu.device_under_test() == "gpu":
            # TODO(necula): revisit once we fix the GPU tests
            atol = 1e-3
        self.ConvertAndCompare(harness.dyn_fun,
                               arg,
                               custom_assert=custom_assert,
                               atol=atol)
Exemple #8
0
  def test_unary_elementwise(self, harness: primitive_harness.Harness):
    dtype = harness.params["dtype"]
    lax_name = harness.params["lax_name"]
    arg, = harness.dyn_args_maker(self.rng())
    custom_assert = None
    if lax_name == "digamma":
      # TODO(necula): fix bug with digamma/(f32|f16) on TPU
      if dtype in [np.float16, np.float32] and jtu.device_under_test() == "tpu":
        raise unittest.SkipTest("TODO: fix bug: nan vs not-nan")

      # In the bfloat16 case, TF and lax both return NaN in undefined cases.
      if not dtype is dtypes.bfloat16:
        # digamma is not defined at 0 and -1
        def custom_assert(result_jax, result_tf):
          # lax.digamma returns NaN and tf.math.digamma returns inf
          special_cases = (arg == 0.) | (arg == -1.)
          nr_special_cases = np.count_nonzero(special_cases)
          self.assertAllClose(np.full((nr_special_cases,), dtype(np.nan)),
                              result_jax[special_cases])
          self.assertAllClose(np.full((nr_special_cases,), dtype(np.inf)),
                              result_tf[special_cases])
          # non-special cases are equal
          self.assertAllClose(result_jax[~ special_cases],
                              result_tf[~ special_cases])
    if lax_name == "erf_inv":
      # TODO(necula): fix erf_inv bug on TPU
      if jtu.device_under_test() == "tpu":
        raise unittest.SkipTest("erf_inv bug on TPU: nan vs non-nan")
      # TODO: investigate: in the (b)float16 cases, TF and lax both return the same
      # result in undefined cases.
      if not dtype in [np.float16, dtypes.bfloat16]:
        # erf_inv is not defined for arg <= -1 or arg >= 1
        def custom_assert(result_jax, result_tf):  # noqa: F811
          # for arg < -1 or arg > 1
          # lax.erf_inv returns NaN; tf.math.erf_inv return +/- inf
          special_cases = (arg < -1.) | (arg > 1.)
          nr_special_cases = np.count_nonzero(special_cases)
          self.assertAllClose(np.full((nr_special_cases,), dtype(np.nan)),
                              result_jax[special_cases])
          signs = np.where(arg[special_cases] < 0., -1., 1.)
          self.assertAllClose(np.full((nr_special_cases,), signs * dtype(np.inf)),
                              result_tf[special_cases])
          # non-special cases are equal
          self.assertAllClose(result_jax[~ special_cases],
                              result_tf[~ special_cases])
    atol = None
    if jtu.device_under_test() == "gpu":
      # TODO(necula): revisit once we fix the GPU tests
      atol = 1e-3
    self.ConvertAndCompare(harness.dyn_fun, arg, custom_assert=custom_assert,
                           atol=atol)
Exemple #9
0
  def test_betainc(self, harness: primitive_harness.Harness):
    dtype = harness.params["dtype"]
    # TODO: https://www.tensorflow.org/api_docs/python/tf/math/betainc only
    # supports float32/64 tests.
    # TODO(bchetioui): investigate why the test actually fails in JAX.
    if dtype in [np.float16, dtypes.bfloat16]:
      raise unittest.SkipTest("(b)float16 not implemented in TF")

    tol = None
    if dtype is np.float64:
      tol = 1e-14

    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
                           atol=tol, rtol=tol)
Exemple #10
0
    def test_conv_general_dilated(self, harness: primitive_harness.Harness):
        dtype, device = harness.params["dtype"], jtu.device_under_test()
        if device == "gpu" and dtype in [np.complex64, np.complex128]:
            raise unittest.SkipTest("TODO: crash on GPU in TF")

        tol = None
        if device == "gpu":
            tol = 1e-4
        elif device == "tpu":
            tol = 1e-3
        # TODO(bchetioui): significant discrepancies in some float16 cases.
        if dtype == np.float16:
            tol = 1.
        # TODO(bchetioui): slight occasional discrepancy in float32 cases.
        elif dtype == np.float32:
            tol = 0.5 if device == "tpu" else (
                1e-3 if device == "gpu" else 1e-4)
        elif dtype == np.complex64 and device == "tpu":
            tol = 0.1
        # TODO(bchetioui): slight discrepancy when going through the path using
        # tf.nn.convolution.
        elif dtype == np.float64 and device == "cpu":
            tol = 1e-13

        # TODO(bchetioui): unidentified bug in compiled mode. The test that fails is
        #
        # test_conv_general_dilated_tf_conversion_path_3d_lhs=float32[1,4,28,28,1]_rhs=float32[2,3,3,1,16]_windowstrides=(1,1,1)_padding=VALID_lhsdilation=(1,1,1)_rhsdilation=(1,1,2)_dimensionnumbers=('NDHWC','DHWIO','NDHWC')_featuregroupcount=1_batchgroupcount=1_precision=None_enablexla=False
        #
        # with the following assertion error in TensorFlowTrace.process_primitive:
        #
        # AssertionError: conv_general_dilated: out.aval = ShapedArray(float32[1,3,24,26,16]); expected ShapedArray(float32[1,3,26,24,16])
        #
        # Deactivating this assertion is enough to pass the test, which suggests
        # that the end shape is indeed the correct one (i.e. (1,3,26,24,16)).
        # Further investigation is required to really understand this behavior,
        # which we have not managed to reproduce as a pure TF test.
        #
        # This bug is low priority since it only occurs when using a non-TFXLA
        # conversion path in compiled mode, i.e. in a context where using the
        # TFXLA path is possible.
        if harness.name == "_tf_conversion_path_3d_lhs=float32[1,4,28,28,1]_rhs=float32[2,3,3,1,16]_windowstrides=(1,1,1)_padding=VALID_lhsdilation=(1,1,1)_rhsdilation=(1,1,2)_dimensionnumbers=('NDHWC','DHWIO','NDHWC')_featuregroupcount=1_batchgroupcount=1_precision=None_enablexla=False":
            raise unittest.SkipTest(
                "TODO: known but unidentified bug in compiled "
                "mode")

        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               atol=tol,
                               rtol=tol,
                               enable_xla=harness.params["enable_xla"])
Exemple #11
0
 def test_conv_general_dilated(self, harness: primitive_harness.Harness):
     if jtu.device_under_test() == "gpu":
         raise unittest.SkipTest("TODO: test failures on GPU")
     tol = None
     # TODO(bchetioui): significant discrepancies in some float16 cases.
     if harness.params["dtype"] is np.float16:
         tol = 1.
     # TODO(bchetioui): slight occasional discrepancy in float32 cases.
     elif harness.params["dtype"] is np.float32:
         tol = 1e-5
     self.ConvertAndCompare(harness.dyn_fun,
                            *harness.dyn_args_maker(self.rng()),
                            atol=tol,
                            rtol=tol)
Exemple #12
0
 def test_top_k(self, harness: primitive_harness.Harness):
     if (harness.params["k"] > harness.params["shape"][-1]
             or harness.params["k"] < 0):
         with self.assertRaisesRegex(ValueError,
                                     "k argument to top_k must be"):
             harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
     # TODO: figure out what's up with bfloat16
     elif harness.params["dtype"] is dtypes.bfloat16:
         raise unittest.SkipTest("bfloat16 support not implemented")
     elif harness.params["dtype"] in jtu.dtypes.complex:
         # TODO(necula): fix top_k complex bug on TPU
         if jtu.device_under_test() == "tpu":
             raise unittest.SkipTest(
                 "top_k complex on TPU raises different error")
         with self.assertRaisesRegex(RuntimeError,
                                     "Unimplemented: complex comparison"):
             harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
     # TODO: TF and JAX sort [inf, nan] differently.
     elif harness.name.startswith("nan_"):
         raise unittest.SkipTest("inconsistent [nan, inf] sorting")
     else:
         self.ConvertAndCompare(harness.dyn_fun,
                                *harness.dyn_args_maker(self.rng()))
Exemple #13
0
 def test_qr(self, harness: primitive_harness.Harness):
     # See jax.lib.lapack.geqrf for the list of compatible types
     if harness.params["dtype"] in [jnp.float32, jnp.float64]:
         self.ConvertAndCompare(harness.dyn_fun,
                                *harness.dyn_args_maker(self.rng()),
                                atol=1e-5,
                                rtol=1e-5)
     elif harness.params["dtype"] in [jnp.complex64, jnp.complex128]:
         # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824.
         # - check_compiled=True breaks for complex types;
         # - for now, the performance of the HLO QR implementation called when
         #   compiling with TF is expected to have worse performance than the
         #   custom calls made in JAX.
         self.ConvertAndCompare(harness.dyn_fun,
                                *harness.dyn_args_maker(self.rng()),
                                expect_tf_exceptions=True,
                                atol=1e-5,
                                rtol=1e-5)
     else:
         expected_error = ValueError if jtu.device_under_test(
         ) == "gpu" else NotImplementedError
         with self.assertRaisesRegex(expected_error, "Unsupported dtype"):
             harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
Exemple #14
0
    def test_select_and_gather_add(self, harness: primitive_harness.Harness):
        dtype = harness.params["dtype"]

        max_bits = 64
        if jtu.device_under_test() == "tpu":
            max_bits = 32

        expect_tf_exceptions = False
        if dtypes.finfo(dtype).bits * 2 > max_bits:
            # TODO: getting an exception "XLA encountered an HLO for which this rewriting is not implemented"
            expect_tf_exceptions = True

        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               expect_tf_exceptions=expect_tf_exceptions)
Exemple #15
0
    def test_add_mul(self, harness: primitive_harness.Harness):
        expect_tf_exceptions = False
        dtype = harness.params["dtype"]
        f_name = harness.params["f_jax"].__name__

        if dtype in [np.uint32, np.uint64]:
            # TODO(bchetioui): tf.math.multiply is not defined for the above types.
            expect_tf_exceptions = True
        elif dtype is np.uint16 and f_name == "add":
            # TODO(bchetioui): tf.math.add is defined for the same types as multiply,
            # except uint16.
            expect_tf_exceptions = True
        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               expect_tf_exceptions=expect_tf_exceptions)
Exemple #16
0
    def test_top_k(self, harness: primitive_harness.Harness):
        custom_assert = None
        k, dtype = harness.params["k"], harness.params["dtype"]
        if k > harness.params["shape"][-1] or k < 0:
            with self.assertRaisesRegex(ValueError,
                                        "k argument to top_k must be"):
                harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
            return
        if dtype in jtu.dtypes.complex:
            # TODO(necula): fix top_k complex bug on TPU
            if jtu.device_under_test() == "tpu":
                raise unittest.SkipTest(
                    "top_k complex on TPU raises different error")
            with self.assertRaisesRegex(RuntimeError,
                                        "Unimplemented: complex comparison"):
                harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
            return
        if dtype in jtu.dtypes.all_inexact:

            def custom_assert(result_jax, result_tf):
                assert len(result_jax) == len(result_tf)
                # TODO: TF and JAX sort [inf, nan] differently.
                first_arr_jax, first_arr_tf = result_jax[0], result_tf[
                    0].numpy()
                if np.all(first_arr_jax == first_arr_tf):
                    for arr_jax, arr_tf in zip(result_jax, result_tf):
                        self.assertArraysEqual(arr_jax, arr_tf)
                else:
                    mask_jax, mask_tf = np.isnan(first_arr_jax), np.isnan(
                        first_arr_tf)
                    self.assertArraysEqual(first_arr_jax[~mask_jax],
                                           first_arr_tf[~mask_tf])

        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               custom_assert=custom_assert)
Exemple #17
0
    def test_min_max(self, harness: primitive_harness.Harness):
        expect_tf_exceptions = False
        dtype = harness.params["dtype"]

        if dtype in [
                np.bool_, np.int8, np.uint16, np.uint32, np.uint64,
                np.complex64, np.complex128
        ]:
            # TODO(bchetioui): tf.math.maximum and tf.math.minimum are not defined for
            # the above types.
            expect_tf_exceptions = True

        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               expect_tf_exceptions=expect_tf_exceptions)
Exemple #18
0
 def test_dot_general(self, harness: primitive_harness.Harness):
   tol, dtype = None, harness.params["dtype"]
   if dtype == dtypes.bfloat16:
     tol = 0.3
   elif dtype in [np.complex64, np.float32]:
     if jtu.device_under_test() == "tpu":
       tol = 0.1 if dtype == np.float32 else 0.3
     else:
       tol = 1e-5
   elif dtype == np.float16:
     if jtu.device_under_test() == "gpu":
       tol = 0.1
     else:
       tol = 0.01
   self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
                          atol=tol, rtol=tol)
Exemple #19
0
    def test_lu(self, harness: primitive_harness.Harness):
        dtype = harness.params["dtype"]
        if dtype in [np.float16, dtypes.bfloat16]:
            raise unittest.SkipTest(
                f"LU is not implemented in JAX for dtype {dtype}.")
        tol = None
        if dtype in [np.float32, np.complex64]:
            if jtu.device_under_test() == "tpu":
                tol = 0.1
            else:
                tol = 1e-5
        if dtype in [np.float64, np.complex128]:
            tol = 1e-13
        operand, = harness.dyn_args_maker(self.rng())

        def custom_assert(result_jax, result_tf):
            lu, pivots, perm = tuple(map(lambda t: t.numpy(), result_tf))
            batch_dims = operand.shape[:-2]
            m, n = operand.shape[-2], operand.shape[-1]

            def _make_permutation_matrix(perm):
                result = []
                for idx in itertools.product(*map(range, operand.shape[:-1])):
                    result += [0 if c != perm[idx] else 1 for c in range(m)]
                result = np.reshape(np.array(result, dtype=dtype),
                                    [*batch_dims, m, m])
                return result

            k = min(m, n)
            l = jnp.tril(lu, -1)[..., :, :k] + jnp.eye(m, k, dtype=dtype)
            u = jnp.triu(lu)[..., :k, :]
            p_mat = _make_permutation_matrix(perm)

            self.assertArraysEqual(
                lax_linalg.lu_pivots_to_permutation(pivots, m), perm)
            self.assertAllClose(jnp.matmul(p_mat, operand),
                                jnp.matmul(l, u),
                                atol=tol,
                                rtol=tol)

        self.ConvertAndCompare(harness.dyn_fun,
                               operand,
                               atol=tol,
                               rtol=tol,
                               custom_assert=custom_assert,
                               always_custom_assert=True)
Exemple #20
0
 def test_cumreduce(self, harness: primitive_harness.Harness):
     f_jax, dtype = harness.params["f_jax"], harness.params["dtype"]
     dut = jtu.device_under_test()
     if (dtype == np.complex64 and f_jax in [
             lax_control_flow.cummin, lax_control_flow.cummax,
             lax_control_flow.cumprod, lax_control_flow.cumsum
     ] and dut == "tpu"):
         raise unittest.SkipTest(
             "TODO(bchetioui): cum{min,max,prod,sum} fails "
             "in JAX for complex64 on TPU")
     tol = None
     if f_jax == lax_control_flow.cumsum:
         tol = 0.1 if dtype == np.float16 else (
             0.5 if dtype == dtypes.bfloat16 else tol)
     self.ConvertAndCompare(harness.dyn_fun,
                            *harness.dyn_args_maker(self.rng()),
                            atol=tol,
                            rtol=tol)
Exemple #21
0
    def test_reduce_window(self, harness: primitive_harness.Harness):
        computation = harness.params['computation'].__name__
        init_value = harness.params['init_value']
        dtype = harness.params['dtype']

        safe_computations = [('sum', dtype(0))]

        if dtype in jtu.dtypes.all_floating:
            # Only in this case, np.inf can be casted safely to a meaningful value.
            safe_computations += [('max', dtype(-np.inf)),
                                  ('min', dtype(np.inf))]

        if (computation, init_value) not in safe_computations:
            raise unittest.SkipTest(
                'TODO: only specific instances of max/min/sum are supported for now.'
            )

        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()))
Exemple #22
0
  def test_jax_implemented(self, harness: primitive_harness.Harness):
    """Runs all harnesses just with JAX to verify the jax_unimplemented field.
    """
    jax_unimpl = [l for l in harness.jax_unimplemented
                  if l.filter(jtu.device_under_test())]
    try:
      harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
    except Exception as e:
      if jax_unimpl:
        logging.info(
          f"Found expected JAX error {e} with expected JAX limitations: "
          f"{[u.description for u in jax_unimpl]} in harness {harness.fullname}")
        return
      else:
        raise e

    if jax_unimpl:
      msg = ("Found no JAX error but expected JAX limitations: "
             f"{[u.description for u in jax_unimpl]} in harness: {harness.fullname}")
      logging.warning(msg)
Exemple #23
0
    def test_prim_vmap(self, harness: primitive_harness.Harness):
        if harness.group_name in _VMAP_NOT_POLY_YET:
            raise unittest.SkipTest(
                f"TODO: vmap({harness.group_name}) not yet supported")
        func_jax = harness.dyn_fun
        args = harness.dyn_args_maker(self.rng())
        if len(args) == 0:
            # vmap not defined for functions with no args
            return

        res_jax = func_jax(*args)

        # Replicate all arguments
        batch_size = 3
        batched_args = [np.stack([a] * batch_size) for a in args]
        func_jax_vmap = jax.vmap(func_jax, in_axes=0, out_axes=0)
        # Check that batching works
        res_jax_vmap = func_jax_vmap(*batched_args)

        def arr_to_shape_spec(a):
            return "b, " + ", ".join(str(d) for d in a.shape)

        func_jax_vmap_polymorphic_shapes = jax.tree_map(
            arr_to_shape_spec, tuple(args))

        def arr_to_tf_tensor_spec(a):
            return tf.TensorSpec((None, ) + a.shape, a.dtype)

        func_jax_vmap_input_signature = jax.tree_map(arr_to_tf_tensor_spec,
                                                     tuple(args))
        func_jax_vmap_output_signature = jax.tree_map(arr_to_tf_tensor_spec,
                                                      res_jax)
        f_tf = self.CheckShapePolymorphism(
            func_jax_vmap,
            input_signature=func_jax_vmap_input_signature,
            polymorphic_shapes=func_jax_vmap_polymorphic_shapes,
            expected_output_signature=func_jax_vmap_output_signature)

        limitations = _get_jax2tf_limitations(jtu.device_under_test(), harness)
        if any([l.custom_assert or l.skip_comparison for l in limitations]):
            self.assertAllClose(res_jax_vmap, f_tf(*batched_args))
  def test_prim(self, harness: Harness):
    args = harness.dyn_args_maker(self.rng())
    poly_axes = harness.params["poly_axes"]  # type: Sequence[Sequence[int]]
    assert len(args) == len(poly_axes)
    # Make the polymorphic_shapes and input_signature
    polymorphic_shapes: List[Optional[str]] = []
    input_signature: List[tf.TensorSpec] = []
    for arg, poly_axis in zip(args, poly_axes):
      if poly_axis is None:
        polymorphic_shapes.append(None)
        input_signature.append(tf.TensorSpec(np.shape(arg), arg.dtype))
      else:
        def make_arg_polymorphic_shapes(poly_axis: Sequence[int]) -> Tuple[str, tf.TensorSpec]:
          idx = -1
          dims = []
          tensorspec_dims: List[Optional[int]] = []
          for i, d in enumerate(arg.shape):
            if i in poly_axis:
              idx += 1
              dims.append(f"b{idx}")
              tensorspec_dims.append(None)
            else:
              dims.append(str(d))
              tensorspec_dims.append(d)
          return ", ".join(dims), tf.TensorSpec(tensorspec_dims, arg.dtype)

        arg_polymorphic_shapes, arg_tensorspec = make_arg_polymorphic_shapes(poly_axis)
        polymorphic_shapes.append(arg_polymorphic_shapes)
        input_signature.append(arg_tensorspec)

    res_jax = harness.dyn_fun(*args)
    f_tf = self.CheckShapePolymorphism(
        harness.dyn_fun,
        input_signature=input_signature,
        polymorphic_shapes=polymorphic_shapes,
        expected_output_signature=None)

    if harness.params["check_result"]:
      tol = harness.params["tol"]
      self.assertAllClose(res_jax, f_tf(*args), atol=tol, rtol=tol)
Exemple #25
0
 def test_sort(self, harness: primitive_harness.Harness):
   if harness.params["dtype"] is dtypes.bfloat16 or harness.params["dtype"] in jtu.dtypes.complex:
     # TODO: implement bfloat16/complex support in XlaSort
     raise unittest.SkipTest("bfloat16/complex support not implemented")
   if harness.params["dtype"] is dtypes.bool_ and len(harness.arg_descriptors) == 4:
     # TODO: _sort uses tfxla.key_value_sort to handle 2 operandes, but the operation is not compatible with boolean keys.
     raise unittest.SkipTest("boolean key key value sort not implemented")
   if harness.params["is_stable"]:
     # TODO: implement stable sort support in XlaSort
     raise unittest.SkipTest("stable sort not implemented")
   if harness.params["dimension"] != len(harness.params["shape"]) - 1:
     # TODO: implement sort on all axes
     raise unittest.SkipTest("conversion not implemented for axis != -1")
   if len(harness.arg_descriptors) > 4:
     # TODO: implement variable number of operands to XlaSort
     raise unittest.SkipTest("conversion not implemented for #operands > 2")
   if (jtu.device_under_test() == "gpu" and
       len(harness.arg_descriptors) == 4 and
       not harness.params["is_stable"]):
     # TODO: fix the TF GPU test
     raise unittest.SkipTest("GPU tests are running TF on CPU")
   self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
Exemple #26
0
  def test_cholesky(self, harness: primitive_harness.Harness):
    dtype = harness.params["dtype"]
    if dtype in [dtypes.bfloat16, np.float16]:
      raise unittest.SkipTest("Cholesky decomposition not supported for "
                              "(b)float16 in JAX.")
    operand = harness.dyn_args_maker(self.rng())[0]
    operand = np.matmul(operand, jnp.conj(np.swapaxes(operand, -1, -2)))
    tol = None
    # TODO(bchetioui): very high discrepancy in the float32/complex64 case
    if dtype in [np.float32, np.complex64]:
      tol = 1e-2
    # TODO(bchetioui): also high discrepancy in the float64/complex128 case
    elif dtype in [np.float64, np.complex128]:
      tol = 1e-11

    def custom_assert(result_jax, result_tf):
      # cholesky_p returns garbage in the strictly upper triangular part of the
      # result, so we can safely ignore that part.
      self.assertAllClose(jnp.tril(result_jax), result_tf, atol=tol)

    self.ConvertAndCompare(harness.dyn_fun, operand,
                           custom_assert=custom_assert,
                           always_custom_assert=True)
Exemple #27
0
    def test_reduce_window(self, harness: primitive_harness.Harness):
        f_name = harness.params['computation'].__name__
        dtype = harness.params['dtype']

        expect_tf_exceptions = False

        if (jtu.device_under_test() == 'tpu' and dtype is np.complex64):
            raise unittest.SkipTest(
                'TODO: JAX reduce_window on TPU does not handle complex64')

        if ((f_name == 'min' or f_name == 'max') and dtype not in [
                dtypes.bfloat16, np.float16, np.float32, np.float64, np.uint8,
                np.int16, np.int32, np.int64
        ]):
            # See https://www.tensorflow.org/api_docs/python/tf/math/minimum for a list of
            # the types supported by tf.math.minimum/tf.math.maximum.
            expect_tf_exceptions = True
        elif (f_name == 'add' and dtype not in [
                dtypes.bfloat16, np.float16, np.float32, np.float64, np.uint8,
                np.int8, np.int16, np.int32, np.int64, np.complex64,
                np.complex128
        ]):
            # See https://www.tensorflow.org/api_docs/python/tf/math/add for a list of the
            # types supported by tf.math.add.
            expect_tf_exceptions = True
        elif (f_name == 'mul' and dtype not in [
                dtypes.bfloat16, np.float16, np.float32, np.float64, np.uint8,
                np.int8, np.uint16, np.int16, np.int32, np.int64, np.complex64,
                np.complex128
        ]):
            # See https://www.tensorflow.org/api_docs/python/tf/math/multiply for a list of
            # the types supported by tf.math.multiply.
            expect_tf_exceptions = True

        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               expect_tf_exceptions=expect_tf_exceptions)
Exemple #28
0
    def test_conv_general_dilated(self, harness: primitive_harness.Harness):
        if jtu.device_under_test() == "gpu" and harness.params["dtype"] in [
                np.complex64, np.complex128
        ]:
            raise unittest.SkipTest("TODO: crash on GPU in TF")

        tol = None
        if jtu.device_under_test() == "gpu":
            tol = 1e-4
        elif jtu.device_under_test() == "tpu":
            tol = 1e-3
        # TODO(bchetioui): significant discrepancies in some float16 cases.
        if harness.params["dtype"] == np.float16:
            tol = 1.
        # TODO(bchetioui): slight occasional discrepancy in float32 cases.
        elif harness.params["dtype"] == np.float32:
            tol = 0.5 if jtu.device_under_test() == "tpu" else 1e-4
        elif harness.params["dtype"] == np.complex64 and jtu.device_under_test(
        ) == "tpu":
            tol = 0.1
        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               atol=tol,
                               rtol=tol)
Exemple #29
0
  def test_qr(self, harness: primitive_harness.Harness):
    # See jax.lib.lapack.geqrf for the list of compatible types

    dtype = harness.params["dtype"]
    dut = jtu.device_under_test()
    # These cases are not implemented in JAX
    if dtype in (jtu.dtypes.all_integer + [jnp.bfloat16]):
      unimplemented_jax = True
    elif dtype is np.complex64 and dut == "tpu":
      unimplemented_jax = True
    elif dtype is np.float16 and dut in ("cpu", "gpu"):
      unimplemented_jax = True
    else:
      unimplemented_jax = False

    if unimplemented_jax:
      raise unittest.SkipTest(f"QR not implemented in JAX for {dtype} on {dut}")

    # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824.
    # - for now, the performance of the HLO QR implementation called when
    #   compiling with TF is expected to have worse performance than the
    #   custom calls made in JAX.
    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
                           atol=1e-5, rtol=1e-5)
Exemple #30
0
  def test_conv_general_dilated(self, harness: primitive_harness.Harness):
    dtype, device = harness.params["dtype"], jtu.device_under_test()
    if device == "gpu" and dtype in [np.complex64, np.complex128]:
      raise unittest.SkipTest("TODO: crash on GPU in TF")

    tol = None
    if device == "gpu":
      tol = 1e-4
    elif device == "tpu":
      tol = 1e-3
    # TODO(bchetioui): significant discrepancies in some float16 cases.
    if dtype == np.float16:
      tol = 1.
    # TODO(bchetioui): slight occasional discrepancy in float32 cases.
    elif dtype == np.float32:
      tol = 0.5 if device == "tpu" else (1e-3 if device == "gpu" else 1e-4)
    elif dtype == np.complex64 and device == "tpu":
      tol = 0.1
    # TODO(bchetioui): slight discrepancy when going through the path using
    # tf.nn.convolution.
    elif dtype == np.float64 and device == "cpu":
      tol = 1e-13
    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
                           atol=tol, rtol=tol)