Пример #1
0
 def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory):
   rng = rng_factory(self.rng())
   if jtu.device_under_test() == "tpu" and op is lax.mul:
     raise SkipTest("unimplemented case")
   tol = {dtypes.bfloat16: 2e-1, onp.float16: 1e-1, onp.float32: 1e-1,
          onp.float64: 1e-3, onp.complex64: 1e-1}
   operand = rng(shape, dtype)
   init_val = onp.asarray(init_val, dtype=dtype)
   reduce = lambda operand: lax.reduce(operand, init_val, op, dims)
   eps = (1.0 if dtypes.finfo(dtype).bits == 16 and op is lax.add else
          1e-1 if dtype == dtypes.bfloat16 else
          1e-2 if dtypes.finfo(dtype).bits == 32 else None)
   check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps)
Пример #2
0
    def testConvGeneralDilatedBatching(self, lhs_shape, rhs_shape, dtype,
                                       strides, padding, lhs_dil, rhs_dil,
                                       dimension_numbers, perms,
                                       feature_group_count, batch_group_count,
                                       lhs_bdim, rhs_bdim):
        rng = jtu.rand_default(self.rng())
        tol = 1e-1 if dtypes.finfo(dtype).bits <= 32 else 1e-3

        # permute shapes to match dim_spec, scale by feature_group_count
        lhs_perm, rhs_perm = perms
        lhs_shape = list(np.take(lhs_shape, lhs_perm))
        rhs_shape = list(np.take(rhs_shape, rhs_perm))

        conv = partial(lax.conv_general_dilated,
                       window_strides=strides,
                       padding=padding,
                       lhs_dilation=lhs_dil,
                       rhs_dilation=rhs_dil,
                       dimension_numbers=dimension_numbers,
                       feature_group_count=feature_group_count,
                       batch_group_count=batch_group_count,
                       precision=lax.Precision.HIGHEST)
        self._CheckBatching(conv,
                            5, (lhs_bdim, rhs_bdim), (lhs_shape, rhs_shape),
                            (dtype, dtype),
                            rng,
                            rtol=tol,
                            atol=tol)
Пример #3
0
def eig_abstract_eval(operand, *, compute_left_eigenvectors,
                      compute_right_eigenvectors):
    if isinstance(operand, ShapedArray):
        if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
            raise ValueError(
                "Argument to nonsymmetric eigendecomposition must have "
                "shape [..., n, n], got shape {}".format(operand.shape))

        batch_dims = operand.shape[:-2]
        n = operand.shape[-1]
        dtype = np.complex64 if dtypes.finfo(
            operand.dtype).bits == 32 else np.complex128
        dtype = dtypes.canonicalize_dtype(dtype)
        vl = vr = ShapedArray(batch_dims + (n, n), dtype)
        w = ShapedArray(batch_dims + (n, ), dtype)
    else:
        raise NotImplementedError

    output = [w]
    if compute_left_eigenvectors:
        output.append(vl)
    if compute_right_eigenvectors:
        output.append(vr)

    return tuple(output)
Пример #4
0
    def testMapCoordinates(self, shape, dtype, coords_shape, coords_dtype,
                           order, mode, cval, impl, round_, rng_factory):
        def args_maker():
            x = np.arange(prod(shape), dtype=dtype).reshape(shape)
            coords = [(size - 1) * rng(coords_shape, coords_dtype)
                      for size in shape]
            if round_:
                coords = [c.round().astype(int) for c in coords]
            return x, coords

        rng = rng_factory(self.rng())
        lsp_op = lambda x, c: lsp_ndimage.map_coordinates(
            x, c, order=order, mode=mode, cval=cval)
        impl_fun = (osp_ndimage.map_coordinates
                    if impl == "original" else _fixed_ref_map_coordinates)
        osp_op = lambda x, c: impl_fun(x, c, order=order, mode=mode, cval=cval)
        if dtype in float_dtypes:
            epsilon = max([
                dtypes.finfo(dtypes.canonicalize_dtype(d)).eps
                for d in [dtype, coords_dtype]
            ])
            self._CheckAgainstNumpy(osp_op,
                                    lsp_op,
                                    args_maker,
                                    tol=100 * epsilon)
        else:
            self._CheckAgainstNumpy(osp_op, lsp_op, args_maker, tol=0)
