def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory): rng = rng_factory(self.rng()) if jtu.device_under_test() == "tpu" and op is lax.mul: raise SkipTest("unimplemented case") tol = {dtypes.bfloat16: 2e-1, onp.float16: 1e-1, onp.float32: 1e-1, onp.float64: 1e-3, onp.complex64: 1e-1} operand = rng(shape, dtype) init_val = onp.asarray(init_val, dtype=dtype) reduce = lambda operand: lax.reduce(operand, init_val, op, dims) eps = (1.0 if dtypes.finfo(dtype).bits == 16 and op is lax.add else 1e-1 if dtype == dtypes.bfloat16 else 1e-2 if dtypes.finfo(dtype).bits == 32 else None) check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps)
def testConvGeneralDilatedBatching(self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil, dimension_numbers, perms, feature_group_count, batch_group_count, lhs_bdim, rhs_bdim): rng = jtu.rand_default(self.rng()) tol = 1e-1 if dtypes.finfo(dtype).bits <= 32 else 1e-3 # permute shapes to match dim_spec, scale by feature_group_count lhs_perm, rhs_perm = perms lhs_shape = list(np.take(lhs_shape, lhs_perm)) rhs_shape = list(np.take(rhs_shape, rhs_perm)) conv = partial(lax.conv_general_dilated, window_strides=strides, padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision=lax.Precision.HIGHEST) self._CheckBatching(conv, 5, (lhs_bdim, rhs_bdim), (lhs_shape, rhs_shape), (dtype, dtype), rng, rtol=tol, atol=tol)
def eig_abstract_eval(operand, *, compute_left_eigenvectors, compute_right_eigenvectors): if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError( "Argument to nonsymmetric eigendecomposition must have " "shape [..., n, n], got shape {}".format(operand.shape)) batch_dims = operand.shape[:-2] n = operand.shape[-1] dtype = np.complex64 if dtypes.finfo( operand.dtype).bits == 32 else np.complex128 dtype = dtypes.canonicalize_dtype(dtype) vl = vr = ShapedArray(batch_dims + (n, n), dtype) w = ShapedArray(batch_dims + (n, ), dtype) else: raise NotImplementedError output = [w] if compute_left_eigenvectors: output.append(vl) if compute_right_eigenvectors: output.append(vr) return tuple(output)
def testMapCoordinates(self, shape, dtype, coords_shape, coords_dtype, order, mode, cval, impl, round_, rng_factory): def args_maker(): x = np.arange(prod(shape), dtype=dtype).reshape(shape) coords = [(size - 1) * rng(coords_shape, coords_dtype) for size in shape] if round_: coords = [c.round().astype(int) for c in coords] return x, coords rng = rng_factory(self.rng()) lsp_op = lambda x, c: lsp_ndimage.map_coordinates( x, c, order=order, mode=mode, cval=cval) impl_fun = (osp_ndimage.map_coordinates if impl == "original" else _fixed_ref_map_coordinates) osp_op = lambda x, c: impl_fun(x, c, order=order, mode=mode, cval=cval) if dtype in float_dtypes: epsilon = max([ dtypes.finfo(dtypes.canonicalize_dtype(d)).eps for d in [dtype, coords_dtype] ]) self._CheckAgainstNumpy(osp_op, lsp_op, args_maker, tol=100 * epsilon) else: self._CheckAgainstNumpy(osp_op, lsp_op, args_maker, tol=0)
def test_select_and_gather_add(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] if dtype is dtypes.bfloat16: raise unittest.SkipTest("bfloat16 not implemented") max_bits = 64 if jtu.device_under_test() == "tpu": max_bits = 32 if dtypes.finfo(dtype).bits * 2 > max_bits: with self.assertRaisesRegex(BaseException, "XLA encountered an HLO for which this rewriting is not implemented"): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng())) else: self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
def test_select_and_gather_add(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] max_bits = 64 if jtu.device_under_test() == "tpu": max_bits = 32 expect_tf_exceptions = False if dtypes.finfo(dtype).bits * 2 > max_bits: # TODO: getting an exception "XLA encountered an HLO for which this rewriting is not implemented" expect_tf_exceptions = True self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), expect_tf_exceptions=expect_tf_exceptions)
def eig_abstract_eval(operand): if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError("Argument to nonsymmetric eigendecomposition must have " "shape [..., n, n], got shape {}".format(operand.shape)) batch_dims = operand.shape[:-2] n = operand.shape[-1] dtype = onp.complex64 if dtypes.finfo(operand.dtype).bits == 32 else onp.complex128 dtype = dtypes.canonicalize_dtype(dtype) vl = vr = ShapedArray(batch_dims + (n, n), dtype) w = ShapedArray(batch_dims + (n,), dtype) else: raise NotImplementedError return w, vl, vr
def categorize(prim: core.Primitive, *args, **kwargs) \ -> List[Limitation]: """ Given a primitive and a set of parameters one would like to pass to it, categorize identifies the potential limitations the call would encounter when converted to TF through jax2tf. Args: prim: the primitive to call. args: the arguments to pass to prim. kwargs: the keyword arguments to pass to prim. Returns: A list of limitations """ limitations: List[Limitation] = [] all_devices = ["CPU", "GPU", "TPU"] def _report_failure(error_type: str, msg: str, affected_dtype: Optional[NpDType] = None, devs: Sequence[str] = all_devices) -> None: affected_dtypes = ( tuple([affected_dtype]) if affected_dtype is not None else tuple()) limitations.append(Limitation(prim.name, error_type, msg, affected_dtypes, tuple(devs))) def tf_unimpl(np_dtype: Optional[NpDType] = None, additional_msg: Optional[str] = None, devs: Sequence[str] = all_devices) -> None: missing_tf_support = "Missing TF support" msg = "Primitive is unimplemented" if additional_msg: msg += '; ' + additional_msg _report_failure(missing_tf_support, msg, np_dtype, devs=devs) def _to_np_dtype(dtype) -> NpDType: try: dtype = to_jax_dtype(dtype) except: pass return np.dtype(dtype) if prim in [lax.min_p, lax.max_p]: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype in [np.bool_, np.int8, np.uint16, np.uint32, np.uint64, np.complex64, np.complex128]: tf_unimpl(np_dtype) if prim in [lax.rem_p, lax.atan2_p]: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype in [np.float16, dtypes.bfloat16]: # b/158006398: TF kernels are missing for 'rem' and 'atan2' tf_unimpl(np_dtype) if prim is lax.nextafter_p: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype in [np.float16, dtypes.bfloat16]: tf_unimpl(np_dtype) if prim is lax_linalg.cholesky_p: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype in [np.complex64, np.complex128]: # See https://github.com/google/jax/pull/3775#issuecomment-659407824; # experimental_compile=True breaks for complex types. tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled " "mode (experimental_compile=True))")) if prim is lax_linalg.qr_p: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype in [np.complex64, np.complex128]: # See https://github.com/google/jax/pull/3775#issuecomment-659407824; # experimental_compile=True breaks for complex types. tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled " "mode (experimental_compile=True))")) if prim is lax_linalg.svd_p: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype in [dtypes.bfloat16]: # TODO: SVD on TPU for bfloat16 seems to work for JAX but fails for TF tf_unimpl(np_dtype, devs=["TPU"]) elif np_dtype in [np.complex64, np.complex128]: # TODO: on CPU and GPU "No registered 'Svd' OpKernel for XLA_CPU_JIT # devices". Works on JAX because JAX uses a custom implementation. # There exists a XlaSvd operation that could replace tf.linalg.svd in # these cases but complex numbers support is not implemented in XLA yet, # and the API of XlaSvd is different than the one in JAX/TF, which also # limits its useability (e.g. no full_matrices argument, …). additional_msg = ("this works on JAX because JAX uses a custom " "implementation") tf_unimpl(np_dtype, additional_msg=additional_msg, devs=["CPU", "GPU"]) if prim is lax.select_and_gather_add_p: np_dtype = _to_np_dtype(args[0].dtype) # TODO: the conversion is only supported for float16/float32 on CPU/GPU, # and float16 on TPU. This is because we do not implement a precision # reduction in the case where packing 2 n-bit values together results in # more than the maximum number of bits allowed on the platform (64 on # CPU/GPU, 32 on TPU). This could be fixed by implementing a variadic # reduce_window in tfxla, or we can require the user to reduce the # precision of their arrays manually based on the platform they run on. devices_and_max_bits = [ (["CPU", "GPU"], 64) , (["TPU"], 32) ] for devs, max_bits in devices_and_max_bits: if dtypes.finfo(np_dtype).bits * 2 > max_bits: # TODO: getting an exception "XLA encountered an HLO for which this # rewriting is not implemented" tf_unimpl(np_dtype, devs=devs) if prim in [lax.add_p, lax.reduce_window_sum_p]: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype in [np.uint16, np.uint32, np.uint64]: # TODO(bchetioui): tf.math.add is not defined for the above types. tf_unimpl(np_dtype) if prim is lax.mul_p: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype in [np.uint32, np.uint64]: # TODO(bchetioui): tf.math.multiply is not defined for the above types. tf_unimpl(np_dtype) if prim in [lax.scatter_mul_p, lax.scatter_add_p]: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype == np.complex64: tf_unimpl(np_dtype, devs=["TPU"]) if prim is lax.sort_p: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype in [np.complex64, np.complex128]: tf_unimpl(np_dtype) if np_dtype == np.bool_ and len(args) == 2: tf_unimpl(np_dtype, additional_msg=( "sorting 2 arrays where the first one is an array of booleans is not " "supported for XlaSort")) if kwargs["is_stable"]: tf_unimpl(additional_msg="stable sort not implemented for XlaSort") if kwargs["dimension"] != len(np.shape(args[0])) - 1: tf_unimpl(additional_msg="only sorting on last dimension is supported " "for XlaSort") if len(args) > 2: tf_unimpl(additional_msg=( "sorting more than 2 arrays is not supported for XlaSort")) if prim is lax.population_count_p: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype in [np.uint32, np.uint64]: tf_unimpl(np_dtype) if prim is lax.conv_general_dilated_p: np_dtype = _to_np_dtype(args[0].dtype) batch_group_count = kwargs['batch_group_count'] if batch_group_count != 1: tf_unimpl(additional_msg="batch_group_count != 1 unsupported") if np_dtype in [np.complex64, np.complex128]: tf_unimpl(np_dtype, additional_msg="likely bug in the HLO -> LLVM IR " "lowering of XlaConv") if prim in [lax.acosh_p, lax.asinh_p, lax.atanh_p, lax.bessel_i0e_p, lax.bessel_i1e_p, lax.digamma_p, lax.erf_p, lax.erf_inv_p, lax.erfc_p, lax.lgamma_p, lax.round_p, lax.rsqrt_p]: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype == dtypes.bfloat16: tf_unimpl(np_dtype, devs=["CPU", "GPU"]) if prim in [lax.sinh_p, lax.cosh_p, lax.atanh_p, lax.asinh_p, lax.acosh_p, lax.erf_inv_p]: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype == np.float16: # b/158006398: float16 support missing from the kernel of the above # operations. tf_unimpl(np_dtype) if prim is lax.lax_fft.fft_p: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype in [np.float64, np.complex128]: tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled " "mode (experimental_compile=True))")) if prim is lax.top_k_p: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype in [np.float64, np.int64, np.uint64]: tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled " "mode (experimental_compile=True))")) return limitations
def num_float_bits(dtype): return dtypes.finfo(dtypes.canonicalize_dtype(dtype)).bits
def categorize(prim: core.Primitive, *args, **kwargs) \ -> List[Limitation]: """ Given a primitive and a set of parameters one would like to pass to it, categorize identifies the potential limitations the call would encounter when converted to TF through jax2tf. Args: prim: the primitive to call. args: the arguments to pass to prim. kwargs: the keyword arguments to pass to prim. Returns: A list of limitations """ limitations: List[Limitation] = [] all_devices = ["CPU", "GPU", "TPU"] def _report_failure(error_type: str, msg: str, devs: Sequence[str] = all_devices) -> None: limitations.append(Limitation(prim.name, error_type, msg, tuple(devs))) def tf_unimpl(np_dtype: NpDType, additional_msg: Optional[str] = None, devs: Sequence[str] = all_devices) -> None: msg = f"{prim.name} is unimplemented for dtype {np_dtype}" if additional_msg: msg += '; ' + additional_msg _report_failure("Missing TF support", msg, devs=devs) def _to_np_dtype(dtype) -> NpDType: try: dtype = to_jax_dtype(dtype) except: pass return np.dtype(dtype) if prim in [lax.min_p, lax.max_p]: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype in [ np.bool_, np.int8, np.uint16, np.uint32, np.uint64, np.complex64, np.complex128 ]: tf_unimpl(np_dtype) if prim in [lax.rem_p, lax.atan2_p]: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype in [np.float16, dtypes.bfloat16]: # b/158006398: TF kernels are missing for 'rem' and 'atan2' tf_unimpl(np_dtype) if prim is lax.nextafter_p: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype in [np.float16, dtypes.bfloat16]: tf_unimpl(np_dtype) if prim is lax_linalg.qr_p: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype in [np.complex64, np.complex128]: # See https://github.com/google/jax/pull/3775#issuecomment-659407824; # experimental_compile=True breaks for complex types. tf_unimpl(np_dtype) if prim is lax_linalg.svd_p: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype in [np.float16, dtypes.bfloat16]: # TODO: SVD on TPU for bfloat16 seems to work for JAX but fails for TF tf_unimpl(np_dtype, devs=["TPU"]) elif np_dtype in [np.complex64, np.complex128]: # TODO: on CPU and GPU "No registered 'Svd' OpKernel for XLA_CPU_JIT # devices". Works on JAX because JAX uses a custom implementation additional_msg = ("this works on JAX because JAX uses a custom " "implementation") tf_unimpl(np_dtype, additional_msg=additional_msg, devs=["CPU", "GPU"]) if prim is lax.select_and_gather_add_p: np_dtype = _to_np_dtype(args[0].dtype) devices_and_max_bits = [(["CPU", "GPU"], 64), (["TPU"], 32)] for devs, max_bits in devices_and_max_bits: if dtypes.finfo(np_dtype).bits * 2 > max_bits: # TODO: getting an exception "XLA encountered an HLO for which this # rewriting is not implemented" tf_unimpl(np_dtype, devs=devs) if prim in [lax.add_p, lax.reduce_window_sum_p]: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype in [np.uint16, np.uint32, np.uint64]: # TODO(bchetioui): tf.math.add is not defined for the above types. tf_unimpl(np_dtype) if prim is lax.mul_p: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype in [np.uint32, np.uint64]: # TODO(bchetioui): tf.math.multiply is not defined for the above types. tf_unimpl(np_dtype) if prim in [lax.scatter_mul_p, lax.scatter_add_p]: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype == np.complex64: tf_unimpl(np_dtype, devs=["TPU"]) if prim is lax.population_count_p: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype == np.uint32: tf_unimpl(np_dtype) return limitations
def categorize(prim: core.Primitive, *args, **kwargs) \ -> List[Limitation]: """ Given a primitive and a set of parameters one would like to pass to it, categorize identifies the potential limitations the call would encounter when converted to TF through jax2tf. Args: prim: the primitive to call. args: the arguments to pass to prim. kwargs: the keyword arguments to pass to prim. Returns: A list of limitations """ limitations: List[Limitation] = [] all_devices = ["CPU", "GPU", "TPU"] def _report_failure(error_type: str, msg: str, affected_dtype: Optional[NpDType] = None, devs: Sequence[str] = all_devices) -> None: affected_dtypes = ( tuple([affected_dtype]) if affected_dtype is not None else tuple()) limitations.append(Limitation(prim.name, error_type, msg, affected_dtypes, tuple(devs))) def tf_unimpl(np_dtype: Optional[NpDType] = None, additional_msg: Optional[str] = None, devs: Sequence[str] = all_devices) -> None: msg = "Primitive is unimplemented in TF" if additional_msg: msg += '; ' + additional_msg _report_failure(CATEGORY_MISSING_TF_SUPPORT, msg, np_dtype, devs=devs) def tf_possible_incorrect(np_dtype: Optional[NpDType] = None, msg: str = "", devs: Sequence[str] = all_devices) -> None: _report_failure(CATEGORY_POSSIBLE_INCORRECT_RESULTS, msg, np_dtype, devs=devs) def _to_np_dtype(dtype) -> NpDType: try: dtype = to_jax_dtype(dtype) except: pass return np.dtype(dtype) if args and args[0] is not core.unit: np_dtype = _to_np_dtype(args[0].dtype) else: np_dtype = None if prim is lax.regularized_incomplete_beta_p: if np_dtype in [np.float16, dtypes.bfloat16]: tf_unimpl(np_dtype) if prim in [lax.reduce_min_p, lax.reduce_max_p]: if np_dtype in [np.complex64, np.complex128]: tf_unimpl(np_dtype) if prim in [lax.min_p, lax.max_p, lax.reduce_window_min_p, lax.reduce_window_max_p]: if np_dtype in [np.bool_, np.int8, np.uint16, np.uint32, np.uint64, np.complex64, np.complex128]: tf_unimpl(np_dtype) if prim is lax.div_p: if np_dtype in [np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16]: tf_unimpl(np_dtype) elif dtypes.issubdtype(np_dtype, np.integer): tf_unimpl(np_dtype, additional_msg=("integer division fails if the " "divisor contains a 0")) if prim is lax.rem_p: if np_dtype in [np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16, np.float16]: tf_unimpl(np_dtype) elif dtypes.issubdtype(np_dtype, np.integer): tf_unimpl(np_dtype, additional_msg=("integer division fails if the " "divisor contains a 0")) if prim is lax.atan2_p and np_dtype in [np.float16, dtypes.bfloat16]: # b/158006398: TF kernels are missing for 'rem' and 'atan2' tf_unimpl(np_dtype) if prim is lax.nextafter_p: if np_dtype in [np.float16, dtypes.bfloat16]: tf_unimpl(np_dtype) if prim is lax.linalg.cholesky_p: if np_dtype in [np.complex64, np.complex128]: # See https://github.com/google/jax/pull/3775#issuecomment-659407824; # experimental_compile=True breaks for complex types. tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled " "mode (experimental_compile=True))")) if prim is lax.linalg.qr_p: if np_dtype in [np.complex64, np.complex128]: # See https://github.com/google/jax/pull/3775#issuecomment-659407824; # experimental_compile=True breaks for complex types. tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled " "mode (experimental_compile=True))")) if prim is lax.linalg.eig_p: tf_unimpl(additional_msg=("this is a problem only in compiled mode " "(experimental_compile=True))")) compute_left_eigenvectors = kwargs['compute_left_eigenvectors'] compute_right_eigenvectors = kwargs['compute_right_eigenvectors'] if compute_left_eigenvectors and compute_right_eigenvectors: tf_unimpl(additional_msg=("it is not possible to request both left and " "right eigenvectors for now")) if prim is lax.linalg.eigh_p: if np_dtype in [np.complex64, np.complex128]: # See https://github.com/google/jax/pull/3775#issuecomment-659407824; # experimental_compile=True breaks for complex types. tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled " "mode (experimental_compile=True))")) if prim is lax.linalg.lu_p: if np_dtype == np.complex64: tf_unimpl(np_dtype, devs=["TPU"]) if prim is lax.linalg.triangular_solve_p: if np_dtype in [dtypes.bfloat16, np.float16]: tf_unimpl(np_dtype) if prim is lax.linalg.svd_p: if np_dtype in [dtypes.bfloat16]: # TODO: SVD on TPU for bfloat16 seems to work for JAX but fails for TF tf_unimpl(np_dtype, devs=["TPU"]) elif np_dtype in [np.complex64, np.complex128]: # TODO: on CPU and GPU "No registered 'Svd' OpKernel for XLA_CPU_JIT # devices". Works on JAX because JAX uses a custom implementation. # There exists a XlaSvd operation that could replace tf.linalg.svd in # these cases but complex numbers support is not implemented in XLA yet, # and the API of XlaSvd is different than the one in JAX/TF, which also # limits its useability (e.g. no full_matrices argument, …). additional_msg = ("this works on JAX because JAX uses a custom " "implementation") tf_unimpl(np_dtype, additional_msg=additional_msg, devs=["CPU", "GPU"]) if prim is lax.select_and_scatter_add_p: if np_dtype in [np.uint64, np.uint32, np.uint16]: tf_unimpl(np_dtype) if prim is lax.select_and_gather_add_p: # TODO: the conversion is only supported for float16/float32 on CPU/GPU, # and float16 on TPU. This is because we do not implement a precision # reduction in the case where packing 2 n-bit values together results in # more than the maximum number of bits allowed on the platform (64 on # CPU/GPU, 32 on TPU). This could be fixed by implementing a variadic # reduce_window in tfxla, or we can require the user to reduce the # precision of their arrays manually based on the platform they run on. devices_and_max_bits = [ (["CPU", "GPU"], 64) , (["TPU"], 32) ] for devs, max_bits in devices_and_max_bits: if dtypes.finfo(np_dtype).bits * 2 > max_bits: # TODO: getting an exception "XLA encountered an HLO for which this # rewriting is not implemented" tf_unimpl(np_dtype, devs=devs) if prim in [ad_util.add_jaxvals_p, lax.add_p, lax.reduce_window_sum_p]: if np_dtype in [np.uint16, np.uint32, np.uint64]: # TODO(bchetioui): tf.math.add is not defined for the above types. tf_unimpl(np_dtype) if prim is lax.mul_p: if np_dtype in [np.uint32, np.uint64]: # TODO(bchetioui): tf.math.multiply is not defined for the above types. tf_unimpl(np_dtype) if prim is lax.sort_p: if np_dtype in [np.complex64, np.complex128]: tf_unimpl(np_dtype) if np_dtype == np.bool_ and len(args) == 2: tf_unimpl(np_dtype, additional_msg=( "sorting 2 arrays where the first one is an array of booleans is not " "supported for XlaSort")) if kwargs["is_stable"]: tf_unimpl(additional_msg="stable sort not implemented for XlaSort") if kwargs["dimension"] != len(np.shape(args[0])) - 1: tf_unimpl(additional_msg="only sorting on last dimension is supported " "for XlaSort") if len(args) > 2: tf_unimpl(additional_msg=( "sorting more than 2 arrays is not supported for XlaSort")) if prim is lax.population_count_p: if np_dtype in [np.uint32, np.uint64]: tf_unimpl(np_dtype) if prim is lax.clamp_p: if np_dtype in [np.int8, np.uint16, np.uint32, np.uint64]: tf_unimpl(np_dtype) # Testing with matmul (TODO: comment out and test without matmul) if prim is lax.dot_general_p: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype in [np.bool, np.uint8, np.uint16, np.uint32, np.uint64, np.int8]: tf_unimpl(np_dtype) elif np_dtype == np.int16: # TODO(bchetioui): the path using 'einsum' is not compatible with int16 # arguments on CPU/GPU, while the one using 'matmul' is (but not in # compiled mode). tf_unimpl(np_dtype, additional_msg=("only cases representable as 2D " "matrix multiplication can be " "converted properly"), devs=['CPU', 'GPU']) tf_unimpl(np_dtype, devs=['TPU']) elif np_dtype in [np.int16, np.int64]: devs = ['CPU'] if np_dtype == np.int16 else ['CPU', 'GPU'] tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled " "mode (experimental_compile=True))"), devs=devs) if prim is lax.conv_general_dilated_p: batch_group_count = kwargs['batch_group_count'] if batch_group_count != 1: tf_unimpl(additional_msg="batch_group_count != 1 unsupported") if np_dtype in [np.complex64, np.complex128]: tf_unimpl(np_dtype, additional_msg="likely bug in the HLO -> LLVM IR " "lowering of XlaConv") if prim in [lax.acosh_p, lax.asinh_p, lax.atanh_p, lax.bessel_i0e_p, lax.bessel_i1e_p, lax.digamma_p, lax.erf_p, lax.erf_inv_p, lax.erfc_p, lax.lgamma_p, lax.round_p, lax.rsqrt_p]: if np_dtype == dtypes.bfloat16: tf_unimpl(np_dtype, devs=["CPU", "GPU"]) if prim is lax.convert_element_type_p: if np_dtype == dtypes.bfloat16: tf_unimpl(np_dtype, devs=["CPU", "GPU"]) if prim in [lax.sinh_p, lax.cosh_p, lax.atanh_p, lax.asinh_p, lax.acosh_p, lax.erf_inv_p]: if np_dtype == np.float16: # b/158006398: float16 support missing from the kernel of the above # operations. tf_unimpl(np_dtype) if prim is lax.integer_pow_p: if np_dtype in [np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16]: tf_unimpl(np_dtype) if prim is lax.rev_p: if np_dtype in [np.uint32, np.uint64]: tf_unimpl(np_dtype) if prim is lax.sub_p: if np_dtype == np.uint64: tf_unimpl(np_dtype) if prim is lax.bitcast_convert_type_p: if np_dtype == np.bool_: tf_unimpl(np_dtype) if prim in [lax.le_p, lax.lt_p, lax.ge_p, lax.gt_p]: if np_dtype in [np.bool_, np.uint16, np.uint32, np.uint64]: tf_unimpl(np_dtype) if prim is lax.fft_p: if np_dtype in [np.float64, np.complex128]: tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled " "mode (experimental_compile=True))")) if prim is lax.top_k_p: if np_dtype in [np.float64, np.int64, np.uint64]: tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled " "mode (experimental_compile=True))")) return limitations