def test_prim(self, harness: primitive_harness.Harness): limitations = Jax2TfLimitation.limitations_for_harness(harness) device = jtu.device_under_test() limitations = tuple(filter(lambda l: l.filter(device=device, dtype=harness.dtype), limitations)) func_jax = harness.dyn_fun args = harness.dyn_args_maker(self.rng()) self.ConvertAndCompare(func_jax, *args, limitations=limitations)
def _get_jax2tf_limitations( device, h: primitive_harness.Harness) -> Sequence[Jax2TfLimitation]: # And the jax2tf limitations def applicable_jax2tf_limitation(l: Jax2TfLimitation) -> bool: # The CheckShapePolymorphism uses tf.function, so we care about "graph" return l.filter(device=device, dtype=h.dtype, mode="graph") limitations = Jax2TfLimitation.limitations_for_harness(h) return tuple(filter(applicable_jax2tf_limitation, limitations))
def _get_jax2tf_limitations( device, h: primitive_harness.Harness) -> Sequence[Jax2TfLimitation]: # And the jax2tf limitations def applicable_jax2tf_limitation(l: Jax2TfLimitation) -> bool: return (l.filter(device=device, dtype=h.dtype, mode="graph") and l.expect_tf_error) limitations = Jax2TfLimitation.limitations_for_harness(h) return tuple(filter(applicable_jax2tf_limitation, limitations))
def test_prim(self, harness: primitive_harness.Harness): limitations = Jax2TfLimitation.limitations_for_harness(harness) device = jtu.device_under_test() limitations = tuple(filter(lambda l: l.filter(device=device, dtype=harness.dtype), limitations)) func_jax = harness.dyn_fun args = harness.dyn_args_maker(self.rng()) enable_xla = harness.params.get("enable_xla", True) associative_scan_reductions = harness.params.get("associative_scan_reductions", False) with jax.jax2tf_associative_scan_reductions(associative_scan_reductions): self.ConvertAndCompare(func_jax, *args, limitations=limitations, enable_xla=enable_xla)
def test_generate_limitations_doc(self): """Generates primitives_with_limited_support.md. See the doc for instructions. """ harnesses = [ h for h in primitive_harness.all_harnesses if h.filter(h, include_jax_unimpl=True) ] print(f"Found {len(harnesses)} test harnesses that work in JAX") def unique_hash(h: primitive_harness.Harness, l: Jax2TfLimitation): return (h.group_name, l.description, l.devices, tuple([np.dtype(d).name for d in l.dtypes]), l.modes) unique_limitations: Dict[Any, Tuple[primitive_harness.Harness, Jax2TfLimitation]] = {} for h in harnesses: for l in h.jax_unimplemented: if l.enabled: # Fake a Jax2TFLimitation from the Limitation tfl = Jax2TfLimitation( description="Not implemented in JAX: " + l.description, devices=l.devices, dtypes=l.dtypes, expect_tf_error=False, skip_tf_run=True) unique_limitations[hash(unique_hash(h, tfl))] = (h, tfl) for h in harnesses: for l in Jax2TfLimitation.limitations_for_harness(h): unique_limitations[hash(unique_hash(h, l))] = (h, l) print(f"Found {len(unique_limitations)} unique limitations") tf_error_table = [ """ | Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes | | --- | --- | --- | --- | --- | ---|""" ] tf_numerical_discrepancies_table = list(tf_error_table) # a copy for h, l in sorted(unique_limitations.values(), key=lambda pair: unique_hash(*pair)): devices = ", ".join(sorted(l.devices)) modes = ", ".join(sorted(l.modes)) description = l.description if l.skip_comparison: description = "Numeric comparision disabled: " + description if l.expect_tf_error: description = "TF error: " + description if l.skip_tf_run: description = "TF test skipped: " + description if l.skip_tf_run or l.expect_tf_error: to_table = tf_error_table elif l.skip_comparison or l.custom_assert: to_table = tf_numerical_discrepancies_table else: continue to_table.append( f"| {h.group_name} | {description} | " f"{primitive_harness.dtypes_to_str(l.dtypes, empty_means_all=True)} | {devices} | {modes} |" ) if not os.environ.get("JAX_OUTPUT_LIMITATIONS_DOC"): raise unittest.SkipTest( "Set JAX_OUTPUT_LIMITATIONS_DOC=1 to enable the generation of the documentation" ) # The CPU has more supported types, and harnesses self.assertEqual("cpu", jtu.device_under_test()) self.assertTrue( config.x64_enabled, "Documentation generation must be run with JAX_ENABLE_X64=1") with open( os.path.join( os.path.dirname(__file__), "../g3doc/primitives_with_limited_support.md.template") ) as f: template = f.read() output_file = os.path.join( os.path.dirname(__file__), "../g3doc/primitives_with_limited_support.md") with open(output_file, "w") as f: f.write(template.replace("{{generation_date}}", str(datetime.date.today())) \ .replace("{{tf_error_table}}", "\n".join(tf_error_table)) \ .replace("{{tf_numerical_discrepancies_table}}", "\n".join(tf_numerical_discrepancies_table)) \ )