Пример #5
0
  def test_select_and_gather_add(self, harness: primitive_harness.Harness):
    dtype = harness.params["dtype"]

    if dtype is dtypes.bfloat16:
      raise unittest.SkipTest("bfloat16 not implemented")

    max_bits = 64
    if jtu.device_under_test() == "tpu":
      max_bits = 32

    if dtypes.finfo(dtype).bits * 2 > max_bits:
      with self.assertRaisesRegex(BaseException, "XLA encountered an HLO for which this rewriting is not implemented"):
        self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
    else:
      self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
Пример #6
0
    def test_select_and_gather_add(self, harness: primitive_harness.Harness):
        dtype = harness.params["dtype"]

        max_bits = 64
        if jtu.device_under_test() == "tpu":
            max_bits = 32

        expect_tf_exceptions = False
        if dtypes.finfo(dtype).bits * 2 > max_bits:
            # TODO: getting an exception "XLA encountered an HLO for which this rewriting is not implemented"
            expect_tf_exceptions = True

        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               expect_tf_exceptions=expect_tf_exceptions)
Пример #7
0
def eig_abstract_eval(operand):
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
      raise ValueError("Argument to nonsymmetric eigendecomposition must have "
                       "shape [..., n, n], got shape {}".format(operand.shape))

    batch_dims = operand.shape[:-2]
    n = operand.shape[-1]
    dtype = onp.complex64 if dtypes.finfo(operand.dtype).bits == 32 else onp.complex128
    dtype = dtypes.canonicalize_dtype(dtype)
    vl = vr = ShapedArray(batch_dims + (n, n), dtype)
    w = ShapedArray(batch_dims + (n,), dtype)
  else:
    raise NotImplementedError
  return w, vl, vr
