def test_fft(self, harness: primitive_harness.Harness): if len(harness.params["fft_lengths"]) > 3: if jtu.device_under_test() == "gpu": with self.assertRaisesRegex(RuntimeError, "FFT only supports ranks 1-3"): harness.dyn_fun(*harness.dyn_args_maker(self.rng())) else: raise unittest.SkipTest("TF does not support >3D FFTs.") elif (jtu.device_under_test() == "tpu" and len(harness.params["fft_lengths"]) > 1): # TODO(b/140351181): FFT is mostly unimplemented on TPU, even for JAX with self.assertRaisesRegex(RuntimeError, "only 1D FFT is currently supported."): harness.dyn_fun(*harness.dyn_args_maker(self.rng())) else: tol = None if jtu.device_under_test() in ("cpu", "gpu"): if harness.params["dtype"] in jtu.dtypes.boolean: tol = 0.01 else: tol = 1e-3 self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=tol, rtol=tol)
def test_reduce_window(self, harness: primitive_harness.Harness): dtype = harness.params['dtype'] if (jtu.device_under_test() == 'tpu' and dtype is np.complex64): raise unittest.SkipTest( 'TODO: JAX reduce_window on TPU does not handle complex64') self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
def test_pad(self, harness: primitive_harness.Harness): if harness.params["dtype"] is dtypes.bfloat16: raise unittest.SkipTest("bfloat16 not implemented") # TODO: implement (or decide not to) pads with negative edge padding if any([lo < 0 or hi < 0 for lo, hi, mid in harness.params["pads"]]): raise unittest.SkipTest("pad with negative pad not supported") self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), with_function=True)
def test_prim(self, harness: primitive_harness.Harness): limitations = Jax2TfLimitation.limitations_for_harness(harness) device = jtu.device_under_test() limitations = tuple( filter(lambda l: l.filter(device=device, dtype=harness.dtype), limitations)) func_jax = harness.dyn_fun args = harness.dyn_args_maker(self.rng()) self.ConvertAndCompare(func_jax, *args, limitations=limitations)
def test_pad(self, harness: primitive_harness.Harness): # TODO: figure out the bfloat16 story if harness.params["dtype"] is dtypes.bfloat16: raise unittest.SkipTest("bfloat16 not implemented") # TODO: fix pad with negative padding in XLA (fixed on 06/16/2020) if any([lo < 0 or hi < 0 for lo, hi, mid in harness.params["pads"]]): raise unittest.SkipTest("pad with negative pad not supported") self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
def test_eig(self, harness: primitive_harness.Harness): operand = harness.dyn_args_maker(self.rng())[0] compute_left_eigenvectors = harness.params["compute_left_eigenvectors"] compute_right_eigenvectors = harness.params["compute_right_eigenvectors"] dtype = harness.params["dtype"] if jtu.device_under_test() != "cpu": raise unittest.SkipTest("eig only supported on CPU in JAX") if dtype in [np.float16, dtypes.bfloat16]: raise unittest.SkipTest("eig unsupported with (b)float16 in JAX") def custom_assert(result_jax, result_tf): result_tf = tuple(map(lambda e: e.numpy(), result_tf)) inner_dimension = operand.shape[-1] # Test ported from tests.lax_test.testEig # Norm, adjusted for dimension and type. def norm(x): norm = np.linalg.norm(x, axis=(-2, -1)) return norm / ((inner_dimension + 1) * jnp.finfo(dtype).eps) def check_right_eigenvectors(a, w, vr): self.assertTrue( np.all(norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100)) def check_left_eigenvectors(a, w, vl): rank = len(a.shape) aH = jnp.conj(a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2])) wC = jnp.conj(w) check_right_eigenvectors(aH, wC, vl) def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array): tol = None # TODO(bchetioui): numerical discrepancies if dtype in [np.float32, np.complex64]: tol = 1e-4 elif dtype in [np.float64, np.complex128]: tol = 1e-13 closest_diff = min(abs(eigenvalues_array - eigenvalue)) self.assertAllClose(closest_diff, np.array(0., closest_diff.dtype), atol=tol) all_w_jax, all_w_tf = result_jax[0], result_tf[0] for idx in itertools.product(*map(range, operand.shape[:-2])): w_jax, w_tf = all_w_jax[idx], all_w_tf[idx] for i in range(inner_dimension): check_eigenvalue_is_in_array(w_jax[i], w_tf) check_eigenvalue_is_in_array(w_tf[i], w_jax) if compute_left_eigenvectors: check_left_eigenvectors(operand, all_w_tf, result_tf[1]) if compute_right_eigenvectors: check_right_eigenvectors(operand, all_w_tf, result_tf[1 + compute_left_eigenvectors]) self.ConvertAndCompare(harness.dyn_fun, operand, custom_assert=custom_assert)
def test_sort(self, harness: primitive_harness.Harness): if (jtu.device_under_test() == "gpu" and len(harness.arg_descriptors) == 4 and not harness.params["is_stable"]): # TODO: fix the TF GPU test raise unittest.SkipTest("GPU tests are running TF on CPU") if jtu.device_under_test() == "tpu" and harness.params["dtype"] in jtu.dtypes.complex: raise unittest.SkipTest("JAX sort is not implemented on TPU for complex") self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
def test_scatter(self, harness: primitive_harness.Harness): f_name = harness.params['f_lax'].__name__ dtype = harness.params['dtype'] if jtu.device_under_test() == 'tpu': if dtype is np.complex64 and f_name in ['scatter_min', 'scatter_max']: raise unittest.SkipTest(f"TODO: complex {f_name} on TPU fails in JAX") self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
def test_linear_solve(self, harness: primitive_harness.Harness): a, b = harness.dyn_args_maker(self.rng()) if harness.params["symmetric"]: a = a + a.T tol = None if (harness.params["dtype"] == np.float32 and jtu.device_under_test() == "tpu"): tol = 0.01 self.ConvertAndCompare(harness.dyn_fun, a, b, atol=tol, rtol=tol)
def test_triangular_solve(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] if dtype == np.float16 and jtu.device_under_test() == "gpu": raise unittest.SkipTest( f"Triangular solve is not implemented in JAX for dtype {dtype}") atol = rtol = None if dtype == np.float32: atol = rtol = 1e-5 self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=atol, rtol=rtol)
def test_min_max(self, harness: primitive_harness.Harness): # TODO(bchetioui): discrepancies between TF & JAX when comparing with NaN; # JAX always returns NaN, while TF returns the value NaN is compared with. def custom_assert(result_jax, result_tf): mask = np.isnan(result_jax) self.assertAllClose(result_jax[~ mask], result_tf[~ mask]) # TODO(bchetioui): figure out why we need always_custom_assert=True always_custom_assert = True self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), custom_assert=custom_assert, always_custom_assert=always_custom_assert)
def test_div(self, harness: primitive_harness.Harness): dividend, divisor = harness.dyn_args_maker(self.rng()) prim = harness.params["prim"] if dtypes.issubdtype(dividend.dtype, np.integer): if (prim is lax.div_p and np.any(divisor == np.array(0, dtype=divisor.dtype))): raise unittest.SkipTest( "Divisor contains a 0, and TF returns an error value in compiled " "mode instead of failing like in eager and graph mode for dtype " f"{divisor.dtype}") self.ConvertAndCompare(harness.dyn_fun, dividend, divisor)
def test_jax_implemented(self, harness: primitive_harness.Harness): """Runs all harnesses just with JAX to verify the jax_unimplemented field. """ jax_unimpl = [l for l in harness.jax_unimplemented if l.filter(jtu.device_under_test())] try: harness.dyn_fun(*harness.dyn_args_maker(self.rng())) except Exception as e: if jax_unimpl: logging.info( f"Found expected JAX error {e} with expected JAX limitations: " f"{[u.description for u in jax_unimpl]} in harness {harness.fullname}") return else: raise e if jax_unimpl: msg = ("Found no JAX error but expected JAX limitations: " f"{[u.description for u in jax_unimpl]} in harness: {harness.fullname}") logging.warning(msg)
def test_svd(self, harness: primitive_harness.Harness): if jtu.device_under_test() == "tpu": raise unittest.SkipTest("TODO: test crashes the XLA compiler for some TPU variants") expect_tf_exceptions = False if harness.params["dtype"] in [jnp.float16, dtypes.bfloat16]: if jtu.device_under_test() == "tpu": # TODO: SVD on TPU for bfloat16 seems to work for JAX but fails for TF expect_tf_exceptions = True else: # Does not work in JAX with self.assertRaisesRegex(NotImplementedError, "Unsupported dtype"): harness.dyn_fun(*harness.dyn_args_maker(self.rng())) return if harness.params["dtype"] in [jnp.complex64, jnp.complex128]: if jtu.device_under_test() == "tpu": # TODO: on JAX on TPU there is no SVD implementation for complex with self.assertRaisesRegex(RuntimeError, "Binary op compare with different element types"): harness.dyn_fun(*harness.dyn_args_maker(self.rng())) return else: # TODO: on CPU and GPU "No registered 'Svd' OpKernel for XLA_CPU_JIT devices". # Works on JAX because JAX uses a custom implementation. expect_tf_exceptions = True def _custom_assert(r_jax, r_tf, atol=1e-6, rtol=1e-6): def _reconstruct_operand(result, is_tf: bool): # Reconstructing operand as documented in numpy.linalg.svd (see # https://numpy.org/doc/stable/reference/generated/numpy.linalg.svd.html) s, u, v = result if is_tf: s = s.numpy() u = u.numpy() v = v.numpy() U = u[..., :s.shape[-1]] V = v[..., :s.shape[-1], :] S = s[..., None, :] return jnp.matmul(U * S, V), s.shape, u.shape, v.shape if harness.params["compute_uv"]: r_jax_reconstructed = _reconstruct_operand(r_jax, False) r_tf_reconstructed = _reconstruct_operand(r_tf, True) self.assertAllClose(r_jax_reconstructed, r_tf_reconstructed, atol=atol, rtol=rtol) else: self.assertAllClose(r_jax, r_tf, atol=atol, rtol=rtol) tol = 1e-4 custom_assert = partial(_custom_assert, atol=tol, rtol=tol) self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=tol, rtol=tol, expect_tf_exceptions=expect_tf_exceptions, custom_assert=custom_assert, always_custom_assert=True)
def test_unary_elementwise(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] lax_name = harness.params["lax_name"] arg, = harness.dyn_args_maker(self.rng()) custom_assert = None if lax_name == "digamma": # TODO(necula): fix bug with digamma/(f32|f16) on TPU if dtype in [np.float16, np.float32] and jtu.device_under_test() == "tpu": raise unittest.SkipTest("TODO: fix bug: nan vs not-nan") # In the bfloat16 case, TF and lax both return NaN in undefined cases. if not dtype is dtypes.bfloat16: # digamma is not defined at 0 and -1 def custom_assert(result_jax, result_tf): # lax.digamma returns NaN and tf.math.digamma returns inf special_cases = (arg == 0.) | (arg == -1.) nr_special_cases = np.count_nonzero(special_cases) self.assertAllClose(np.full((nr_special_cases,), dtype(np.nan)), result_jax[special_cases]) self.assertAllClose(np.full((nr_special_cases,), dtype(np.inf)), result_tf[special_cases]) # non-special cases are equal self.assertAllClose(result_jax[~ special_cases], result_tf[~ special_cases]) if lax_name == "erf_inv": # TODO(necula): fix erf_inv bug on TPU if jtu.device_under_test() == "tpu": raise unittest.SkipTest("erf_inv bug on TPU: nan vs non-nan") # TODO: investigate: in the (b)float16 cases, TF and lax both return the # same result in undefined cases. if not dtype in [np.float16, dtypes.bfloat16]: # erf_inv is not defined for arg <= -1 or arg >= 1 def custom_assert(result_jax, result_tf): # noqa: F811 # for arg < -1 or arg > 1 # lax.erf_inv returns NaN; tf.math.erf_inv return +/- inf special_cases = (arg < -1.) | (arg > 1.) nr_special_cases = np.count_nonzero(special_cases) self.assertAllClose(np.full((nr_special_cases,), dtype(np.nan), dtype=dtype), result_jax[special_cases]) signs = np.where(arg[special_cases] < 0., -1., 1.) self.assertAllClose(np.full((nr_special_cases,), signs * dtype(np.inf), dtype=dtype), result_tf[special_cases]) # non-special cases are equal self.assertAllClose(result_jax[~ special_cases], result_tf[~ special_cases]) atol = None if jtu.device_under_test() == "gpu": # TODO(necula): revisit once we fix the GPU tests atol = 1e-3 self.ConvertAndCompare(harness.dyn_fun, arg, custom_assert=custom_assert, atol=atol)
def test_prim(self, harness: primitive_harness.Harness): limitations = Jax2TfLimitation.limitations_for_harness(harness) device = jtu.device_under_test() limitations = tuple(filter(lambda l: l.filter(device=device, dtype=harness.dtype), limitations)) func_jax = harness.dyn_fun args = harness.dyn_args_maker(self.rng()) enable_xla = harness.params.get("enable_xla", True) associative_scan_reductions = harness.params.get("associative_scan_reductions", False) with jax.jax2tf_associative_scan_reductions(associative_scan_reductions): self.ConvertAndCompare(func_jax, *args, limitations=limitations, enable_xla=enable_xla)
def test_betainc(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] # TODO: https://www.tensorflow.org/api_docs/python/tf/math/betainc only # supports float32/64 tests. tol = None if dtype is np.float64: tol = 1e-14 self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=tol, rtol=tol)
def test_prim(self, harness: Harness): args = harness.dyn_args_maker(self.rng()) poly_axes = harness.params["poly_axes"] # type: Sequence[Sequence[int]] assert len(args) == len(poly_axes) # Make the polymorphic_shapes and input_signature polymorphic_shapes: List[Optional[str]] = [] input_signature: List[tf.TensorSpec] = [] for arg, poly_axis in zip(args, poly_axes): if poly_axis is None: polymorphic_shapes.append(None) input_signature.append(tf.TensorSpec(np.shape(arg), arg.dtype)) else: def make_arg_polymorphic_shapes(poly_axis: Sequence[int]) -> Tuple[str, tf.TensorSpec]: idx = -1 dims = [] tensorspec_dims: List[Optional[int]] = [] for i, d in enumerate(arg.shape): if i in poly_axis: idx += 1 dims.append(f"b{idx}") tensorspec_dims.append(None) else: dims.append(str(d)) tensorspec_dims.append(d) return ", ".join(dims), tf.TensorSpec(tensorspec_dims, arg.dtype) arg_polymorphic_shapes, arg_tensorspec = make_arg_polymorphic_shapes(poly_axis) polymorphic_shapes.append(arg_polymorphic_shapes) input_signature.append(arg_tensorspec) res_jax = harness.dyn_fun(*args) f_tf = self.CheckShapePolymorphism( harness.dyn_fun, input_signature=input_signature, polymorphic_shapes=polymorphic_shapes, expected_output_signature=None) if harness.params["check_result"]: tol = harness.params["tol"] self.assertAllClose(res_jax, f_tf(*args), atol=tol, rtol=tol)
def test_qr(self, harness: primitive_harness.Harness): # See jax.lib.lapack.geqrf for the list of compatible types if (harness.params["dtype"] in [jnp.float32, jnp.float64] or harness.params["dtype"] == jnp.float16 and jtu.device_under_test() == "tpu"): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=1e-5, rtol=1e-5) elif harness.params["dtype"] in [jnp.complex64, jnp.complex128]: if (jtu.device_under_test() == "tpu" and harness.params["dtype"] in [jnp.complex64]): raise unittest.SkipTest("QR for c64 not implemented on TPU") # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824. # - check_compiled=True breaks for complex types; # - for now, the performance of the HLO QR implementation called when # compiling with TF is expected to have worse performance than the # custom calls made in JAX. self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), expect_tf_exceptions=True, atol=1e-5, rtol=1e-5) else: # TODO(necula): fix QR bug on TPU if (jtu.device_under_test() == "tpu" and harness.params["dtype"] in (jnp.bfloat16, jnp.int32, jnp.uint32)): raise unittest.SkipTest( "QR bug on TPU for certain types: error not raised") if (jtu.device_under_test() == "tpu" and harness.params["dtype"] in (jnp.bool_, )): raise unittest.SkipTest( "QR bug on TPU for certain types: invalid cast") expected_error = ValueError if jtu.device_under_test( ) == "gpu" else NotImplementedError with self.assertRaisesRegex(expected_error, "Unsupported dtype"): harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
def test_unary_elementwise(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] if dtype is dtypes.bfloat16: raise unittest.SkipTest("bfloat16 not implemented") arg, = harness.dyn_args_maker(self.rng()) custom_assert = None if harness.params["lax_name"] == "digamma": # digamma is not defined at 0 and -1 def custom_assert(result_jax, result_tf): # lax.digamma returns NaN and tf.math.digamma returns inf special_cases = (arg == 0.) | (arg == -1.) nr_special_cases = np.count_nonzero(special_cases) self.assertAllClose( np.full((nr_special_cases, ), dtype(np.nan)), result_jax[special_cases]) self.assertAllClose( np.full((nr_special_cases, ), dtype(np.inf)), result_tf[special_cases]) # non-special cases are equal self.assertAllClose(result_jax[~special_cases], result_tf[~special_cases]) if harness.params["lax_name"] == "erf_inv": # TODO(necula): fix bug with erf_inv/f16 if dtype is np.float16: raise unittest.SkipTest("TODO: fix bug") # erf_inf is not defined for arg <= -1 or arg >= 1 def custom_assert(result_jax, result_tf): # noqa: F811 # for arg < -1 or arg > 1 # lax.erf_inf returns NaN; tf.math.erf_inv return +/- inf special_cases = (arg < -1.) | (arg > 1.) nr_special_cases = np.count_nonzero(special_cases) self.assertAllClose( np.full((nr_special_cases, ), dtype(np.nan)), result_jax[special_cases]) signs = np.where(arg[special_cases] < 0., -1., 1.) self.assertAllClose( np.full((nr_special_cases, ), signs * dtype(np.inf)), result_tf[special_cases]) # non-special cases are equal self.assertAllClose(result_jax[~special_cases], result_tf[~special_cases]) atol = None if jtu.device_under_test() == "gpu": # TODO(necula): revisit once we fix the GPU tests atol = 1e-3 self.ConvertAndCompare(harness.dyn_fun, arg, custom_assert=custom_assert, atol=atol)
def test_qr(self, harness: primitive_harness.Harness): # See jax.lib.lapack.geqrf for the list of compatible types if harness.params["dtype"] in [jnp.float32, jnp.float64]: self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=1e-5, rtol=1e-5) elif harness.params["dtype"] in [jnp.complex64, jnp.complex128]: # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824. # - check_compiled=True breaks for complex types; # - for now, the performance of the HLO QR implementation called when # compiling with TF is expected to have worse performance than the # custom calls made in JAX. self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), expect_tf_exceptions=True, atol=1e-5, rtol=1e-5) else: expected_error = ValueError if jtu.device_under_test( ) == "gpu" else NotImplementedError with self.assertRaisesRegex(expected_error, "Unsupported dtype"): harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
def test_conv_general_dilated(self, harness: primitive_harness.Harness): dtype, device = harness.params["dtype"], jtu.device_under_test() if device == "gpu" and dtype in [np.complex64, np.complex128]: raise unittest.SkipTest("TODO: crash on GPU in TF") tol = None if device == "gpu": tol = 1e-4 elif device == "tpu": tol = 1e-3 # TODO(bchetioui): significant discrepancies in some float16 cases. if dtype == np.float16: tol = 1. # TODO(bchetioui): slight occasional discrepancy in float32 cases. elif dtype == np.float32: tol = 0.5 if device == "tpu" else ( 1e-3 if device == "gpu" else 1e-4) elif dtype == np.complex64 and device == "tpu": tol = 0.1 # TODO(bchetioui): slight discrepancy when going through the path using # tf.nn.convolution. elif dtype == np.float64 and device == "cpu": tol = 1e-13 # TODO(bchetioui): unidentified bug in compiled mode. The test that fails is # # test_conv_general_dilated_tf_conversion_path_3d_lhs=float32[1,4,28,28,1]_rhs=float32[2,3,3,1,16]_windowstrides=(1,1,1)_padding=VALID_lhsdilation=(1,1,1)_rhsdilation=(1,1,2)_dimensionnumbers=('NDHWC','DHWIO','NDHWC')_featuregroupcount=1_batchgroupcount=1_precision=None_enablexla=False # # with the following assertion error in TensorFlowTrace.process_primitive: # # AssertionError: conv_general_dilated: out.aval = ShapedArray(float32[1,3,24,26,16]); expected ShapedArray(float32[1,3,26,24,16]) # # Deactivating this assertion is enough to pass the test, which suggests # that the end shape is indeed the correct one (i.e. (1,3,26,24,16)). # Further investigation is required to really understand this behavior, # which we have not managed to reproduce as a pure TF test. # # This bug is low priority since it only occurs when using a non-TFXLA # conversion path in compiled mode, i.e. in a context where using the # TFXLA path is possible. if harness.name == "_tf_conversion_path_3d_lhs=float32[1,4,28,28,1]_rhs=float32[2,3,3,1,16]_windowstrides=(1,1,1)_padding=VALID_lhsdilation=(1,1,1)_rhsdilation=(1,1,2)_dimensionnumbers=('NDHWC','DHWIO','NDHWC')_featuregroupcount=1_batchgroupcount=1_precision=None_enablexla=False": raise unittest.SkipTest( "TODO: known but unidentified bug in compiled " "mode") self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=tol, rtol=tol, enable_xla=harness.params["enable_xla"])
def test_conv_general_dilated(self, harness: primitive_harness.Harness): if jtu.device_under_test() == "gpu": raise unittest.SkipTest("TODO: test failures on GPU") tol = None # TODO(bchetioui): significant discrepancies in some float16 cases. if harness.params["dtype"] is np.float16: tol = 1. # TODO(bchetioui): slight occasional discrepancy in float32 cases. elif harness.params["dtype"] is np.float32: tol = 1e-5 self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=tol, rtol=tol)
def test_betainc(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] # TODO: https://www.tensorflow.org/api_docs/python/tf/math/betainc only # supports float32/64 tests. # TODO(bchetioui): investigate why the test actually fails in JAX. if dtype in [np.float16, dtypes.bfloat16]: raise unittest.SkipTest("(b)float16 not implemented in TF") tol = None if dtype is np.float64: tol = 1e-14 self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=tol, rtol=tol)
def test_add_mul(self, harness: primitive_harness.Harness): expect_tf_exceptions = False dtype = harness.params["dtype"] f_name = harness.params["f_jax"].__name__ if dtype in [np.uint32, np.uint64]: # TODO(bchetioui): tf.math.multiply is not defined for the above types. expect_tf_exceptions = True elif dtype is np.uint16 and f_name == "add": # TODO(bchetioui): tf.math.add is defined for the same types as multiply, # except uint16. expect_tf_exceptions = True self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), expect_tf_exceptions=expect_tf_exceptions)
def test_svd(self, harness: primitive_harness.Harness): if harness.params["dtype"] in [np.float16, dtypes.bfloat16]: if jtu.device_under_test() != "tpu": # Does not work in JAX with self.assertRaisesRegex(NotImplementedError, "Unsupported dtype"): harness.dyn_fun(*harness.dyn_args_maker(self.rng())) return if harness.params["dtype"] in [np.complex64, np.complex128]: if jtu.device_under_test() == "tpu": # TODO: on JAX on TPU there is no SVD implementation for complex with self.assertRaisesRegex( RuntimeError, "Binary op compare with different element types"): harness.dyn_fun(*harness.dyn_args_maker(self.rng())) return def _custom_assert(r_jax, r_tf, atol=1e-6, rtol=1e-6): def _reconstruct_operand(result, is_tf: bool): # Reconstructing operand as documented in numpy.linalg.svd (see # https://numpy.org/doc/stable/reference/generated/numpy.linalg.svd.html) s, u, v = result if is_tf: s = s.numpy() u = u.numpy() v = v.numpy() U = u[..., :s.shape[-1]] V = v[..., :s.shape[-1], :] S = s[..., None, :] return jnp.matmul(U * S, V), s.shape, u.shape, v.shape if harness.params["compute_uv"]: r_jax_reconstructed = _reconstruct_operand(r_jax, False) r_tf_reconstructed = _reconstruct_operand(r_tf, True) self.assertAllClose(r_jax_reconstructed, r_tf_reconstructed, atol=atol, rtol=rtol) else: self.assertAllClose(r_jax, r_tf, atol=atol, rtol=rtol) tol = 1e-4 custom_assert = partial(_custom_assert, atol=tol, rtol=tol) self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=tol, rtol=tol, custom_assert=custom_assert, always_custom_assert=True)
def test_min_max(self, harness: primitive_harness.Harness): expect_tf_exceptions = False dtype = harness.params["dtype"] if dtype in [ np.bool_, np.int8, np.uint16, np.uint32, np.uint64, np.complex64, np.complex128 ]: # TODO(bchetioui): tf.math.maximum and tf.math.minimum are not defined for # the above types. expect_tf_exceptions = True self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), expect_tf_exceptions=expect_tf_exceptions)
def test_select_and_gather_add(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] max_bits = 64 if jtu.device_under_test() == "tpu": max_bits = 32 expect_tf_exceptions = False if dtypes.finfo(dtype).bits * 2 > max_bits: # TODO: getting an exception "XLA encountered an HLO for which this rewriting is not implemented" expect_tf_exceptions = True self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), expect_tf_exceptions=expect_tf_exceptions)
def test_dot_general(self, harness: primitive_harness.Harness): tol, dtype = None, harness.params["dtype"] if dtype == dtypes.bfloat16: tol = 0.3 elif dtype in [np.complex64, np.float32]: if jtu.device_under_test() == "tpu": tol = 0.1 if dtype == np.float32 else 0.3 else: tol = 1e-5 elif dtype == np.float16: if jtu.device_under_test() == "gpu": tol = 0.1 else: tol = 0.01 self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=tol, rtol=tol)
def test_lu(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] if dtype in [np.float16, dtypes.bfloat16]: raise unittest.SkipTest( f"LU is not implemented in JAX for dtype {dtype}.") tol = None if dtype in [np.float32, np.complex64]: if jtu.device_under_test() == "tpu": tol = 0.1 else: tol = 1e-5 if dtype in [np.float64, np.complex128]: tol = 1e-13 operand, = harness.dyn_args_maker(self.rng()) def custom_assert(result_jax, result_tf): lu, pivots, perm = tuple(map(lambda t: t.numpy(), result_tf)) batch_dims = operand.shape[:-2] m, n = operand.shape[-2], operand.shape[-1] def _make_permutation_matrix(perm): result = [] for idx in itertools.product(*map(range, operand.shape[:-1])): result += [0 if c != perm[idx] else 1 for c in range(m)] result = np.reshape(np.array(result, dtype=dtype), [*batch_dims, m, m]) return result k = min(m, n) l = jnp.tril(lu, -1)[..., :, :k] + jnp.eye(m, k, dtype=dtype) u = jnp.triu(lu)[..., :k, :] p_mat = _make_permutation_matrix(perm) self.assertArraysEqual( lax_linalg.lu_pivots_to_permutation(pivots, m), perm) self.assertAllClose(jnp.matmul(p_mat, operand), jnp.matmul(l, u), atol=tol, rtol=tol) self.ConvertAndCompare(harness.dyn_fun, operand, atol=tol, rtol=tol, custom_assert=custom_assert, always_custom_assert=True)