def ConvertAndCompare(self, func_jax: Callable, *args, enable_xla: bool = True, limitations: Sequence = ()): """Compares jax_func(*args) with convert(jax_func)(*args). It compares the result of JAX, TF ("eager" mode), TF with tf.function ("graph" mode), and TF with tf.function(jit_compile=True) ("compiled" mode). In each mode, either we expect to encounter a known limitation, or the value should match the value from the JAX execution. Args: func_jax: the function to invoke (``func_jax(*args)``) args: the arguments. enable_xla: if True, allows the use of XLA ops in jax2tf.convert (default: True). limitations: the set of limitations for this harness (not yet filtered by mode). """ # Run JAX. Should not fail, we assume that the harness has been filtered # already by JAX unimplemented primitives. result_jax = func_jax(*args) # JAX result_tf = None func_tf = jax2tf.convert(func_jax, enable_xla=enable_xla) unexpected_successes: List[str] = [] # Run the "compiled" mode first, it is most important for mode in ("compiled", "eager", "graph"): def log_message(extra): return f"[{self._testMethodName}] mode={mode}: {extra}" jax2tf_limits = tuple( filter(lambda l: l.filter(mode=mode), limitations)) skip_tf_run = [l for l in jax2tf_limits if l.skip_tf_run] if skip_tf_run: logging.info( log_message( f"Skip TF run due to limitations {skip_tf_run}")) continue try: result_tf = _run_tf_function(func_tf, *args, mode=mode) tf_exception = None except Exception as e: tf_exception = e expect_tf_error = [l for l in jax2tf_limits if l.expect_tf_error] if tf_exception: if expect_tf_error: logging.info( log_message( "Found expected TF error with enabled limitations " f"{expect_tf_error}; TF error is {tf_exception}")) continue else: raise tf_exception else: if expect_tf_error: # It is more ergonomic to print all successful modes once logging.warning( log_message( f"Unexpected success with known limitations {expect_tf_error}" )) unexpected_successes.append(f"{mode}: {expect_tf_error}") if (jtu.device_under_test() == "gpu" and "dot_general_preferred" in self._testMethodName): logging.info( log_message( f"Arguments are {args}, JAX result is {result_jax}\nand TF result is {result_tf}" )) skip_comparison = [l for l in jax2tf_limits if l.skip_comparison] if skip_comparison: logging.warning( log_message( f"Skip result comparison due to {skip_comparison}")) continue max_tol = None max_tol_lim = None if not jax2tf_limits else jax2tf_limits[ 0].get_max_tolerance_limitation(jax2tf_limits) if max_tol_lim is not None: max_tol = max_tol_lim.tol logging.info( log_message(f"Using tol={max_tol} due to {max_tol_lim}")) # Convert results to np.arrays result_tf = tf.nest.map_structure(lambda t: t.numpy(), result_tf) # type: ignore custom_assert_lim = [l for l in jax2tf_limits if l.custom_assert] assert len( custom_assert_lim ) <= 1, f"Expecting at most one applicable limitation with custom_assert, found {custom_assert_lim}" try: err_msg = f"TF mode {mode}." log_hlo_on_error = mode == "compiled" or jtu.device_under_test( ) == "tpu" if log_hlo_on_error: err_msg += " See the logs for JAX and TF HLO comparisons." if custom_assert_lim: logging.info( log_message( f"Running custom_assert with tol={max_tol} due to {custom_assert_lim[0]}" )) custom_assert_lim[0].custom_assert(self, result_jax, result_tf, args=args, tol=max_tol, err_msg=err_msg) else: logging.info( log_message( f"Running default assert with tol={max_tol}")) self.assertAllClose(result_jax, result_tf, atol=max_tol, rtol=max_tol, err_msg=err_msg) except AssertionError as e: # Print the HLO for comparison if not log_hlo_on_error: print( f"[{self._testMethodName}] Not logging HLO because the " f"mode was {mode}") raise logging.info( f"[{self._testMethodName}] Logging HLO for exception in mode {mode}: {e}" ) jax_comp = jax.xla_computation(func_jax)(*args) jax_hlo = jax_comp.as_hlo_text() logging.info(f"[{self._testMethodName}] " f"JAX NON_OPT HLO\n{jax_hlo}") tf_args_signature = _make_tf_input_signature(*args) # If we give the signature, we cannot pass scalars tf_args_no_scalars = tuple( map( lambda a, sig: tf.convert_to_tensor( a, dtype=sig.dtype), args, tf_args_signature)) tf_func_compiled = tf.function( func_tf, autograph=False, jit_compile=True, input_signature=tf_args_signature) tf_hlo = tf_func_compiled.experimental_get_compiler_ir( *tf_args_no_scalars)(stage="hlo") logging.info( f"[{self._testMethodName}] TF NON OPT HLO\n{tf_hlo}") backend = jax.lib.xla_bridge.get_backend() modules = backend.compile(jax_comp).hlo_modules() jax_opt_hlo = modules[0].to_string() logging.info(f"[{self._testMethodName}] " f"JAX OPT HLO\n{jax_opt_hlo}") # TODO(b/189265364): Remove this workaround if (jtu.device_under_test() == "gpu" and "dot_general" in self._testMethodName): print( f"[{self._testMethodName}] Not logging TF OPT HLO because of " f"crash in tf.experimental_get_compiler_ir (b/189265364)" ) else: tf_opt_hlo = tf_func_compiled.experimental_get_compiler_ir( *tf_args_no_scalars)(stage="optimized_hlo") logging.info( f"[{self._testMethodName}] TF OPT HLO\n{tf_opt_hlo}") raise # end "for mode" if unexpected_successes: msg = (f"[{self._testMethodName}] The following are unexpected " "successful modes:\n" + "\n".join(unexpected_successes)) logging.warning(msg) # Uncomment the below if you want to see warnings as failures # self.assertEmpty(msg) return result_jax, result_tf
class SparseObjectTest(jtu.JaxTestCase): @parameterized.named_parameters( {"testcase_name": "_{}".format(Obj.__name__), "Obj": Obj} for Obj in [sparse_ops.CSR, sparse_ops.CSC, sparse_ops.COO]) def test_attrs(self, Obj, shape=(5, 8), dtype=np.float16): rng = rand_sparse(self.rng(), post=Obj.fromdense) M = rng(shape, dtype) assert isinstance(M, Obj) assert M.shape == shape assert M.dtype == dtype assert M.nnz == (M.todense() != 0).sum() assert M.data.dtype == dtype if isinstance(M, sparse_ops.CSR): assert len(M.data) == len(M.indices) assert len(M.indptr) == M.shape[0] + 1 elif isinstance(M, sparse_ops.CSC): assert len(M.data) == len(M.indices) assert len(M.indptr) == M.shape[1] + 1 elif isinstance(M, sparse_ops.COO): assert len(M.data) == len(M.row) == len(M.col) else: raise ValueError("Obj={Obj} not expected.") @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": "_{}_Obj={}".format( jtu.format_shape_dtype_string(shape, dtype), Obj.__name__), "shape": shape, "dtype": dtype, "Obj": Obj} for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex) for Obj in [sparse_ops.CSR, sparse_ops.CSC, sparse_ops.COO])) def test_dense_round_trip(self, shape, dtype, Obj): rng = rand_sparse(self.rng()) M = rng(shape, dtype) Msparse = Obj.fromdense(M) self.assertArraysEqual(M, Msparse.todense()) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": "_{}_Obj={}".format( jtu.format_shape_dtype_string(shape, dtype), Obj.__name__), "shape": shape, "dtype": dtype, "Obj": Obj} for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex) for Obj in [sparse_ops.CSR, sparse_ops.CSC, sparse_ops.COO])) def test_transpose(self, shape, dtype, Obj): rng = rand_sparse(self.rng()) M = rng(shape, dtype) Msparse = Obj.fromdense(M) self.assertArraysEqual(M.T, Msparse.T.todense()) @unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision") @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": "_{}_Obj={}_bshape={}".format( jtu.format_shape_dtype_string(shape, dtype), Obj.__name__, bshape), "shape": shape, "dtype": dtype, "Obj": Obj, "bshape": bshape} for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for bshape in [shape[-1:] + s for s in [(), (3,), (4,)]] for dtype in jtu.dtypes.floating + jtu.dtypes.complex) for Obj in [sparse_ops.CSR, sparse_ops.CSC, sparse_ops.COO])) def test_matmul(self, shape, dtype, Obj, bshape): rng = rand_sparse(self.rng(), post=jnp.array) rng_b = jtu.rand_default(self.rng()) M = rng(shape, dtype) Msp = Obj.fromdense(M) x = rng_b(bshape, dtype) x = jnp.asarray(x) self.assertAllClose(M @ x, Msp @ x, rtol=MATMUL_TOL)
def test_generate_primitives_coverage_doc(self): harnesses = primitive_harness.all_harnesses print(f"Found {len(harnesses)} harnesses") harness_groups: Dict[str, Sequence[primitive_harness. Harness]] = collections.defaultdict(list) def unique_hash(h: primitive_harness.Harness, l: primitive_harness.Limitation): return (h.group_name, l.description, l.devices, tuple([np.dtype(d).name for d in l.dtypes])) unique_limitations: Dict[Any, Tuple[primitive_harness.Harness, primitive_harness.Limitation]] = {} for h in harnesses: harness_groups[h.group_name].append(h) for l in h.jax_unimplemented: if l.enabled: unique_limitations[hash(unique_hash(h, l))] = (h, l) primitive_coverage_table = [ """ | Primitive | Total test harnesses | dtypes supported on at least one device | dtypes NOT tested on any device | | --- | --- | --- | --- | --- |""" ] all_dtypes = set(jtu.dtypes.all) for group_name in sorted(harness_groups.keys()): hlist = harness_groups[group_name] dtypes_tested = set() # Tested on at least some device for h in hlist: dtypes_tested = dtypes_tested.union({h.dtype}) primitive_coverage_table.append( f"| {group_name} | {len(hlist)} | " f"{primitive_harness.dtypes_to_str(dtypes_tested)} | " f"{primitive_harness.dtypes_to_str(all_dtypes - dtypes_tested)} |" ) print(f"Found {len(unique_limitations)} unique limitations") primitive_unimpl_table = [ """ | Affected primitive | Description of limitation | Affected dtypes | Affected devices | | --- | --- | --- | --- | --- |""" ] for h, l in sorted(unique_limitations.values(), key=lambda pair: unique_hash(*pair)): devices = ", ".join(l.devices) primitive_unimpl_table.append( f"|{h.group_name}|{l.description}|" f"{primitive_harness.dtypes_to_str(l.dtypes, empty_means_all=True)}|{devices}|" ) if not os.environ.get("JAX_OUTPUT_LIMITATIONS_DOC"): raise unittest.SkipTest( "Set JAX_OUTPUT_LIMITATIONS_DOC=1 to enable the generation of the documentation" ) # The CPU/GPU have more supported types than TPU. self.assertEqual("cpu", jtu.device_under_test(), "The documentation can be generated only on CPU") self.assertTrue( FLAGS.jax_enable_x64, "The documentation must be generated with JAX_ENABLE_X64=1") with open( os.path.join( os.path.dirname(__file__), '../g3doc/jax_primitives_coverage.md.template')) as f: template = f.read() output_file = os.path.join(os.path.dirname(__file__), '../g3doc/jax_primitives_coverage.md') with open(output_file, "w") as f: f.write(template.replace("{{generation_date}}", str(datetime.date.today())) \ .replace("{{nr_harnesses}}", str(len(harnesses))) \ .replace("{{nr_primitives}}", str(len(harness_groups))) \ .replace("{{primitive_unimpl_table}}", "\n".join(primitive_unimpl_table)) \ .replace("{{primitive_coverage_table}}", "\n".join(primitive_coverage_table)))
def test_svd(self, harness: primitive_harness.Harness): if jtu.device_under_test() == "tpu": raise unittest.SkipTest( "TODO: test crashes the XLA compiler for some TPU variants") expect_tf_exceptions = False if harness.params["dtype"] in [jnp.float16, dtypes.bfloat16]: if jtu.device_under_test() == "tpu": # TODO: SVD on TPU for bfloat16 seems to work for JAX but fails for TF expect_tf_exceptions = True else: # Does not work in JAX with self.assertRaisesRegex(NotImplementedError, "Unsupported dtype"): harness.dyn_fun(*harness.dyn_args_maker(self.rng())) return if harness.params["dtype"] in [jnp.complex64, jnp.complex128]: if jtu.device_under_test() == "tpu": # TODO: on JAX on TPU there is no SVD implementation for complex with self.assertRaisesRegex( RuntimeError, "Binary op compare with different element types"): harness.dyn_fun(*harness.dyn_args_maker(self.rng())) return else: # TODO: on CPU and GPU "No registered 'Svd' OpKernel for XLA_CPU_JIT devices". # Works on JAX because JAX uses a custom implementation. expect_tf_exceptions = True def _custom_assert(r_jax, r_tf, atol=1e-6, rtol=1e-6): def _reconstruct_operand(result, is_tf: bool): # Reconstructing operand as documented in numpy.linalg.svd (see # https://numpy.org/doc/stable/reference/generated/numpy.linalg.svd.html) s, u, v = result if is_tf: s = s.numpy() u = u.numpy() v = v.numpy() U = u[..., :s.shape[-1]] V = v[..., :s.shape[-1], :] S = s[..., None, :] return jnp.matmul(U * S, V), s.shape, u.shape, v.shape if harness.params["compute_uv"]: r_jax_reconstructed = _reconstruct_operand(r_jax, False) r_tf_reconstructed = _reconstruct_operand(r_tf, True) self.assertAllClose(r_jax_reconstructed, r_tf_reconstructed, atol=atol, rtol=rtol) else: self.assertAllClose(r_jax, r_tf, atol=atol, rtol=rtol) tol = 1e-4 custom_assert = partial(_custom_assert, atol=tol, rtol=tol) self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=tol, rtol=tol, expect_tf_exceptions=expect_tf_exceptions, custom_assert=custom_assert, always_custom_assert=True)
op_record("square", 1, float_dtypes + complex_dtypes, jtu.rand_default), op_record("reciprocal", 1, float_dtypes + complex_dtypes, jtu.rand_positive), op_record("tan", 1, float_dtypes, jtu.rand_default, {np.float32: 3e-5}), op_record("asin", 1, float_dtypes, jtu.rand_small), # TODO(j-towns) fix: op_record("acos", 1, float_dtypes, jtu.rand_small), op_record("atan", 1, float_dtypes, jtu.rand_small), op_record("asinh", 1, float_dtypes, jtu.rand_default), op_record("acosh", 1, float_dtypes, jtu.rand_positive), # TODO(b/155331781): atanh has only ~float precision op_record("atanh", 1, float_dtypes, jtu.rand_small, {np.float64: 1e-9}), op_record("sinh", 1, float_dtypes + complex_dtypes, jtu.rand_default), op_record("cosh", 1, float_dtypes + complex_dtypes, jtu.rand_default), op_record( "lgamma", 1, float_dtypes, jtu.rand_positive, { np.float32: 1e-3 if jtu.device_under_test() == "tpu" else 1e-5, np.float64: 1e-14 }), op_record("digamma", 1, float_dtypes, jtu.rand_positive, {np.float64: 1e-14}), op_record("betainc", 3, float_dtypes, jtu.rand_positive, {np.float64: 1e-14}), op_record( "igamma", 2, [f for f in float_dtypes if f not in [dtypes.bfloat16, np.float16]], jtu.rand_positive, {np.float64: 1e-14}), op_record( "igammac", 2, [f for f in float_dtypes if f not in [dtypes.bfloat16, np.float16]], jtu.rand_positive, {np.float64: 1e-14}), op_record("erf", 1, float_dtypes, jtu.rand_small),
def test_unary_elementwise(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] lax_name = harness.params["lax_name"] arg, = harness.dyn_args_maker(self.rng()) custom_assert = None if lax_name == "digamma": # TODO(necula): fix bug with digamma/(f32|f16) on TPU if dtype in [np.float16, np.float32 ] and jtu.device_under_test() == "tpu": raise unittest.SkipTest("TODO: fix bug: nan vs not-nan") # In the bfloat16 case, TF and lax both return NaN in undefined cases. if not dtype is dtypes.bfloat16: # digamma is not defined at 0 and -1 def custom_assert(result_jax, result_tf): # lax.digamma returns NaN and tf.math.digamma returns inf special_cases = (arg == 0.) | (arg == -1.) nr_special_cases = np.count_nonzero(special_cases) self.assertAllClose( np.full((nr_special_cases, ), dtype(np.nan)), result_jax[special_cases]) self.assertAllClose( np.full((nr_special_cases, ), dtype(np.inf)), result_tf[special_cases]) # non-special cases are equal self.assertAllClose(result_jax[~special_cases], result_tf[~special_cases]) if lax_name == "erf_inv": # TODO(necula): fix erf_inv bug on TPU if jtu.device_under_test() == "tpu": raise unittest.SkipTest("erf_inv bug on TPU: nan vs non-nan") # TODO: investigate: in the (b)float16 cases, TF and lax both return the # same result in undefined cases. if not dtype in [np.float16, dtypes.bfloat16]: # erf_inv is not defined for arg <= -1 or arg >= 1 def custom_assert(result_jax, result_tf): # noqa: F811 # for arg < -1 or arg > 1 # lax.erf_inv returns NaN; tf.math.erf_inv return +/- inf special_cases = (arg < -1.) | (arg > 1.) nr_special_cases = np.count_nonzero(special_cases) self.assertAllClose( np.full((nr_special_cases, ), dtype(np.nan), dtype=dtype), result_jax[special_cases]) signs = np.where(arg[special_cases] < 0., -1., 1.) self.assertAllClose( np.full((nr_special_cases, ), signs * dtype(np.inf), dtype=dtype), result_tf[special_cases]) # non-special cases are equal self.assertAllClose(result_jax[~special_cases], result_tf[~special_cases]) atol = None if jtu.device_under_test() == "gpu": # TODO(necula): revisit once we fix the GPU tests atol = 1e-3 self.ConvertAndCompare(harness.dyn_fun, arg, custom_assert=custom_assert, atol=atol)
class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [ rng(shape, dtype) for shape, dtype in zip(shapes, dtypes) ] @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_axis={}_keepdims={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU. "rng_factory": jtu.rand_some_inf_and_nan if jtu.device_under_test() != "cpu" else jtu.rand_default, "shape": shape, "dtype": dtype, "axis": axis, "keepdims": keepdims } for shape in all_shapes for dtype in float_dtypes for axis in range(-len(shape), len(shape)) for keepdims in [False, True])) @jtu.skip_on_flag("jax_xla_backend", "xrt") def testLogSumExp(self, rng_factory, shape, dtype, axis, keepdims): rng = rng_factory() # TODO(mattjj): test autodiff def scipy_fun(array_to_reduce): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims) def lax_fun(array_to_reduce): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( itertools.chain.from_iterable( jtu.cases_from_list( { "testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, dtypes), "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "test_autodiff": rec.test_autodiff, "scipy_op": getattr(osp_special, rec.name), "lax_op": getattr(lsp_special, rec.name) } for shapes in CombosWithReplacement(all_shapes, rec.nargs) for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs)) for rec in JAX_SPECIAL_FUNCTION_RECORDS)) def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes, test_autodiff): rng = rng_factory() args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) self._CompileAndCheck(lax_op, args_maker, check_dtypes=True) if test_autodiff: jtu.check_grads(lax_op, args, order=1, atol=1e-3, rtol=3e-3, eps=1e-3) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_d={}".format( jtu.format_shape_dtype_string(shape, dtype), d), "rng_factory": jtu.rand_positive, "shape": shape, "dtype": dtype, "d": d } for shape in all_shapes for dtype in float_dtypes for d in [1, 2, 5])) def testMultigammaln(self, rng_factory, shape, dtype, d): def scipy_fun(a): return osp_special.multigammaln(a, d) def lax_fun(a): return lsp_special.multigammaln(a, d) rng = rng_factory() args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) def testIssue980(self): x = onp.full((4, ), -1e20, dtype=onp.float32) self.assertAllClose(onp.zeros((4, ), dtype=onp.float32), lsp_special.expit(x), check_dtypes=True)
def setUp(self): super().setUp() if jtu.device_under_test() not in ["tpu", "gpu"]: raise SkipTest if jtu.device_under_test() == "gpu": os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL"
def default_tolerance(): if jtu.device_under_test() != 'tpu': return jtu._default_tolerance tol = jtu._default_tolerance.copy() tol[onp.dtype(onp.float32)] = 5e-2 return tol
def testSoftplusGrad(self): check_grads(nn.softplus, (1e-8, ), 4, rtol=1e-2 if jtu.device_under_test() == "tpu" else None)
def setUp(self): super().setUp() if jtu.device_under_test() not in ["tpu", "gpu"]: raise SkipTest
# TODO(mattjj): make some-equal checks more robust, enable second-order # grad_test_spec(lax.max, nargs=2, order=1, rng_factory=jtu.rand_some_equal, # dtypes=grad_float_dtypes, name="MaxSomeEqual"), # grad_test_spec(lax.min, nargs=2, order=1, rng_factory=jtu.rand_some_equal, # dtypes=grad_float_dtypes, name="MinSomeEqual"), ] GradSpecialValuesTestSpec = collections.namedtuple( "GradSpecialValuesTestSpec", ["op", "values", "tol"]) def grad_special_values_test_spec(op, values, tol=None): return GradSpecialValuesTestSpec(op, values, tol) LAX_GRAD_SPECIAL_VALUE_TESTS = [ grad_special_values_test_spec( lax.sinh, [0.], tol={np.float32: 1e-2} if jtu.device_under_test() == "tpu" else None), grad_special_values_test_spec( lax.cosh, [0.], tol={np.float32: 1e-2} if jtu.device_under_test() == "tpu" else None), grad_special_values_test_spec(lax.tanh, [0., 1000.]), grad_special_values_test_spec(lax.sin, [0., np.pi, np.pi/2., np.pi/4.]), grad_special_values_test_spec(lax.cos, [0., np.pi, np.pi/2., np.pi/4.]), grad_special_values_test_spec(lax.tan, [0.]), grad_special_values_test_spec(lax.asin, [0.]), grad_special_values_test_spec(lax.acos, [0.]), grad_special_values_test_spec(lax.atan, [0., 1000.]), grad_special_values_test_spec(lax.erf, [0., 10.]), grad_special_values_test_spec(lax.erfc, [0., 10.]), ]
def setUp(self): if jtu.device_under_test() != "gpu": self.skipTest("__cuda_array_interface__ is only supported on GPU")
def setUp(self): if jtu.device_under_test() != "tpu": raise SkipTest
def test_eig(self, harness: primitive_harness.Harness): operand = harness.dyn_args_maker(self.rng())[0] compute_left_eigenvectors = harness.params["compute_left_eigenvectors"] compute_right_eigenvectors = harness.params[ "compute_right_eigenvectors"] dtype = harness.params["dtype"] if jtu.device_under_test() != "cpu": raise unittest.SkipTest("eig only supported on CPU in JAX") if dtype in [np.float16, dtypes.bfloat16]: raise unittest.SkipTest("eig unsupported with (b)float16 in JAX") def custom_assert(result_jax, result_tf): result_tf = tuple(map(lambda e: e.numpy(), result_tf)) inner_dimension = operand.shape[-1] # Test ported from tests.lax_test.testEig # Norm, adjusted for dimension and type. def norm(x): norm = np.linalg.norm(x, axis=(-2, -1)) return norm / ((inner_dimension + 1) * jnp.finfo(dtype).eps) def check_right_eigenvectors(a, w, vr): self.assertTrue( np.all( norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100)) def check_left_eigenvectors(a, w, vl): rank = len(a.shape) aH = jnp.conj( a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2])) wC = jnp.conj(w) check_right_eigenvectors(aH, wC, vl) def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array): tol = None # TODO(bchetioui): numerical discrepancies if dtype in [np.float32, np.complex64]: tol = 1e-4 elif dtype in [np.float64, np.complex128]: tol = 1e-13 closest_diff = min(abs(eigenvalues_array - eigenvalue)) self.assertAllClose(closest_diff, np.array(0., closest_diff.dtype), atol=tol) all_w_jax, all_w_tf = result_jax[0], result_tf[0] for idx in itertools.product(*map(range, operand.shape[:-2])): w_jax, w_tf = all_w_jax[idx], all_w_tf[idx] for i in range(inner_dimension): check_eigenvalue_is_in_array(w_jax[i], w_tf) check_eigenvalue_is_in_array(w_tf[i], w_jax) if compute_left_eigenvectors: check_left_eigenvectors(operand, all_w_tf, result_tf[1]) if compute_right_eigenvectors: check_right_eigenvectors( operand, all_w_tf, result_tf[1 + compute_left_eigenvectors]) self.ConvertAndCompare(harness.dyn_fun, operand, custom_assert=custom_assert)
class cuSparseTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) def test_csr_todense(self, shape, dtype): rng = rand_sparse(self.rng(), post=sparse.csr_matrix) M = rng(shape, dtype) args = (M.data, M.indices, M.indptr) todense = lambda *args: sparse_ops.csr_todense(*args, shape=M.shape) self.assertArraysEqual(M.toarray(), todense(*args)) self.assertArraysEqual(M.toarray(), jit(todense)(*args)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) def test_csr_fromdense(self, shape, dtype): rng = rand_sparse(self.rng()) M = rng(shape, dtype) M_csr = sparse.csr_matrix(M) nnz = M_csr.nnz index_dtype = jnp.int32 fromdense = lambda M: sparse_ops.csr_fromdense( M, nnz=nnz, index_dtype=jnp.int32) data, indices, indptr = fromdense(M) self.assertArraysEqual(data, M_csr.data.astype(dtype)) self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype)) self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype)) data, indices, indptr = jit(fromdense)(M) self.assertArraysEqual(data, M_csr.data.astype(dtype)) self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype)) self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose), "shape": shape, "dtype": dtype, "transpose": transpose } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex for transpose in [True, False])) def test_csr_matvec(self, shape, dtype, transpose): op = lambda M: M.T if transpose else M v_rng = jtu.rand_default(self.rng()) rng = rand_sparse(self.rng(), post=sparse.csr_matrix) M = rng(shape, dtype) v = v_rng(op(M).shape[1], dtype) args = (M.data, M.indices, M.indptr, v) matvec = lambda *args: sparse_ops.csr_matvec( *args, shape=M.shape, transpose=transpose) self.assertAllClose(op(M) @ v, matvec(*args)) self.assertAllClose(op(M) @ v, jit(matvec)(*args)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose), "shape": shape, "dtype": dtype, "transpose": transpose } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex for transpose in [True, False])) def test_csr_matmat(self, shape, dtype, transpose): op = lambda M: M.T if transpose else M B_rng = jtu.rand_default(self.rng()) rng = rand_sparse(self.rng(), post=sparse.csr_matrix) M = rng(shape, dtype) B = B_rng((op(M).shape[1], 4), dtype) args = (M.data, M.indices, M.indptr, B) matmat = lambda *args: sparse_ops.csr_matmat( *args, shape=shape, transpose=transpose) self.assertAllClose(op(M) @ B, matmat(*args)) self.assertAllClose(op(M) @ B, jit(matmat)(*args)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) def test_coo_todense(self, shape, dtype): rng = rand_sparse(self.rng(), post=sparse.coo_matrix) M = rng(shape, dtype) args = (M.data, M.row, M.col) todense = lambda *args: sparse_ops.coo_todense(*args, shape=M.shape) self.assertArraysEqual(M.toarray(), todense(*args)) self.assertArraysEqual(M.toarray(), jit(todense)(*args)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) def test_coo_fromdense(self, shape, dtype): rng = rand_sparse(self.rng()) M = rng(shape, dtype) M_coo = sparse.coo_matrix(M) nnz = M_coo.nnz index_dtype = jnp.int32 fromdense = lambda M: sparse_ops.coo_fromdense( M, nnz=nnz, index_dtype=jnp.int32) data, row, col = fromdense(M) self.assertArraysEqual(data, M_coo.data.astype(dtype)) self.assertArraysEqual(row, M_coo.row.astype(index_dtype)) self.assertArraysEqual(col, M_coo.col.astype(index_dtype)) data, indices, indptr = jit(fromdense)(M) self.assertArraysEqual(data, M_coo.data.astype(dtype)) self.assertArraysEqual(row, M_coo.row.astype(index_dtype)) self.assertArraysEqual(col, M_coo.col.astype(index_dtype)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose), "shape": shape, "dtype": dtype, "transpose": transpose } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex for transpose in [True, False])) def test_coo_matvec(self, shape, dtype, transpose): op = lambda M: M.T if transpose else M v_rng = jtu.rand_default(self.rng()) rng = rand_sparse(self.rng(), post=sparse.coo_matrix) M = rng(shape, dtype) v = v_rng(op(M).shape[1], dtype) args = (M.data, M.row, M.col, v) matvec = lambda *args: sparse_ops.coo_matvec( *args, shape=M.shape, transpose=transpose) self.assertAllClose(op(M) @ v, matvec(*args)) self.assertAllClose(op(M) @ v, jit(matvec)(*args)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose), "shape": shape, "dtype": dtype, "transpose": transpose } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex for transpose in [True, False])) def test_coo_matmat(self, shape, dtype, transpose): op = lambda M: M.T if transpose else M B_rng = jtu.rand_default(self.rng()) rng = rand_sparse(self.rng(), post=sparse.coo_matrix) M = rng(shape, dtype) B = B_rng((op(M).shape[1], 4), dtype) args = (M.data, M.row, M.col, B) matmat = lambda *args: sparse_ops.coo_matmat( *args, shape=shape, transpose=transpose) self.assertAllClose(op(M) @ B, matmat(*args)) self.assertAllClose(op(M) @ B, jit(matmat)(*args)) @unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU") def test_gpu_translation_rule(self): version = xla_bridge.get_backend().platform_version cuda_version = None if version == "<unknown>" else int( version.split()[-1]) if cuda_version is None or cuda_version < 11000: self.assertNotIn(sparse_ops.csr_todense_p, xla.backend_specific_translations["gpu"]) else: self.assertIn(sparse_ops.csr_todense_p, xla.backend_specific_translations["gpu"]) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_{}".format(jtu.format_shape_dtype_string(shape, dtype), mat_type), "shape": shape, "dtype": dtype, "mat_type": mat_type } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex for mat_type in ['csr', 'coo'])) def test_extra_nnz(self, shape, dtype, mat_type): rng = rand_sparse(self.rng()) M = rng(shape, dtype) nnz = (M != 0).sum() + 5 fromdense = getattr(sparse_ops, f"{mat_type}_fromdense") todense = getattr(sparse_ops, f"{mat_type}_todense") args = fromdense(M, nnz=nnz, index_dtype=jnp.int32) M_out = todense(*args, shape=M.shape) self.assertArraysEqual(M, M_out)
def test_eigh(self, harness: primitive_harness.Harness): operand = harness.dyn_args_maker(self.rng())[0] lower = harness.params["lower"] # Make operand self-adjoint operand = (operand + np.conj(np.swapaxes(operand, -1, -2))) / 2 # Make operand lower/upper triangular triangular_operand = np.tril(operand) if lower else np.triu(operand) dtype = harness.params["dtype"] if (dtype in [np.complex64, np.complex128] and jtu.device_under_test() == "tpu"): raise unittest.SkipTest( "TODO: complex eigh not supported on TPU in JAX") def custom_assert(result_jax, result_tf): result_tf = tuple(map(lambda e: e.numpy(), result_tf)) inner_dimension = operand.shape[-1] def check_right_eigenvectors(a, w, vr): tol = 1e-16 # TODO(bchetioui): tolerance needs to be very high in compiled mode, # specifically for eigenvectors. if dtype == np.float64: tol = 1e-6 elif dtype == np.float32: tol = 1e-2 elif dtype in [dtypes.bfloat16, np.complex64]: tol = 1e-3 elif dtype == np.complex128: tol = 1e-13 self.assertAllClose(np.matmul(a, vr) - w[..., None, :] * vr, np.zeros(a.shape, dtype=vr.dtype), atol=tol) def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array): tol = None if dtype in [dtypes.bfloat16, np.float32, np.complex64]: tol = 1e-3 elif dtype in [np.float64, np.complex128]: tol = 1e-11 closest_diff = min(abs(eigenvalues_array - eigenvalue)) self.assertAllClose(closest_diff, np.array(0., closest_diff.dtype), atol=tol) _, all_w_jax = result_jax all_vr_tf, all_w_tf = result_tf for idx in itertools.product(*map(range, operand.shape[:-2])): w_jax, w_tf = all_w_jax[idx], all_w_tf[idx] for i in range(inner_dimension): check_eigenvalue_is_in_array(w_jax[i], w_tf) check_eigenvalue_is_in_array(w_tf[i], w_jax) check_right_eigenvectors(operand, all_w_tf, all_vr_tf) # On CPU and GPU, JAX makes custom calls always_custom_assert = True # On TPU, JAX calls xops.Eigh if jtu.device_under_test == "tpu": always_custom_assert = False self.ConvertAndCompare(harness.dyn_fun, triangular_operand, custom_assert=custom_assert, always_custom_assert=always_custom_assert)
def testDtypeMatchesInput(self, dtype, fn): if dtype is jnp.float16 and jtu.device_under_test() == "tpu": self.skipTest("float16 not supported on TPU") x = jnp.zeros((), dtype=dtype) out = fn(x) self.assertEqual(out.dtype, dtype)
def test_binary_elementwise(self, harness): tol = None lax_name, dtype = harness.params["lax_name"], harness.params["dtype"] if lax_name in ("igamma", "igammac"): # TODO(necula): fix bug with igamma/f16 if dtype in [np.float16, dtypes.bfloat16]: raise unittest.SkipTest( "TODO: igamma(c) unsupported with (b)float16 in JAX") # TODO(necula): fix bug with igamma/f32 on TPU if dtype is np.float32 and jtu.device_under_test() == "tpu": raise unittest.SkipTest("TODO: fix bug: nan vs not-nan") arg1, arg2 = harness.dyn_args_maker(self.rng()) custom_assert = None if lax_name == "igamma": # igamma is not defined when the first argument is <=0 def custom_assert(result_jax, result_tf): # lax.igamma returns NaN when arg1 == arg2 == 0; tf.math.igamma returns 0 special_cases = (arg1 == 0.) & (arg2 == 0.) nr_special_cases = np.count_nonzero(special_cases) self.assertAllClose( np.full((nr_special_cases, ), np.nan, dtype=dtype), result_jax[special_cases]) self.assertAllClose( np.full((nr_special_cases, ), 0., dtype=dtype), result_tf[special_cases]) # non-special cases are equal self.assertAllClose(result_jax[~special_cases], result_tf[~special_cases]) if lax_name == "igammac": # On GPU, tolerance also needs to be adjusted in compiled mode if dtype == np.float64 and jtu.device_under_test() == 'gpu': tol = 1e-14 # igammac is not defined when the first argument is <=0 def custom_assert(result_jax, result_tf): # noqa: F811 # lax.igammac returns 1. when arg1 <= 0; tf.math.igammac returns NaN special_cases = (arg1 <= 0.) | (arg2 <= 0) nr_special_cases = np.count_nonzero(special_cases) self.assertAllClose( np.full((nr_special_cases, ), 1., dtype=dtype), result_jax[special_cases]) self.assertAllClose( np.full((nr_special_cases, ), np.nan, dtype=dtype), result_tf[special_cases]) # On CPU, tolerance only needs to be adjusted in eager & graph modes tol = None if dtype == np.float64: tol = 1e-14 # non-special cases are equal self.assertAllClose(result_jax[~special_cases], result_tf[~special_cases], atol=tol, rtol=tol) self.ConvertAndCompare(harness.dyn_fun, arg1, arg2, custom_assert=custom_assert, atol=tol, rtol=tol)
def testSoftplusGradZero(self): check_grads(nn.softplus, (0., ), order=1, rtol=1e-2 if jtu.device_under_test() == "tpu" else None)
def setUp(self): super(DLPackTest, self).setUp() if jtu.device_under_test() == "tpu": self.skipTest("DLPack not supported on TPU")
def testSoftplusGradNan(self): check_grads(nn.softplus, (float('nan'), ), order=1, rtol=1e-2 if jtu.device_under_test() == "tpu" else None)
def test_unary_elementwise(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] lax_name = harness.params["lax_name"] if dtype is dtypes.bfloat16: raise unittest.SkipTest("bfloat16 not implemented") if lax_name in ("sinh", "cosh", "atanh", "asinh", "acosh") and dtype is np.float16: raise unittest.SkipTest( "b/158006398: float16 support is missing from '%s' TF kernel" % lax_name) arg, = harness.dyn_args_maker(self.rng()) custom_assert = None if lax_name == "digamma": # TODO(necula): fix bug with digamma/f32 on TPU if harness.params["dtype"] is np.float32 and jtu.device_under_test( ) == "tpu": raise unittest.SkipTest("TODO: fix bug: nan vs not-nan") if harness.params["dtype"] is np.float16 and jtu.device_under_test( ) == "tpu": raise unittest.SkipTest("TODO: fix bug: nans and infs") # digamma is not defined at 0 and -1 def custom_assert(result_jax, result_tf): # lax.digamma returns NaN and tf.math.digamma returns inf special_cases = (arg == 0.) | (arg == -1.) nr_special_cases = np.count_nonzero(special_cases) self.assertAllClose( np.full((nr_special_cases, ), dtype(np.nan)), result_jax[special_cases]) self.assertAllClose( np.full((nr_special_cases, ), dtype(np.inf)), result_tf[special_cases]) # non-special cases are equal self.assertAllClose(result_jax[~special_cases], result_tf[~special_cases]) if lax_name == "erf_inv": # TODO(necula): fix bug with erf_inv/f16 if dtype is np.float16: raise unittest.SkipTest("TODO: fix bug") # TODO(necula): fix erf_inv bug on TPU if jtu.device_under_test() == "tpu": raise unittest.SkipTest("erf_inv bug on TPU: nan vs non-nan") # erf_inf is not defined for arg <= -1 or arg >= 1 def custom_assert(result_jax, result_tf): # noqa: F811 # for arg < -1 or arg > 1 # lax.erf_inf returns NaN; tf.math.erf_inv return +/- inf special_cases = (arg < -1.) | (arg > 1.) nr_special_cases = np.count_nonzero(special_cases) self.assertAllClose( np.full((nr_special_cases, ), dtype(np.nan)), result_jax[special_cases]) signs = np.where(arg[special_cases] < 0., -1., 1.) self.assertAllClose( np.full((nr_special_cases, ), signs * dtype(np.inf)), result_tf[special_cases]) # non-special cases are equal self.assertAllClose(result_jax[~special_cases], result_tf[~special_cases]) atol = None if jtu.device_under_test() == "gpu": # TODO(necula): revisit once we fix the GPU tests atol = 1e-3 self.ConvertAndCompare(harness.dyn_fun, arg, custom_assert=custom_assert, atol=atol)
def testReluGrad(self): rtol = 1e-2 if jtu.device_under_test() == "tpu" else None check_grads(nn.relu, (1., ), order=3, rtol=rtol) check_grads(nn.relu, (-1., ), order=3, rtol=rtol) jaxpr = jax.make_jaxpr(jax.grad(nn.relu))(0.) self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 2)
def test_generate_limitations_doc(self): """Generates primitives_with_limited_support.md. See the doc for instructions. """ harnesses = [ h for h in primitive_harness.all_harnesses if h.filter(h, include_jax_unimpl=True) ] print(f"Found {len(harnesses)} test harnesses that work in JAX") def unique_hash(h: primitive_harness.Harness, l: Jax2TfLimitation): return (h.group_name, l.description, l.devices, tuple([np.dtype(d).name for d in l.dtypes]), l.modes) unique_limitations: Dict[Any, Tuple[primitive_harness.Harness, Jax2TfLimitation]] = {} for h in harnesses: for l in h.jax_unimplemented: if l.enabled: # Fake a Jax2TFLimitation from the Limitation tfl = Jax2TfLimitation( description="Not implemented in JAX: " + l.description, devices=l.devices, dtypes=l.dtypes, expect_tf_error=False, skip_tf_run=True) unique_limitations[hash(unique_hash(h, tfl))] = (h, tfl) for h in harnesses: for l in Jax2TfLimitation.limitations_for_harness(h): unique_limitations[hash(unique_hash(h, l))] = (h, l) print(f"Found {len(unique_limitations)} unique limitations") tf_error_table = [ """ | Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes | | --- | --- | --- | --- | --- |""" ] tf_numerical_discrepancies_table = list(tf_error_table) # a copy for h, l in sorted(unique_limitations.values(), key=lambda pair: unique_hash(*pair)): devices = ", ".join(sorted(l.devices)) modes = ", ".join(sorted(l.modes)) description = l.description if l.skip_comparison: description = "Numeric comparision disabled: " + description if l.expect_tf_error: description = "TF error: " + description if l.skip_tf_run: description = "TF test skipped: " + description if l.skip_tf_run or l.expect_tf_error: to_table = tf_error_table elif l.skip_comparison or l.custom_assert: to_table = tf_numerical_discrepancies_table else: continue to_table.append( f"| {h.group_name} | {description} | " f"{primitive_harness.dtypes_to_str(l.dtypes, empty_means_all=True)} | {devices} | {modes} |" ) if not os.environ.get("JAX_OUTPUT_LIMITATIONS_DOC"): raise unittest.SkipTest( "Set JAX_OUTPUT_LIMITATIONS_DOC=1 to enable the generation of the documentation" ) # The CPU has more supported types, and harnesses self.assertEqual("cpu", jtu.device_under_test()) self.assertTrue( config.x64_enabled, "Documentation generation must be run with JAX_ENABLE_X64=1") with open( os.path.join( os.path.dirname(__file__), "../g3doc/primitives_with_limited_support.md.template") ) as f: template = f.read() output_file = os.path.join( os.path.dirname(__file__), "../g3doc/primitives_with_limited_support.md") with open(output_file, "w") as f: f.write(template.replace("{{generation_date}}", str(datetime.date.today())) \ .replace("{{tf_error_table}}", "\n".join(tf_error_table)) \ .replace("{{tf_numerical_discrepancies_table}}", "\n".join(tf_numerical_discrepancies_table)) \ )
def setUp(self): super(ShardedJitTest, self).setUp() if jtu.device_under_test() != "tpu": raise SkipTest
class cuSparseTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype} for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) def test_csr_todense(self, shape, dtype): rng = rand_sparse(self.rng(), post=sparse.csr_matrix) M = rng(shape, dtype) args = (M.data, M.indices, M.indptr) todense = lambda *args: sparse_ops.csr_todense(*args, shape=M.shape) self.assertArraysEqual(M.toarray(), todense(*args)) self.assertArraysEqual(M.toarray(), jit(todense)(*args)) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype} for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) def test_csr_fromdense(self, shape, dtype): rng = rand_sparse(self.rng()) M = rng(shape, dtype) M_csr = sparse.csr_matrix(M) nnz = M_csr.nnz index_dtype = jnp.int32 fromdense = lambda M: sparse_ops.csr_fromdense(M, nnz=nnz, index_dtype=jnp.int32) data, indices, indptr = fromdense(M) self.assertArraysEqual(data, M_csr.data.astype(dtype)) self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype)) self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype)) data, indices, indptr = jit(fromdense)(M) self.assertArraysEqual(data, M_csr.data.astype(dtype)) self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype)) self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype)) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose), "shape": shape, "dtype": dtype, "transpose": transpose} for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex for transpose in [True, False])) def test_csr_matvec(self, shape, dtype, transpose): op = lambda M: M.T if transpose else M v_rng = jtu.rand_default(self.rng()) rng = rand_sparse(self.rng(), post=sparse.csr_matrix) M = rng(shape, dtype) v = v_rng(op(M).shape[1], dtype) args = (M.data, M.indices, M.indptr, v) matvec = lambda *args: sparse_ops.csr_matvec(*args, shape=M.shape, transpose=transpose) self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL) self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose), "shape": shape, "dtype": dtype, "transpose": transpose} for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex for transpose in [True, False])) def test_csr_matmat(self, shape, dtype, transpose): op = lambda M: M.T if transpose else M B_rng = jtu.rand_default(self.rng()) rng = rand_sparse(self.rng(), post=sparse.csr_matrix) M = rng(shape, dtype) B = B_rng((op(M).shape[1], 4), dtype) args = (M.data, M.indices, M.indptr, B) matmat = lambda *args: sparse_ops.csr_matmat(*args, shape=shape, transpose=transpose) self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL) self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype} for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) def test_coo_todense(self, shape, dtype): rng = rand_sparse(self.rng(), post=sparse.coo_matrix) M = rng(shape, dtype) args = (M.data, M.row, M.col) todense = lambda *args: sparse_ops.coo_todense(*args, shape=M.shape) self.assertArraysEqual(M.toarray(), todense(*args)) self.assertArraysEqual(M.toarray(), jit(todense)(*args)) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype} for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) def test_coo_fromdense(self, shape, dtype): rng = rand_sparse(self.rng()) M = rng(shape, dtype) M_coo = sparse.coo_matrix(M) nnz = M_coo.nnz index_dtype = jnp.int32 fromdense = lambda M: sparse_ops.coo_fromdense(M, nnz=nnz, index_dtype=jnp.int32) data, row, col = fromdense(M) self.assertArraysEqual(data, M_coo.data.astype(dtype)) self.assertArraysEqual(row, M_coo.row.astype(index_dtype)) self.assertArraysEqual(col, M_coo.col.astype(index_dtype)) data, indices, indptr = jit(fromdense)(M) self.assertArraysEqual(data, M_coo.data.astype(dtype)) self.assertArraysEqual(row, M_coo.row.astype(index_dtype)) self.assertArraysEqual(col, M_coo.col.astype(index_dtype)) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose), "shape": shape, "dtype": dtype, "transpose": transpose} for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex for transpose in [True, False])) def test_coo_matvec(self, shape, dtype, transpose): op = lambda M: M.T if transpose else M v_rng = jtu.rand_default(self.rng()) rng = rand_sparse(self.rng(), post=sparse.coo_matrix) M = rng(shape, dtype) v = v_rng(op(M).shape[1], dtype) args = (M.data, M.row, M.col, v) matvec = lambda *args: sparse_ops.coo_matvec(*args, shape=M.shape, transpose=transpose) self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL) self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL) @unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU") @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose), "shape": shape, "dtype": dtype, "transpose": transpose} for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex for transpose in [True, False])) def test_coo_matmat(self, shape, dtype, transpose): op = lambda M: M.T if transpose else M B_rng = jtu.rand_default(self.rng()) rng = rand_sparse(self.rng(), post=sparse.coo_matrix) M = rng(shape, dtype) B = B_rng((op(M).shape[1], 4), dtype) args = (M.data, M.row, M.col, B) matmat = lambda *args: sparse_ops.coo_matmat(*args, shape=shape, transpose=transpose) self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL) self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL) y, dy = jvp(lambda x: sparse_ops.coo_matmat(M.data, M.row, M.col, x, shape=shape, transpose=transpose).sum(), (B, ), (jnp.ones_like(B), )) self.assertAllClose((op(M) @ B).sum(), y, rtol=MATMUL_TOL) y, dy = jvp(lambda x: sparse_ops.coo_matmat(x, M.row, M.col, B, shape=shape, transpose=transpose).sum(), (M.data, ), (jnp.ones_like(M.data), )) self.assertAllClose((op(M) @ B).sum(), y, rtol=MATMUL_TOL) @unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU") def test_gpu_translation_rule(self): version = xla_bridge.get_backend().platform_version cuda_version = None if version == "<unknown>" else int(version.split()[-1]) if cuda_version is None or cuda_version < 11000: self.assertFalse(cusparse and cusparse.is_supported) self.assertNotIn(sparse_ops.csr_todense_p, xla.backend_specific_translations["gpu"]) else: self.assertTrue(cusparse and cusparse.is_supported) self.assertIn(sparse_ops.csr_todense_p, xla.backend_specific_translations["gpu"]) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_{}".format( jtu.format_shape_dtype_string(shape, dtype), mat_type), "shape": shape, "dtype": dtype, "mat_type": mat_type} for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex for mat_type in ['csr', 'coo'])) def test_extra_nnz(self, shape, dtype, mat_type): rng = rand_sparse(self.rng()) M = rng(shape, dtype) nnz = (M != 0).sum() + 5 fromdense = getattr(sparse_ops, f"{mat_type}_fromdense") todense = getattr(sparse_ops, f"{mat_type}_todense") args = fromdense(M, nnz=nnz, index_dtype=jnp.int32) M_out = todense(*args, shape=M.shape) self.assertArraysEqual(M, M_out) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype} for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) def test_coo_todense_ad(self, shape, dtype): rng = rand_sparse(self.rng(), post=jnp.array) M = rng(shape, dtype) data, row, col = sparse_ops.coo_fromdense(M, nnz=(M != 0).sum()) f = lambda data: sparse_ops.coo_todense(data, row, col, shape=M.shape) # Forward-mode primals, tangents = api.jvp(f, [data], [jnp.ones_like(data)]) self.assertArraysEqual(primals, f(data)) self.assertArraysEqual(tangents, jnp.zeros_like(M).at[row, col].set(1)) # Reverse-mode primals, vjp_fun = api.vjp(f, data) data_out, = vjp_fun(primals) self.assertArraysEqual(primals, f(data)) self.assertArraysEqual(data_out, data) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype} for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) def test_coo_fromdense_ad(self, shape, dtype): rng = rand_sparse(self.rng(), post=jnp.array) M = rng(shape, dtype) nnz = (M != 0).sum() f = lambda M: sparse_ops.coo_fromdense(M, nnz=nnz) # Forward-mode primals, tangents = api.jvp(f, [M], [jnp.ones_like(M)]) self.assertArraysEqual(primals[0], f(M)[0]) self.assertArraysEqual(primals[1], f(M)[1]) self.assertArraysEqual(primals[2], f(M)[2]) self.assertArraysEqual(tangents[0], jnp.ones(nnz, dtype=dtype)) self.assertEqual(tangents[1].dtype, dtypes.float0) self.assertEqual(tangents[2].dtype, dtypes.float0) # Reverse-mode primals, vjp_fun = api.vjp(f, M) M_out, = vjp_fun(primals) self.assertArraysEqual(primals[0], f(M)[0]) self.assertArraysEqual(primals[1], f(M)[1]) self.assertArraysEqual(primals[2], f(M)[2]) self.assertArraysEqual(M_out, M) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_{}".format( jtu.format_shape_dtype_string(shape, dtype), jtu.format_shape_dtype_string(bshape, dtype)), "shape": shape, "dtype": dtype, "bshape": bshape} for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for bshape in [shape[-1:] + s for s in [()]] # TODO: matmul autodiff for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) # TODO: other types def test_coo_matvec_ad(self, shape, dtype, bshape): tol = {np.float32: 1E-6, np.float64: 1E-13, np.complex64: 1E-6, np.complex128: 1E-13} rng = rand_sparse(self.rng(), post=jnp.array) rng_b = jtu.rand_default(self.rng()) M = rng(shape, dtype) data, row, col = sparse_ops.coo_fromdense(M, nnz=(M != 0).sum()) x = rng_b(bshape, dtype) xdot = rng_b(bshape, dtype) # Forward-mode with respect to the vector f_dense = lambda x: M @ x f_sparse = lambda x: sparse_ops.coo_matvec(data, row, col, x, shape=M.shape) v_sparse, t_sparse = api.jvp(f_sparse, [x], [xdot]) v_dense, t_dense = api.jvp(f_dense, [x], [xdot]) self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol) self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol) # Reverse-mode with respect to the vector primals_dense, vjp_dense = api.vjp(f_dense, x) primals_sparse, vjp_sparse = api.vjp(f_sparse, x) out_dense, = vjp_dense(primals_dense) out_sparse, = vjp_sparse(primals_sparse) self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol) self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol) # Forward-mode with respect to nonzero elements of the matrix f_sparse = lambda data: sparse_ops.coo_matvec(data, row, col, x, shape=M.shape) f_dense = lambda data: sparse_ops.coo_todense(data, row, col, shape=M.shape) @ x data = rng((len(data),), data.dtype) data_dot = rng((len(data),), data.dtype) v_sparse, t_sparse = api.jvp(f_sparse, [data], [data_dot]) v_dense, t_dense = api.jvp(f_dense, [data], [data_dot]) self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol) self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol) # Reverse-mode with respect to nonzero elements of the matrix primals_dense, vjp_dense = api.vjp(f_dense, data) primals_sparse, vjp_sparse = api.vjp(f_sparse, data) out_dense, = vjp_dense(primals_dense) out_sparse, = vjp_sparse(primals_sparse) self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol) self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)
def test_unary_elementwise(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] lax_name = harness.params["lax_name"] if (lax_name in ("acosh", "asinh", "atanh", "bessel_i0e", "bessel_i1e", "digamma", "erf", "erf_inv", "erfc", "lgamma", "round", "rsqrt") and dtype is dtypes.bfloat16 and jtu.device_under_test() in ["cpu", "gpu"]): raise unittest.SkipTest( f"bfloat16 support is missing from '{lax_name}' TF kernel on {jtu.device_under_test()} devices." ) # TODO(bchetioui): do they have bfloat16 support, though? if lax_name in ("sinh", "cosh", "atanh", "asinh", "acosh", "erf_inv") and dtype is np.float16: raise unittest.SkipTest( "b/158006398: float16 support is missing from '%s' TF kernel" % lax_name) arg, = harness.dyn_args_maker(self.rng()) custom_assert = None if lax_name == "digamma": # TODO(necula): fix bug with digamma/(f32|f16) on TPU if dtype in [np.float16, np.float32 ] and jtu.device_under_test() == "tpu": raise unittest.SkipTest("TODO: fix bug: nan vs not-nan") # In the bfloat16 case, TF and lax both return NaN in undefined cases. if not dtype is dtypes.bfloat16: # digamma is not defined at 0 and -1 def custom_assert(result_jax, result_tf): # lax.digamma returns NaN and tf.math.digamma returns inf special_cases = (arg == 0.) | (arg == -1.) nr_special_cases = np.count_nonzero(special_cases) self.assertAllClose( np.full((nr_special_cases, ), dtype(np.nan)), result_jax[special_cases]) self.assertAllClose( np.full((nr_special_cases, ), dtype(np.inf)), result_tf[special_cases]) # non-special cases are equal self.assertAllClose(result_jax[~special_cases], result_tf[~special_cases]) if lax_name == "erf_inv": # TODO(necula): fix erf_inv bug on TPU if jtu.device_under_test() == "tpu": raise unittest.SkipTest("erf_inv bug on TPU: nan vs non-nan") # TODO: investigate: in the (b)float16 cases, TF and lax both return the same # result in undefined cases. if not dtype in [np.float16, dtypes.bfloat16]: # erf_inv is not defined for arg <= -1 or arg >= 1 def custom_assert(result_jax, result_tf): # noqa: F811 # for arg < -1 or arg > 1 # lax.erf_inv returns NaN; tf.math.erf_inv return +/- inf special_cases = (arg < -1.) | (arg > 1.) nr_special_cases = np.count_nonzero(special_cases) self.assertAllClose( np.full((nr_special_cases, ), dtype(np.nan)), result_jax[special_cases]) signs = np.where(arg[special_cases] < 0., -1., 1.) self.assertAllClose( np.full((nr_special_cases, ), signs * dtype(np.inf)), result_tf[special_cases]) # non-special cases are equal self.assertAllClose(result_jax[~special_cases], result_tf[~special_cases]) atol = None if jtu.device_under_test() == "gpu": # TODO(necula): revisit once we fix the GPU tests atol = 1e-3 self.ConvertAndCompare(harness.dyn_fun, arg, custom_assert=custom_assert, atol=atol)
def setUp(self): super(CudaArrayInterfaceTest, self).setUp() if jtu.device_under_test() != "gpu": self.skipTest("__cuda_array_interface__ is only supported on GPU")
def _add_vmap_primitive_harnesses(): """For each harness group, pick a single dtype. Ignore harnesses that fail in graph mode in jax2tf. """ all_h = primitive_harness.all_harnesses # Index by group harness_groups: Dict[ str, Sequence[primitive_harness.Harness]] = collections.defaultdict(list) device = jtu.device_under_test() for h in all_h: # Drop the the JAX limitations if not h.filter(device_under_test=device, include_jax_unimpl=False): continue # And the jax2tf limitations that are known to result in TF error. if any(l.expect_tf_error for l in _get_jax2tf_limitations(device, h)): continue harness_groups[h.group_name].append(h) selected_harnesses = [] for group_name, hlist in harness_groups.items(): # Pick the dtype with the most harnesses in this group. Some harness # groups only test different use cases at a few dtypes. c = collections.Counter([h.dtype for h in hlist]) (dtype, _), = c.most_common(1) selected_harnesses.extend([h for h in hlist if h.dtype == dtype]) # We do not yet support shape polymorphism for vmap for some primitives _NOT_SUPPORTED_YET = frozenset([ # In the random._gamma_impl we do reshape(-1, 2) for the keys "random_gamma", # In linalg._lu_python we do reshape(-1, ...) "lu", "custom_linear_solve", # We do *= shapes in the batching rule for conv_general_dilated "conv_general_dilated", # vmap(clamp) fails in JAX "clamp", "iota", # vmap does not make sense for 0-argument functions ]) batch_size = 3 for h in selected_harnesses: if h.group_name in _NOT_SUPPORTED_YET: continue def make_batched_arg_descriptor( ad: primitive_harness.ArgDescriptor ) -> Optional[primitive_harness.ArgDescriptor]: if isinstance(ad, RandArg): return RandArg((batch_size, ) + ad.shape, ad.dtype) elif isinstance(ad, CustomArg): def wrap_custom(rng): arg = ad.make(rng) return np.stack([arg] * batch_size) return CustomArg(wrap_custom) else: assert isinstance(ad, np.ndarray), ad return np.stack([ad] * batch_size) new_args = [ make_batched_arg_descriptor(ad) for ad in h.arg_descriptors if not isinstance(ad, StaticArg) ] # We do not check the result of harnesses that require custom assertions. check_result = all( not l.custom_assert and not l.skip_comparison and l.tol is None for l in _get_jax2tf_limitations(device, h)) vmap_harness = _make_harness(h.group_name, f"vmap_{h.name}", jax.vmap(h.dyn_fun, in_axes=0, out_axes=0), new_args, poly_axes=[0] * len(new_args), check_result=check_result, **h.params) _POLY_SHAPE_TEST_HARNESSES.append(vmap_harness)