Пример #8
0
def categorize(prim: core.Primitive, *args, **kwargs) \
    -> List[Limitation]:
  """
  Given a primitive and a set of parameters one would like to pass to it,
  categorize identifies the potential limitations the call would encounter when
  converted to TF through jax2tf.

  Args:
    prim: the primitive to call.
    args: the arguments to pass to prim.
    kwargs: the keyword arguments to pass to prim.

  Returns:
    A list of limitations
  """
  limitations: List[Limitation] = []
  all_devices = ["CPU", "GPU", "TPU"]

  def _report_failure(error_type: str, msg: str,
                      affected_dtype: Optional[NpDType] = None,
                      devs: Sequence[str] = all_devices) -> None:
    affected_dtypes = (
      tuple([affected_dtype]) if affected_dtype is not None else tuple())
    limitations.append(Limitation(prim.name, error_type, msg,
                                  affected_dtypes, tuple(devs)))

  def tf_unimpl(np_dtype: Optional[NpDType] = None,
                additional_msg: Optional[str] = None,
                devs: Sequence[str] = all_devices) -> None:

    missing_tf_support = "Missing TF support"
    msg = "Primitive is unimplemented"
    if additional_msg:
      msg += '; ' + additional_msg
    _report_failure(missing_tf_support, msg, np_dtype, devs=devs)

  def _to_np_dtype(dtype) -> NpDType:
    try:
      dtype = to_jax_dtype(dtype)
    except:
      pass
    return np.dtype(dtype)

  if prim in [lax.min_p, lax.max_p]:
    np_dtype = _to_np_dtype(args[0].dtype)
    if np_dtype in [np.bool_, np.int8, np.uint16, np.uint32, np.uint64,
                    np.complex64, np.complex128]:
      tf_unimpl(np_dtype)

  if prim in [lax.rem_p, lax.atan2_p]:
    np_dtype = _to_np_dtype(args[0].dtype)
    if np_dtype in [np.float16, dtypes.bfloat16]:
      # b/158006398: TF kernels are missing for 'rem' and 'atan2'
      tf_unimpl(np_dtype)

  if prim is lax.nextafter_p:
    np_dtype = _to_np_dtype(args[0].dtype)
    if np_dtype in [np.float16, dtypes.bfloat16]:
      tf_unimpl(np_dtype)

  if prim is lax_linalg.cholesky_p:
    np_dtype = _to_np_dtype(args[0].dtype)
    if np_dtype in [np.complex64, np.complex128]:
      # See https://github.com/google/jax/pull/3775#issuecomment-659407824;
      # experimental_compile=True breaks for complex types.
      tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled "
                                          "mode (experimental_compile=True))"))

  if prim is lax_linalg.qr_p:
    np_dtype = _to_np_dtype(args[0].dtype)
    if np_dtype in [np.complex64, np.complex128]:
      # See https://github.com/google/jax/pull/3775#issuecomment-659407824;
      # experimental_compile=True breaks for complex types.
      tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled "
                                          "mode (experimental_compile=True))"))

  if prim is lax_linalg.svd_p:
    np_dtype = _to_np_dtype(args[0].dtype)
    if np_dtype in [dtypes.bfloat16]:
      # TODO: SVD on TPU for bfloat16 seems to work for JAX but fails for TF
      tf_unimpl(np_dtype, devs=["TPU"])
    elif np_dtype in [np.complex64, np.complex128]:
      # TODO: on CPU and GPU "No registered 'Svd' OpKernel for XLA_CPU_JIT
      # devices". Works on JAX because JAX uses a custom implementation.
      # There exists a XlaSvd operation that could replace tf.linalg.svd in
      # these cases but complex numbers support is not implemented in XLA yet,
      # and the API of XlaSvd is different than the one in JAX/TF, which also
      # limits its useability (e.g. no full_matrices argument, …).
      additional_msg = ("this works on JAX because JAX uses a custom "
                        "implementation")
      tf_unimpl(np_dtype, additional_msg=additional_msg, devs=["CPU", "GPU"])

  if prim is lax.select_and_gather_add_p:
    np_dtype = _to_np_dtype(args[0].dtype)
    # TODO: the conversion is only supported for float16/float32 on CPU/GPU,
    # and float16 on TPU. This is because we do not implement a precision
    # reduction in the case where packing 2 n-bit values together results in
    # more than the maximum number of bits allowed on the platform (64 on
    # CPU/GPU, 32 on TPU). This could be fixed by implementing a variadic
    # reduce_window in tfxla, or we can require the user to reduce the
    # precision of their arrays manually based on the platform they run on.
    devices_and_max_bits = [ (["CPU", "GPU"], 64)
                           , (["TPU"], 32)
                           ]
    for devs, max_bits in devices_and_max_bits:
      if dtypes.finfo(np_dtype).bits * 2 > max_bits:
        # TODO: getting an exception "XLA encountered an HLO for which this
        # rewriting is not implemented"
        tf_unimpl(np_dtype, devs=devs)

  if prim in [lax.add_p, lax.reduce_window_sum_p]:
    np_dtype = _to_np_dtype(args[0].dtype)
    if np_dtype in [np.uint16, np.uint32, np.uint64]:
      # TODO(bchetioui): tf.math.add is not defined for the above types.
      tf_unimpl(np_dtype)

  if prim is lax.mul_p:
    np_dtype = _to_np_dtype(args[0].dtype)
    if np_dtype in [np.uint32, np.uint64]:
      # TODO(bchetioui): tf.math.multiply is not defined for the above types.
      tf_unimpl(np_dtype)

  if prim in [lax.scatter_mul_p, lax.scatter_add_p]:
    np_dtype = _to_np_dtype(args[0].dtype)
    if np_dtype == np.complex64:
      tf_unimpl(np_dtype, devs=["TPU"])

  if prim is lax.sort_p:
    np_dtype = _to_np_dtype(args[0].dtype)
    if np_dtype in [np.complex64, np.complex128]:
      tf_unimpl(np_dtype)
    if np_dtype == np.bool_ and len(args) == 2:
      tf_unimpl(np_dtype, additional_msg=(
        "sorting 2 arrays where the first one is an array of booleans is not "
        "supported for XlaSort"))
    if kwargs["is_stable"]:
      tf_unimpl(additional_msg="stable sort not implemented for XlaSort")
    if kwargs["dimension"] != len(np.shape(args[0])) - 1:
      tf_unimpl(additional_msg="only sorting on last dimension is supported "
                               "for XlaSort")
    if len(args) > 2:
      tf_unimpl(additional_msg=(
        "sorting more than 2 arrays is not supported for XlaSort"))

  if prim is lax.population_count_p:
    np_dtype = _to_np_dtype(args[0].dtype)
    if np_dtype in [np.uint32, np.uint64]:
      tf_unimpl(np_dtype)

  if prim is lax.conv_general_dilated_p:
    np_dtype = _to_np_dtype(args[0].dtype)
    batch_group_count = kwargs['batch_group_count']
    if batch_group_count != 1:
      tf_unimpl(additional_msg="batch_group_count != 1 unsupported")
    if np_dtype in [np.complex64, np.complex128]:
      tf_unimpl(np_dtype, additional_msg="likely bug in the HLO -> LLVM IR "
                                         "lowering of XlaConv")

  if prim in [lax.acosh_p, lax.asinh_p, lax.atanh_p, lax.bessel_i0e_p,
              lax.bessel_i1e_p, lax.digamma_p, lax.erf_p, lax.erf_inv_p,
              lax.erfc_p, lax.lgamma_p, lax.round_p, lax.rsqrt_p]:
    np_dtype = _to_np_dtype(args[0].dtype)
    if np_dtype == dtypes.bfloat16:
      tf_unimpl(np_dtype, devs=["CPU", "GPU"])

  if prim in [lax.sinh_p, lax.cosh_p, lax.atanh_p, lax.asinh_p, lax.acosh_p,
              lax.erf_inv_p]:
    np_dtype = _to_np_dtype(args[0].dtype)
    if np_dtype == np.float16:
      # b/158006398: float16 support missing from the kernel of the above
      # operations.
      tf_unimpl(np_dtype)

  if prim is lax.lax_fft.fft_p:
    np_dtype = _to_np_dtype(args[0].dtype)
    if np_dtype in [np.float64, np.complex128]:
      tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled "
                                          "mode (experimental_compile=True))"))

  if prim is lax.top_k_p:
    np_dtype = _to_np_dtype(args[0].dtype)
    if np_dtype in [np.float64, np.int64, np.uint64]:
      tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled "
                                          "mode (experimental_compile=True))"))
  return limitations
