def test_scatter(self, harness: primitive_harness.Harness): f_name = harness.params['f_lax'].__name__ dtype = harness.params['dtype'] if jtu.device_under_test() == 'tpu': if dtype is np.complex64 and f_name in [ 'scatter_min', 'scatter_max' ]: raise unittest.SkipTest( f"TODO: complex {f_name} on TPU fails in JAX") self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
def test_triangular_solve(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] if dtype == np.float16 and jtu.device_under_test() == "gpu": raise unittest.SkipTest( f"Triangular solve is not implemented in JAX for dtype {dtype}" ) atol = rtol = None if dtype == np.float32: atol = rtol = 1e-5 self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=atol, rtol=rtol)
def test_fft(self, harness: primitive_harness.Harness): if len(harness.params["fft_lengths"]) > 3: with self.assertRaisesRegex(RuntimeError, "FFT only supports ranks 1-3"): harness.dyn_fun(*harness.dyn_args_maker(self.rng())) elif (jtu.device_under_test() == "tpu" and len(harness.params["fft_lengths"]) > 1): # TODO(b/140351181): FFT is mostly unimplemented on TPU, even for JAX with self.assertRaisesRegex(RuntimeError, "only 1D FFT is currently supported."): harness.dyn_fun(*harness.dyn_args_maker(self.rng())) else: tol = None if jtu.device_under_test() == "gpu": if harness.params["dtype"] in jtu.dtypes.boolean: tol = 0.01 else: tol = 1e-3 self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=tol, rtol=tol)
def test_qr(self, harness: primitive_harness.Harness): # See jax.lib.lapack.geqrf for the list of compatible types if (harness.params["dtype"] in [jnp.float32, jnp.float64] or harness.params["dtype"] == jnp.float16 and jtu.device_under_test() == "tpu"): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=1e-5, rtol=1e-5) elif harness.params["dtype"] in [jnp.complex64, jnp.complex128]: if (jtu.device_under_test() == "tpu" and harness.params["dtype"] in [jnp.complex64]): raise unittest.SkipTest("QR for c64 not implemented on TPU") # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824. # - check_compiled=True breaks for complex types; # - for now, the performance of the HLO QR implementation called when # compiling with TF is expected to have worse performance than the # custom calls made in JAX. self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), expect_tf_exceptions=True, atol=1e-5, rtol=1e-5) else: # TODO(necula): fix QR bug on TPU if (jtu.device_under_test() == "tpu" and harness.params["dtype"] in (jnp.bfloat16, jnp.int32, jnp.uint32)): raise unittest.SkipTest( "QR bug on TPU for certain types: error not raised") if (jtu.device_under_test() == "tpu" and harness.params["dtype"] in (jnp.bool_, )): raise unittest.SkipTest( "QR bug on TPU for certain types: invalid cast") expected_error = ValueError if jtu.device_under_test( ) == "gpu" else NotImplementedError with self.assertRaisesRegex(expected_error, "Unsupported dtype"): harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
def test_prim(self, harness: primitive_harness.Harness): limitations = Jax2TfLimitation.limitations_for_harness(harness) device = jtu.device_under_test() limitations = tuple( filter(lambda l: l.filter(device=device, dtype=harness.dtype), limitations)) func_jax = harness.dyn_fun args = harness.dyn_args_maker(self.rng()) enable_xla = harness.params.get("enable_xla", True) self.ConvertAndCompare(func_jax, *args, limitations=limitations, enable_xla=enable_xla)
def test_min_max(self, harness: primitive_harness.Harness): # TODO(bchetioui): discrepancies between TF & JAX when comparing with NaN; # JAX always returns NaN, while TF returns the value NaN is compared with. def custom_assert(result_jax, result_tf): mask = np.isnan(result_jax) self.assertAllClose(result_jax[~mask], result_tf[~mask]) # TODO(bchetioui): figure out why we need always_custom_assert=True always_custom_assert = True self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), custom_assert=custom_assert, always_custom_assert=always_custom_assert)
def test_unary_elementwise(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] if dtype is dtypes.bfloat16: raise unittest.SkipTest("bfloat16 not implemented") arg, = harness.dyn_args_maker(self.rng()) custom_assert = None if harness.params["lax_name"] == "digamma": # digamma is not defined at 0 and -1 def custom_assert(result_jax, result_tf): # lax.digamma returns NaN and tf.math.digamma returns inf special_cases = (arg == 0.) | (arg == -1.) nr_special_cases = np.count_nonzero(special_cases) self.assertAllClose( np.full((nr_special_cases, ), dtype(np.nan)), result_jax[special_cases]) self.assertAllClose( np.full((nr_special_cases, ), dtype(np.inf)), result_tf[special_cases]) # non-special cases are equal self.assertAllClose(result_jax[~special_cases], result_tf[~special_cases]) if harness.params["lax_name"] == "erf_inv": # TODO(necula): fix bug with erf_inv/f16 if dtype is np.float16: raise unittest.SkipTest("TODO: fix bug") # erf_inf is not defined for arg <= -1 or arg >= 1 def custom_assert(result_jax, result_tf): # noqa: F811 # for arg < -1 or arg > 1 # lax.erf_inf returns NaN; tf.math.erf_inv return +/- inf special_cases = (arg < -1.) | (arg > 1.) nr_special_cases = np.count_nonzero(special_cases) self.assertAllClose( np.full((nr_special_cases, ), dtype(np.nan)), result_jax[special_cases]) signs = np.where(arg[special_cases] < 0., -1., 1.) self.assertAllClose( np.full((nr_special_cases, ), signs * dtype(np.inf)), result_tf[special_cases]) # non-special cases are equal self.assertAllClose(result_jax[~special_cases], result_tf[~special_cases]) atol = None if jtu.device_under_test() == "gpu": # TODO(necula): revisit once we fix the GPU tests atol = 1e-3 self.ConvertAndCompare(harness.dyn_fun, arg, custom_assert=custom_assert, atol=atol)
def test_unary_elementwise(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] lax_name = harness.params["lax_name"] arg, = harness.dyn_args_maker(self.rng()) custom_assert = None if lax_name == "digamma": # TODO(necula): fix bug with digamma/(f32|f16) on TPU if dtype in [np.float16, np.float32] and jtu.device_under_test() == "tpu": raise unittest.SkipTest("TODO: fix bug: nan vs not-nan") # In the bfloat16 case, TF and lax both return NaN in undefined cases. if not dtype is dtypes.bfloat16: # digamma is not defined at 0 and -1 def custom_assert(result_jax, result_tf): # lax.digamma returns NaN and tf.math.digamma returns inf special_cases = (arg == 0.) | (arg == -1.) nr_special_cases = np.count_nonzero(special_cases) self.assertAllClose(np.full((nr_special_cases,), dtype(np.nan)), result_jax[special_cases]) self.assertAllClose(np.full((nr_special_cases,), dtype(np.inf)), result_tf[special_cases]) # non-special cases are equal self.assertAllClose(result_jax[~ special_cases], result_tf[~ special_cases]) if lax_name == "erf_inv": # TODO(necula): fix erf_inv bug on TPU if jtu.device_under_test() == "tpu": raise unittest.SkipTest("erf_inv bug on TPU: nan vs non-nan") # TODO: investigate: in the (b)float16 cases, TF and lax both return the same # result in undefined cases. if not dtype in [np.float16, dtypes.bfloat16]: # erf_inv is not defined for arg <= -1 or arg >= 1 def custom_assert(result_jax, result_tf): # noqa: F811 # for arg < -1 or arg > 1 # lax.erf_inv returns NaN; tf.math.erf_inv return +/- inf special_cases = (arg < -1.) | (arg > 1.) nr_special_cases = np.count_nonzero(special_cases) self.assertAllClose(np.full((nr_special_cases,), dtype(np.nan)), result_jax[special_cases]) signs = np.where(arg[special_cases] < 0., -1., 1.) self.assertAllClose(np.full((nr_special_cases,), signs * dtype(np.inf)), result_tf[special_cases]) # non-special cases are equal self.assertAllClose(result_jax[~ special_cases], result_tf[~ special_cases]) atol = None if jtu.device_under_test() == "gpu": # TODO(necula): revisit once we fix the GPU tests atol = 1e-3 self.ConvertAndCompare(harness.dyn_fun, arg, custom_assert=custom_assert, atol=atol)
def test_betainc(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] # TODO: https://www.tensorflow.org/api_docs/python/tf/math/betainc only # supports float32/64 tests. # TODO(bchetioui): investigate why the test actually fails in JAX. if dtype in [np.float16, dtypes.bfloat16]: raise unittest.SkipTest("(b)float16 not implemented in TF") tol = None if dtype is np.float64: tol = 1e-14 self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=tol, rtol=tol)
def test_conv_general_dilated(self, harness: primitive_harness.Harness): dtype, device = harness.params["dtype"], jtu.device_under_test() if device == "gpu" and dtype in [np.complex64, np.complex128]: raise unittest.SkipTest("TODO: crash on GPU in TF") tol = None if device == "gpu": tol = 1e-4 elif device == "tpu": tol = 1e-3 # TODO(bchetioui): significant discrepancies in some float16 cases. if dtype == np.float16: tol = 1. # TODO(bchetioui): slight occasional discrepancy in float32 cases. elif dtype == np.float32: tol = 0.5 if device == "tpu" else ( 1e-3 if device == "gpu" else 1e-4) elif dtype == np.complex64 and device == "tpu": tol = 0.1 # TODO(bchetioui): slight discrepancy when going through the path using # tf.nn.convolution. elif dtype == np.float64 and device == "cpu": tol = 1e-13 # TODO(bchetioui): unidentified bug in compiled mode. The test that fails is # # test_conv_general_dilated_tf_conversion_path_3d_lhs=float32[1,4,28,28,1]_rhs=float32[2,3,3,1,16]_windowstrides=(1,1,1)_padding=VALID_lhsdilation=(1,1,1)_rhsdilation=(1,1,2)_dimensionnumbers=('NDHWC','DHWIO','NDHWC')_featuregroupcount=1_batchgroupcount=1_precision=None_enablexla=False # # with the following assertion error in TensorFlowTrace.process_primitive: # # AssertionError: conv_general_dilated: out.aval = ShapedArray(float32[1,3,24,26,16]); expected ShapedArray(float32[1,3,26,24,16]) # # Deactivating this assertion is enough to pass the test, which suggests # that the end shape is indeed the correct one (i.e. (1,3,26,24,16)). # Further investigation is required to really understand this behavior, # which we have not managed to reproduce as a pure TF test. # # This bug is low priority since it only occurs when using a non-TFXLA # conversion path in compiled mode, i.e. in a context where using the # TFXLA path is possible. if harness.name == "_tf_conversion_path_3d_lhs=float32[1,4,28,28,1]_rhs=float32[2,3,3,1,16]_windowstrides=(1,1,1)_padding=VALID_lhsdilation=(1,1,1)_rhsdilation=(1,1,2)_dimensionnumbers=('NDHWC','DHWIO','NDHWC')_featuregroupcount=1_batchgroupcount=1_precision=None_enablexla=False": raise unittest.SkipTest( "TODO: known but unidentified bug in compiled " "mode") self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=tol, rtol=tol, enable_xla=harness.params["enable_xla"])
def test_conv_general_dilated(self, harness: primitive_harness.Harness): if jtu.device_under_test() == "gpu": raise unittest.SkipTest("TODO: test failures on GPU") tol = None # TODO(bchetioui): significant discrepancies in some float16 cases. if harness.params["dtype"] is np.float16: tol = 1. # TODO(bchetioui): slight occasional discrepancy in float32 cases. elif harness.params["dtype"] is np.float32: tol = 1e-5 self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=tol, rtol=tol)
def test_top_k(self, harness: primitive_harness.Harness): if (harness.params["k"] > harness.params["shape"][-1] or harness.params["k"] < 0): with self.assertRaisesRegex(ValueError, "k argument to top_k must be"): harness.dyn_fun(*harness.dyn_args_maker(self.rng())) # TODO: figure out what's up with bfloat16 elif harness.params["dtype"] is dtypes.bfloat16: raise unittest.SkipTest("bfloat16 support not implemented") elif harness.params["dtype"] in jtu.dtypes.complex: # TODO(necula): fix top_k complex bug on TPU if jtu.device_under_test() == "tpu": raise unittest.SkipTest( "top_k complex on TPU raises different error") with self.assertRaisesRegex(RuntimeError, "Unimplemented: complex comparison"): harness.dyn_fun(*harness.dyn_args_maker(self.rng())) # TODO: TF and JAX sort [inf, nan] differently. elif harness.name.startswith("nan_"): raise unittest.SkipTest("inconsistent [nan, inf] sorting") else: self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
def test_qr(self, harness: primitive_harness.Harness): # See jax.lib.lapack.geqrf for the list of compatible types if harness.params["dtype"] in [jnp.float32, jnp.float64]: self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=1e-5, rtol=1e-5) elif harness.params["dtype"] in [jnp.complex64, jnp.complex128]: # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824. # - check_compiled=True breaks for complex types; # - for now, the performance of the HLO QR implementation called when # compiling with TF is expected to have worse performance than the # custom calls made in JAX. self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), expect_tf_exceptions=True, atol=1e-5, rtol=1e-5) else: expected_error = ValueError if jtu.device_under_test( ) == "gpu" else NotImplementedError with self.assertRaisesRegex(expected_error, "Unsupported dtype"): harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
def test_select_and_gather_add(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] max_bits = 64 if jtu.device_under_test() == "tpu": max_bits = 32 expect_tf_exceptions = False if dtypes.finfo(dtype).bits * 2 > max_bits: # TODO: getting an exception "XLA encountered an HLO for which this rewriting is not implemented" expect_tf_exceptions = True self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), expect_tf_exceptions=expect_tf_exceptions)
def test_add_mul(self, harness: primitive_harness.Harness): expect_tf_exceptions = False dtype = harness.params["dtype"] f_name = harness.params["f_jax"].__name__ if dtype in [np.uint32, np.uint64]: # TODO(bchetioui): tf.math.multiply is not defined for the above types. expect_tf_exceptions = True elif dtype is np.uint16 and f_name == "add": # TODO(bchetioui): tf.math.add is defined for the same types as multiply, # except uint16. expect_tf_exceptions = True self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), expect_tf_exceptions=expect_tf_exceptions)
def test_top_k(self, harness: primitive_harness.Harness): custom_assert = None k, dtype = harness.params["k"], harness.params["dtype"] if k > harness.params["shape"][-1] or k < 0: with self.assertRaisesRegex(ValueError, "k argument to top_k must be"): harness.dyn_fun(*harness.dyn_args_maker(self.rng())) return if dtype in jtu.dtypes.complex: # TODO(necula): fix top_k complex bug on TPU if jtu.device_under_test() == "tpu": raise unittest.SkipTest( "top_k complex on TPU raises different error") with self.assertRaisesRegex(RuntimeError, "Unimplemented: complex comparison"): harness.dyn_fun(*harness.dyn_args_maker(self.rng())) return if dtype in jtu.dtypes.all_inexact: def custom_assert(result_jax, result_tf): assert len(result_jax) == len(result_tf) # TODO: TF and JAX sort [inf, nan] differently. first_arr_jax, first_arr_tf = result_jax[0], result_tf[ 0].numpy() if np.all(first_arr_jax == first_arr_tf): for arr_jax, arr_tf in zip(result_jax, result_tf): self.assertArraysEqual(arr_jax, arr_tf) else: mask_jax, mask_tf = np.isnan(first_arr_jax), np.isnan( first_arr_tf) self.assertArraysEqual(first_arr_jax[~mask_jax], first_arr_tf[~mask_tf]) self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), custom_assert=custom_assert)
def test_min_max(self, harness: primitive_harness.Harness): expect_tf_exceptions = False dtype = harness.params["dtype"] if dtype in [ np.bool_, np.int8, np.uint16, np.uint32, np.uint64, np.complex64, np.complex128 ]: # TODO(bchetioui): tf.math.maximum and tf.math.minimum are not defined for # the above types. expect_tf_exceptions = True self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), expect_tf_exceptions=expect_tf_exceptions)
def test_dot_general(self, harness: primitive_harness.Harness): tol, dtype = None, harness.params["dtype"] if dtype == dtypes.bfloat16: tol = 0.3 elif dtype in [np.complex64, np.float32]: if jtu.device_under_test() == "tpu": tol = 0.1 if dtype == np.float32 else 0.3 else: tol = 1e-5 elif dtype == np.float16: if jtu.device_under_test() == "gpu": tol = 0.1 else: tol = 0.01 self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=tol, rtol=tol)
def test_lu(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] if dtype in [np.float16, dtypes.bfloat16]: raise unittest.SkipTest( f"LU is not implemented in JAX for dtype {dtype}.") tol = None if dtype in [np.float32, np.complex64]: if jtu.device_under_test() == "tpu": tol = 0.1 else: tol = 1e-5 if dtype in [np.float64, np.complex128]: tol = 1e-13 operand, = harness.dyn_args_maker(self.rng()) def custom_assert(result_jax, result_tf): lu, pivots, perm = tuple(map(lambda t: t.numpy(), result_tf)) batch_dims = operand.shape[:-2] m, n = operand.shape[-2], operand.shape[-1] def _make_permutation_matrix(perm): result = [] for idx in itertools.product(*map(range, operand.shape[:-1])): result += [0 if c != perm[idx] else 1 for c in range(m)] result = np.reshape(np.array(result, dtype=dtype), [*batch_dims, m, m]) return result k = min(m, n) l = jnp.tril(lu, -1)[..., :, :k] + jnp.eye(m, k, dtype=dtype) u = jnp.triu(lu)[..., :k, :] p_mat = _make_permutation_matrix(perm) self.assertArraysEqual( lax_linalg.lu_pivots_to_permutation(pivots, m), perm) self.assertAllClose(jnp.matmul(p_mat, operand), jnp.matmul(l, u), atol=tol, rtol=tol) self.ConvertAndCompare(harness.dyn_fun, operand, atol=tol, rtol=tol, custom_assert=custom_assert, always_custom_assert=True)
def test_cumreduce(self, harness: primitive_harness.Harness): f_jax, dtype = harness.params["f_jax"], harness.params["dtype"] dut = jtu.device_under_test() if (dtype == np.complex64 and f_jax in [ lax_control_flow.cummin, lax_control_flow.cummax, lax_control_flow.cumprod, lax_control_flow.cumsum ] and dut == "tpu"): raise unittest.SkipTest( "TODO(bchetioui): cum{min,max,prod,sum} fails " "in JAX for complex64 on TPU") tol = None if f_jax == lax_control_flow.cumsum: tol = 0.1 if dtype == np.float16 else ( 0.5 if dtype == dtypes.bfloat16 else tol) self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=tol, rtol=tol)
def test_reduce_window(self, harness: primitive_harness.Harness): computation = harness.params['computation'].__name__ init_value = harness.params['init_value'] dtype = harness.params['dtype'] safe_computations = [('sum', dtype(0))] if dtype in jtu.dtypes.all_floating: # Only in this case, np.inf can be casted safely to a meaningful value. safe_computations += [('max', dtype(-np.inf)), ('min', dtype(np.inf))] if (computation, init_value) not in safe_computations: raise unittest.SkipTest( 'TODO: only specific instances of max/min/sum are supported for now.' ) self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
def test_jax_implemented(self, harness: primitive_harness.Harness): """Runs all harnesses just with JAX to verify the jax_unimplemented field. """ jax_unimpl = [l for l in harness.jax_unimplemented if l.filter(jtu.device_under_test())] try: harness.dyn_fun(*harness.dyn_args_maker(self.rng())) except Exception as e: if jax_unimpl: logging.info( f"Found expected JAX error {e} with expected JAX limitations: " f"{[u.description for u in jax_unimpl]} in harness {harness.fullname}") return else: raise e if jax_unimpl: msg = ("Found no JAX error but expected JAX limitations: " f"{[u.description for u in jax_unimpl]} in harness: {harness.fullname}") logging.warning(msg)
def test_prim_vmap(self, harness: primitive_harness.Harness): if harness.group_name in _VMAP_NOT_POLY_YET: raise unittest.SkipTest( f"TODO: vmap({harness.group_name}) not yet supported") func_jax = harness.dyn_fun args = harness.dyn_args_maker(self.rng()) if len(args) == 0: # vmap not defined for functions with no args return res_jax = func_jax(*args) # Replicate all arguments batch_size = 3 batched_args = [np.stack([a] * batch_size) for a in args] func_jax_vmap = jax.vmap(func_jax, in_axes=0, out_axes=0) # Check that batching works res_jax_vmap = func_jax_vmap(*batched_args) def arr_to_shape_spec(a): return "b, " + ", ".join(str(d) for d in a.shape) func_jax_vmap_polymorphic_shapes = jax.tree_map( arr_to_shape_spec, tuple(args)) def arr_to_tf_tensor_spec(a): return tf.TensorSpec((None, ) + a.shape, a.dtype) func_jax_vmap_input_signature = jax.tree_map(arr_to_tf_tensor_spec, tuple(args)) func_jax_vmap_output_signature = jax.tree_map(arr_to_tf_tensor_spec, res_jax) f_tf = self.CheckShapePolymorphism( func_jax_vmap, input_signature=func_jax_vmap_input_signature, polymorphic_shapes=func_jax_vmap_polymorphic_shapes, expected_output_signature=func_jax_vmap_output_signature) limitations = _get_jax2tf_limitations(jtu.device_under_test(), harness) if any([l.custom_assert or l.skip_comparison for l in limitations]): self.assertAllClose(res_jax_vmap, f_tf(*batched_args))
def test_prim(self, harness: Harness): args = harness.dyn_args_maker(self.rng()) poly_axes = harness.params["poly_axes"] # type: Sequence[Sequence[int]] assert len(args) == len(poly_axes) # Make the polymorphic_shapes and input_signature polymorphic_shapes: List[Optional[str]] = [] input_signature: List[tf.TensorSpec] = [] for arg, poly_axis in zip(args, poly_axes): if poly_axis is None: polymorphic_shapes.append(None) input_signature.append(tf.TensorSpec(np.shape(arg), arg.dtype)) else: def make_arg_polymorphic_shapes(poly_axis: Sequence[int]) -> Tuple[str, tf.TensorSpec]: idx = -1 dims = [] tensorspec_dims: List[Optional[int]] = [] for i, d in enumerate(arg.shape): if i in poly_axis: idx += 1 dims.append(f"b{idx}") tensorspec_dims.append(None) else: dims.append(str(d)) tensorspec_dims.append(d) return ", ".join(dims), tf.TensorSpec(tensorspec_dims, arg.dtype) arg_polymorphic_shapes, arg_tensorspec = make_arg_polymorphic_shapes(poly_axis) polymorphic_shapes.append(arg_polymorphic_shapes) input_signature.append(arg_tensorspec) res_jax = harness.dyn_fun(*args) f_tf = self.CheckShapePolymorphism( harness.dyn_fun, input_signature=input_signature, polymorphic_shapes=polymorphic_shapes, expected_output_signature=None) if harness.params["check_result"]: tol = harness.params["tol"] self.assertAllClose(res_jax, f_tf(*args), atol=tol, rtol=tol)
def test_sort(self, harness: primitive_harness.Harness): if harness.params["dtype"] is dtypes.bfloat16 or harness.params["dtype"] in jtu.dtypes.complex: # TODO: implement bfloat16/complex support in XlaSort raise unittest.SkipTest("bfloat16/complex support not implemented") if harness.params["dtype"] is dtypes.bool_ and len(harness.arg_descriptors) == 4: # TODO: _sort uses tfxla.key_value_sort to handle 2 operandes, but the operation is not compatible with boolean keys. raise unittest.SkipTest("boolean key key value sort not implemented") if harness.params["is_stable"]: # TODO: implement stable sort support in XlaSort raise unittest.SkipTest("stable sort not implemented") if harness.params["dimension"] != len(harness.params["shape"]) - 1: # TODO: implement sort on all axes raise unittest.SkipTest("conversion not implemented for axis != -1") if len(harness.arg_descriptors) > 4: # TODO: implement variable number of operands to XlaSort raise unittest.SkipTest("conversion not implemented for #operands > 2") if (jtu.device_under_test() == "gpu" and len(harness.arg_descriptors) == 4 and not harness.params["is_stable"]): # TODO: fix the TF GPU test raise unittest.SkipTest("GPU tests are running TF on CPU") self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
def test_cholesky(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] if dtype in [dtypes.bfloat16, np.float16]: raise unittest.SkipTest("Cholesky decomposition not supported for " "(b)float16 in JAX.") operand = harness.dyn_args_maker(self.rng())[0] operand = np.matmul(operand, jnp.conj(np.swapaxes(operand, -1, -2))) tol = None # TODO(bchetioui): very high discrepancy in the float32/complex64 case if dtype in [np.float32, np.complex64]: tol = 1e-2 # TODO(bchetioui): also high discrepancy in the float64/complex128 case elif dtype in [np.float64, np.complex128]: tol = 1e-11 def custom_assert(result_jax, result_tf): # cholesky_p returns garbage in the strictly upper triangular part of the # result, so we can safely ignore that part. self.assertAllClose(jnp.tril(result_jax), result_tf, atol=tol) self.ConvertAndCompare(harness.dyn_fun, operand, custom_assert=custom_assert, always_custom_assert=True)
def test_reduce_window(self, harness: primitive_harness.Harness): f_name = harness.params['computation'].__name__ dtype = harness.params['dtype'] expect_tf_exceptions = False if (jtu.device_under_test() == 'tpu' and dtype is np.complex64): raise unittest.SkipTest( 'TODO: JAX reduce_window on TPU does not handle complex64') if ((f_name == 'min' or f_name == 'max') and dtype not in [ dtypes.bfloat16, np.float16, np.float32, np.float64, np.uint8, np.int16, np.int32, np.int64 ]): # See https://www.tensorflow.org/api_docs/python/tf/math/minimum for a list of # the types supported by tf.math.minimum/tf.math.maximum. expect_tf_exceptions = True elif (f_name == 'add' and dtype not in [ dtypes.bfloat16, np.float16, np.float32, np.float64, np.uint8, np.int8, np.int16, np.int32, np.int64, np.complex64, np.complex128 ]): # See https://www.tensorflow.org/api_docs/python/tf/math/add for a list of the # types supported by tf.math.add. expect_tf_exceptions = True elif (f_name == 'mul' and dtype not in [ dtypes.bfloat16, np.float16, np.float32, np.float64, np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64, np.complex64, np.complex128 ]): # See https://www.tensorflow.org/api_docs/python/tf/math/multiply for a list of # the types supported by tf.math.multiply. expect_tf_exceptions = True self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), expect_tf_exceptions=expect_tf_exceptions)
def test_conv_general_dilated(self, harness: primitive_harness.Harness): if jtu.device_under_test() == "gpu" and harness.params["dtype"] in [ np.complex64, np.complex128 ]: raise unittest.SkipTest("TODO: crash on GPU in TF") tol = None if jtu.device_under_test() == "gpu": tol = 1e-4 elif jtu.device_under_test() == "tpu": tol = 1e-3 # TODO(bchetioui): significant discrepancies in some float16 cases. if harness.params["dtype"] == np.float16: tol = 1. # TODO(bchetioui): slight occasional discrepancy in float32 cases. elif harness.params["dtype"] == np.float32: tol = 0.5 if jtu.device_under_test() == "tpu" else 1e-4 elif harness.params["dtype"] == np.complex64 and jtu.device_under_test( ) == "tpu": tol = 0.1 self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=tol, rtol=tol)
def test_qr(self, harness: primitive_harness.Harness): # See jax.lib.lapack.geqrf for the list of compatible types dtype = harness.params["dtype"] dut = jtu.device_under_test() # These cases are not implemented in JAX if dtype in (jtu.dtypes.all_integer + [jnp.bfloat16]): unimplemented_jax = True elif dtype is np.complex64 and dut == "tpu": unimplemented_jax = True elif dtype is np.float16 and dut in ("cpu", "gpu"): unimplemented_jax = True else: unimplemented_jax = False if unimplemented_jax: raise unittest.SkipTest(f"QR not implemented in JAX for {dtype} on {dut}") # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824. # - for now, the performance of the HLO QR implementation called when # compiling with TF is expected to have worse performance than the # custom calls made in JAX. self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=1e-5, rtol=1e-5)
def test_conv_general_dilated(self, harness: primitive_harness.Harness): dtype, device = harness.params["dtype"], jtu.device_under_test() if device == "gpu" and dtype in [np.complex64, np.complex128]: raise unittest.SkipTest("TODO: crash on GPU in TF") tol = None if device == "gpu": tol = 1e-4 elif device == "tpu": tol = 1e-3 # TODO(bchetioui): significant discrepancies in some float16 cases. if dtype == np.float16: tol = 1. # TODO(bchetioui): slight occasional discrepancy in float32 cases. elif dtype == np.float32: tol = 0.5 if device == "tpu" else (1e-3 if device == "gpu" else 1e-4) elif dtype == np.complex64 and device == "tpu": tol = 0.1 # TODO(bchetioui): slight discrepancy when going through the path using # tf.nn.convolution. elif dtype == np.float64 and device == "cpu": tol = 1e-13 self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=tol, rtol=tol)