Пример #9
0
def num_float_bits(dtype):
    return dtypes.finfo(dtypes.canonicalize_dtype(dtype)).bits
Пример #10
0
def categorize(prim: core.Primitive, *args, **kwargs) \
    -> List[Limitation]:
    """
  Given a primitive and a set of parameters one would like to pass to it,
  categorize identifies the potential limitations the call would encounter when
  converted to TF through jax2tf.

  Args:
    prim: the primitive to call.
    args: the arguments to pass to prim.
    kwargs: the keyword arguments to pass to prim.

  Returns:
    A list of limitations
  """
    limitations: List[Limitation] = []
    all_devices = ["CPU", "GPU", "TPU"]

    def _report_failure(error_type: str,
                        msg: str,
                        devs: Sequence[str] = all_devices) -> None:
        limitations.append(Limitation(prim.name, error_type, msg, tuple(devs)))

    def tf_unimpl(np_dtype: NpDType,
                  additional_msg: Optional[str] = None,
                  devs: Sequence[str] = all_devices) -> None:
        msg = f"{prim.name} is unimplemented for dtype {np_dtype}"
        if additional_msg:
            msg += '; ' + additional_msg
        _report_failure("Missing TF support", msg, devs=devs)

    def _to_np_dtype(dtype) -> NpDType:
        try:
            dtype = to_jax_dtype(dtype)
        except:
            pass
        return np.dtype(dtype)

    if prim in [lax.min_p, lax.max_p]:
        np_dtype = _to_np_dtype(args[0].dtype)
        if np_dtype in [
                np.bool_, np.int8, np.uint16, np.uint32, np.uint64,
                np.complex64, np.complex128
        ]:
            tf_unimpl(np_dtype)

    if prim in [lax.rem_p, lax.atan2_p]:
        np_dtype = _to_np_dtype(args[0].dtype)
        if np_dtype in [np.float16, dtypes.bfloat16]:
            # b/158006398: TF kernels are missing for 'rem' and 'atan2'
            tf_unimpl(np_dtype)

    if prim is lax.nextafter_p:
        np_dtype = _to_np_dtype(args[0].dtype)
        if np_dtype in [np.float16, dtypes.bfloat16]:
            tf_unimpl(np_dtype)

    if prim is lax_linalg.qr_p:
        np_dtype = _to_np_dtype(args[0].dtype)
        if np_dtype in [np.complex64, np.complex128]:
            # See https://github.com/google/jax/pull/3775#issuecomment-659407824;
            # experimental_compile=True breaks for complex types.
            tf_unimpl(np_dtype)

    if prim is lax_linalg.svd_p:
        np_dtype = _to_np_dtype(args[0].dtype)
        if np_dtype in [np.float16, dtypes.bfloat16]:
            # TODO: SVD on TPU for bfloat16 seems to work for JAX but fails for TF
            tf_unimpl(np_dtype, devs=["TPU"])
        elif np_dtype in [np.complex64, np.complex128]:
            # TODO: on CPU and GPU "No registered 'Svd' OpKernel for XLA_CPU_JIT
            # devices". Works on JAX because JAX uses a custom implementation
            additional_msg = ("this works on JAX because JAX uses a custom "
                              "implementation")
            tf_unimpl(np_dtype,
                      additional_msg=additional_msg,
                      devs=["CPU", "GPU"])

    if prim is lax.select_and_gather_add_p:
        np_dtype = _to_np_dtype(args[0].dtype)
        devices_and_max_bits = [(["CPU", "GPU"], 64), (["TPU"], 32)]
        for devs, max_bits in devices_and_max_bits:
            if dtypes.finfo(np_dtype).bits * 2 > max_bits:
                # TODO: getting an exception "XLA encountered an HLO for which this
                # rewriting is not implemented"
                tf_unimpl(np_dtype, devs=devs)

    if prim in [lax.add_p, lax.reduce_window_sum_p]:
        np_dtype = _to_np_dtype(args[0].dtype)
        if np_dtype in [np.uint16, np.uint32, np.uint64]:
            # TODO(bchetioui): tf.math.add is not defined for the above types.
            tf_unimpl(np_dtype)

    if prim is lax.mul_p:
        np_dtype = _to_np_dtype(args[0].dtype)
        if np_dtype in [np.uint32, np.uint64]:
            # TODO(bchetioui): tf.math.multiply is not defined for the above types.
            tf_unimpl(np_dtype)

    if prim in [lax.scatter_mul_p, lax.scatter_add_p]:
        np_dtype = _to_np_dtype(args[0].dtype)
        if np_dtype == np.complex64:
            tf_unimpl(np_dtype, devs=["TPU"])

    if prim is lax.population_count_p:
        np_dtype = _to_np_dtype(args[0].dtype)
        if np_dtype == np.uint32:
            tf_unimpl(np_dtype)

    return limitations
Пример #11
0
def categorize(prim: core.Primitive, *args, **kwargs) \
    -> List[Limitation]:
  """
  Given a primitive and a set of parameters one would like to pass to it,
  categorize identifies the potential limitations the call would encounter when
  converted to TF through jax2tf.

  Args:
    prim: the primitive to call.
    args: the arguments to pass to prim.
    kwargs: the keyword arguments to pass to prim.

  Returns:
    A list of limitations
  """
  limitations: List[Limitation] = []
  all_devices = ["CPU", "GPU", "TPU"]

  def _report_failure(error_type: str, msg: str,
                      affected_dtype: Optional[NpDType] = None,
                      devs: Sequence[str] = all_devices) -> None:
    affected_dtypes = (
      tuple([affected_dtype]) if affected_dtype is not None else tuple())
    limitations.append(Limitation(prim.name, error_type, msg,
                                  affected_dtypes, tuple(devs)))

  def tf_unimpl(np_dtype: Optional[NpDType] = None,
                additional_msg: Optional[str] = None,
                devs: Sequence[str] = all_devices) -> None:
    msg = "Primitive is unimplemented in TF"
    if additional_msg:
      msg += '; ' + additional_msg
    _report_failure(CATEGORY_MISSING_TF_SUPPORT, msg, np_dtype, devs=devs)

  def tf_possible_incorrect(np_dtype: Optional[NpDType] = None,
                            msg: str = "",
                            devs: Sequence[str] = all_devices) -> None:
    _report_failure(CATEGORY_POSSIBLE_INCORRECT_RESULTS, msg, np_dtype, devs=devs)

  def _to_np_dtype(dtype) -> NpDType:
    try:
      dtype = to_jax_dtype(dtype)
    except:
      pass
    return np.dtype(dtype)

  if args and args[0] is not core.unit:
    np_dtype = _to_np_dtype(args[0].dtype)
  else:
    np_dtype = None

  if prim is lax.regularized_incomplete_beta_p:
    if np_dtype in [np.float16, dtypes.bfloat16]:
      tf_unimpl(np_dtype)

  if prim in [lax.reduce_min_p, lax.reduce_max_p]:
    if np_dtype in [np.complex64, np.complex128]:
      tf_unimpl(np_dtype)

  if prim in [lax.min_p, lax.max_p, lax.reduce_window_min_p,
              lax.reduce_window_max_p]:
    if np_dtype in [np.bool_, np.int8, np.uint16, np.uint32, np.uint64,
                    np.complex64, np.complex128]:
      tf_unimpl(np_dtype)

  if prim is lax.div_p:
    if np_dtype in [np.uint8, np.uint16, np.uint32, np.uint64,
                    np.int8, np.int16]:
      tf_unimpl(np_dtype)
    elif dtypes.issubdtype(np_dtype, np.integer):
      tf_unimpl(np_dtype, additional_msg=("integer division fails if the "
                                          "divisor contains a 0"))

  if prim is lax.rem_p:
    if np_dtype in [np.uint8, np.uint16, np.uint32, np.uint64,
                    np.int8, np.int16, np.float16]:
      tf_unimpl(np_dtype)
    elif dtypes.issubdtype(np_dtype, np.integer):
      tf_unimpl(np_dtype, additional_msg=("integer division fails if the "
                                          "divisor contains a 0"))

  if prim is lax.atan2_p and np_dtype in [np.float16, dtypes.bfloat16]:
      # b/158006398: TF kernels are missing for 'rem' and 'atan2'
      tf_unimpl(np_dtype)

  if prim is lax.nextafter_p:
    if np_dtype in [np.float16, dtypes.bfloat16]:
      tf_unimpl(np_dtype)

  if prim is lax.linalg.cholesky_p:
    if np_dtype in [np.complex64, np.complex128]:
      # See https://github.com/google/jax/pull/3775#issuecomment-659407824;
      # experimental_compile=True breaks for complex types.
      tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled "
                                          "mode (experimental_compile=True))"))

  if prim is lax.linalg.qr_p:
    if np_dtype in [np.complex64, np.complex128]:
      # See https://github.com/google/jax/pull/3775#issuecomment-659407824;
      # experimental_compile=True breaks for complex types.
      tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled "
                                          "mode (experimental_compile=True))"))

  if prim is lax.linalg.eig_p:
    tf_unimpl(additional_msg=("this is a problem only in compiled mode "
                              "(experimental_compile=True))"))
    compute_left_eigenvectors = kwargs['compute_left_eigenvectors']
    compute_right_eigenvectors = kwargs['compute_right_eigenvectors']
    if compute_left_eigenvectors and compute_right_eigenvectors:
      tf_unimpl(additional_msg=("it is not possible to request both left and "
                                "right eigenvectors for now"))

  if prim is lax.linalg.eigh_p:
    if np_dtype in [np.complex64, np.complex128]:
      # See https://github.com/google/jax/pull/3775#issuecomment-659407824;
      # experimental_compile=True breaks for complex types.
      tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled "
                                          "mode (experimental_compile=True))"))

  if prim is lax.linalg.lu_p:
    if np_dtype == np.complex64:
      tf_unimpl(np_dtype, devs=["TPU"])

  if prim is lax.linalg.triangular_solve_p:
    if np_dtype in [dtypes.bfloat16, np.float16]:
      tf_unimpl(np_dtype)

  if prim is lax.linalg.svd_p:
    if np_dtype in [dtypes.bfloat16]:
      # TODO: SVD on TPU for bfloat16 seems to work for JAX but fails for TF
      tf_unimpl(np_dtype, devs=["TPU"])
    elif np_dtype in [np.complex64, np.complex128]:
      # TODO: on CPU and GPU "No registered 'Svd' OpKernel for XLA_CPU_JIT
      # devices". Works on JAX because JAX uses a custom implementation.
      # There exists a XlaSvd operation that could replace tf.linalg.svd in
      # these cases but complex numbers support is not implemented in XLA yet,
      # and the API of XlaSvd is different than the one in JAX/TF, which also
      # limits its useability (e.g. no full_matrices argument, …).
      additional_msg = ("this works on JAX because JAX uses a custom "
                        "implementation")
      tf_unimpl(np_dtype, additional_msg=additional_msg, devs=["CPU", "GPU"])

  if prim is lax.select_and_scatter_add_p:
    if np_dtype in [np.uint64, np.uint32, np.uint16]:
      tf_unimpl(np_dtype)

  if prim is lax.select_and_gather_add_p:
    # TODO: the conversion is only supported for float16/float32 on CPU/GPU,
    # and float16 on TPU. This is because we do not implement a precision
    # reduction in the case where packing 2 n-bit values together results in
    # more than the maximum number of bits allowed on the platform (64 on
    # CPU/GPU, 32 on TPU). This could be fixed by implementing a variadic
    # reduce_window in tfxla, or we can require the user to reduce the
    # precision of their arrays manually based on the platform they run on.
    devices_and_max_bits = [ (["CPU", "GPU"], 64)
                           , (["TPU"], 32)
                           ]
    for devs, max_bits in devices_and_max_bits:
      if dtypes.finfo(np_dtype).bits * 2 > max_bits:
        # TODO: getting an exception "XLA encountered an HLO for which this
        # rewriting is not implemented"
        tf_unimpl(np_dtype, devs=devs)

  if prim in [ad_util.add_jaxvals_p, lax.add_p, lax.reduce_window_sum_p]:
    if np_dtype in [np.uint16, np.uint32, np.uint64]:
      # TODO(bchetioui): tf.math.add is not defined for the above types.
      tf_unimpl(np_dtype)

  if prim is lax.mul_p:
    if np_dtype in [np.uint32, np.uint64]:
      # TODO(bchetioui): tf.math.multiply is not defined for the above types.
      tf_unimpl(np_dtype)

  if prim is lax.sort_p:
    if np_dtype in [np.complex64, np.complex128]:
      tf_unimpl(np_dtype)
    if np_dtype == np.bool_ and len(args) == 2:
      tf_unimpl(np_dtype, additional_msg=(
        "sorting 2 arrays where the first one is an array of booleans is not "
        "supported for XlaSort"))
    if kwargs["is_stable"]:
      tf_unimpl(additional_msg="stable sort not implemented for XlaSort")
    if kwargs["dimension"] != len(np.shape(args[0])) - 1:
      tf_unimpl(additional_msg="only sorting on last dimension is supported "
                               "for XlaSort")
    if len(args) > 2:
      tf_unimpl(additional_msg=(
        "sorting more than 2 arrays is not supported for XlaSort"))

  if prim is lax.population_count_p:
    if np_dtype in [np.uint32, np.uint64]:
      tf_unimpl(np_dtype)

  if prim is lax.clamp_p:
    if np_dtype in [np.int8, np.uint16, np.uint32, np.uint64]:
      tf_unimpl(np_dtype)

  # Testing with matmul (TODO: comment out and test without matmul)
  if prim is lax.dot_general_p:
    np_dtype = _to_np_dtype(args[0].dtype)
    if np_dtype in [np.bool, np.uint8, np.uint16, np.uint32, np.uint64,
                    np.int8]:
      tf_unimpl(np_dtype)
    elif np_dtype == np.int16:
      # TODO(bchetioui): the path using 'einsum' is not compatible with int16
      # arguments on CPU/GPU, while the one using 'matmul' is (but not in
      # compiled mode).
      tf_unimpl(np_dtype, additional_msg=("only cases representable as 2D "
                                          "matrix multiplication can be "
                                          "converted properly"),
                devs=['CPU', 'GPU'])
      tf_unimpl(np_dtype, devs=['TPU'])
    elif np_dtype in [np.int16, np.int64]:
      devs = ['CPU'] if np_dtype == np.int16 else ['CPU', 'GPU']
      tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled "
                                          "mode (experimental_compile=True))"),
                devs=devs)
  if prim is lax.conv_general_dilated_p:
    batch_group_count = kwargs['batch_group_count']
    if batch_group_count != 1:
      tf_unimpl(additional_msg="batch_group_count != 1 unsupported")
    if np_dtype in [np.complex64, np.complex128]:
      tf_unimpl(np_dtype, additional_msg="likely bug in the HLO -> LLVM IR "
                                         "lowering of XlaConv")

  if prim in [lax.acosh_p, lax.asinh_p, lax.atanh_p, lax.bessel_i0e_p,
              lax.bessel_i1e_p, lax.digamma_p, lax.erf_p, lax.erf_inv_p,
              lax.erfc_p, lax.lgamma_p, lax.round_p, lax.rsqrt_p]:
    if np_dtype == dtypes.bfloat16:
      tf_unimpl(np_dtype, devs=["CPU", "GPU"])

  if prim is lax.convert_element_type_p:
    if np_dtype == dtypes.bfloat16:
      tf_unimpl(np_dtype, devs=["CPU", "GPU"])

  if prim in [lax.sinh_p, lax.cosh_p, lax.atanh_p, lax.asinh_p, lax.acosh_p,
              lax.erf_inv_p]:
    if np_dtype == np.float16:
      # b/158006398: float16 support missing from the kernel of the above
      # operations.
      tf_unimpl(np_dtype)

  if prim is lax.integer_pow_p:
    if np_dtype in [np.uint8, np.uint16, np.uint32, np.uint64, np.int8,
                    np.int16]:
      tf_unimpl(np_dtype)

  if prim is lax.rev_p:
    if np_dtype in [np.uint32, np.uint64]:
      tf_unimpl(np_dtype)

  if prim is lax.sub_p:
    if np_dtype == np.uint64:
      tf_unimpl(np_dtype)

  if prim is lax.bitcast_convert_type_p:
    if np_dtype == np.bool_:
      tf_unimpl(np_dtype)

  if prim in [lax.le_p, lax.lt_p, lax.ge_p, lax.gt_p]:
    if np_dtype in [np.bool_, np.uint16, np.uint32, np.uint64]:
      tf_unimpl(np_dtype)

  if prim is lax.fft_p:
    if np_dtype in [np.float64, np.complex128]:
      tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled "
                                          "mode (experimental_compile=True))"))

  if prim is lax.top_k_p:
    if np_dtype in [np.float64, np.int64, np.uint64]:
      tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled "
                                          "mode (experimental_compile=True))"))
  return limitations