Exemplo n.º 1
0
 def test_prim(self, harness: primitive_harness.Harness):
     limitations = Jax2TfLimitation.limitations_for_harness(harness)
     device = jtu.device_under_test()
     limitations = tuple(
         filter(lambda l: l.filter(device=device, dtype=harness.dtype),
                limitations))
     func_jax = harness.dyn_fun
     args = harness.dyn_args_maker(self.rng())
     self.ConvertAndCompare(func_jax, *args, limitations=limitations)
Exemplo n.º 2
0
def _get_jax2tf_limitations(
    device, h: primitive_harness.Harness) -> Sequence[Jax2TfLimitation]:
  # And the jax2tf limitations
  def applicable_jax2tf_limitation(l: Jax2TfLimitation) -> bool:
    # The CheckShapePolymorphism uses tf.function, so we care about "graph"
    return l.filter(device=device, dtype=h.dtype, mode="graph")

  limitations = Jax2TfLimitation.limitations_for_harness(h)
  return tuple(filter(applicable_jax2tf_limitation, limitations))
Exemplo n.º 3
0
def _get_jax2tf_limitations(
        device, h: primitive_harness.Harness) -> Sequence[Jax2TfLimitation]:
    # And the jax2tf limitations
    def applicable_jax2tf_limitation(l: Jax2TfLimitation) -> bool:
        return (l.filter(device=device, dtype=h.dtype, mode="graph")
                and l.expect_tf_error)

    limitations = Jax2TfLimitation.limitations_for_harness(h)
    return tuple(filter(applicable_jax2tf_limitation, limitations))
Exemplo n.º 4
0
 def test_prim(self, harness: primitive_harness.Harness):
   limitations = Jax2TfLimitation.limitations_for_harness(harness)
   device = jtu.device_under_test()
   limitations = tuple(filter(lambda l: l.filter(device=device,
                                                 dtype=harness.dtype), limitations))
   func_jax = harness.dyn_fun
   args = harness.dyn_args_maker(self.rng())
   enable_xla = harness.params.get("enable_xla", True)
   associative_scan_reductions = harness.params.get("associative_scan_reductions", False)
   with jax.jax2tf_associative_scan_reductions(associative_scan_reductions):
     self.ConvertAndCompare(func_jax, *args, limitations=limitations,
                            enable_xla=enable_xla)
Exemplo n.º 5
0
 def applicable_jax2tf_limitation(l: Jax2TfLimitation) -> bool:
   # The CheckShapePolymorphism uses tf.function, so we care about "graph"
   return l.filter(device=device, dtype=h.dtype, mode="graph")
Exemplo n.º 6
0
    def test_generate_limitations_doc(self):
        """Generates primitives_with_limited_support.md.

    See the doc for instructions.
    """

        harnesses = [
            h for h in primitive_harness.all_harnesses
            if h.filter(h, include_jax_unimpl=True)
        ]
        print(f"Found {len(harnesses)} test harnesses that work in JAX")

        def unique_hash(h: primitive_harness.Harness, l: Jax2TfLimitation):
            return (h.group_name, l.description, l.devices,
                    tuple([np.dtype(d).name for d in l.dtypes]), l.modes)

        unique_limitations: Dict[Any, Tuple[primitive_harness.Harness,
                                            Jax2TfLimitation]] = {}
        for h in harnesses:
            for l in h.jax_unimplemented:
                if l.enabled:
                    # Fake a Jax2TFLimitation from the Limitation
                    tfl = Jax2TfLimitation(
                        description="Not implemented in JAX: " + l.description,
                        devices=l.devices,
                        dtypes=l.dtypes,
                        expect_tf_error=False,
                        skip_tf_run=True)
                    unique_limitations[hash(unique_hash(h, tfl))] = (h, tfl)
        for h in harnesses:
            for l in Jax2TfLimitation.limitations_for_harness(h):
                unique_limitations[hash(unique_hash(h, l))] = (h, l)

        print(f"Found {len(unique_limitations)} unique limitations")
        tf_error_table = [
            """
| Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes |
| --- | --- | --- | --- | --- | ---|"""
        ]
        tf_numerical_discrepancies_table = list(tf_error_table)  # a copy
        for h, l in sorted(unique_limitations.values(),
                           key=lambda pair: unique_hash(*pair)):
            devices = ", ".join(sorted(l.devices))
            modes = ", ".join(sorted(l.modes))
            description = l.description
            if l.skip_comparison:
                description = "Numeric comparision disabled: " + description
            if l.expect_tf_error:
                description = "TF error: " + description
            if l.skip_tf_run:
                description = "TF test skipped: " + description

            if l.skip_tf_run or l.expect_tf_error:
                to_table = tf_error_table
            elif l.skip_comparison or l.custom_assert:
                to_table = tf_numerical_discrepancies_table
            else:
                continue

            to_table.append(
                f"| {h.group_name} | {description} | "
                f"{primitive_harness.dtypes_to_str(l.dtypes, empty_means_all=True)} | {devices} | {modes} |"
            )

        if not os.environ.get("JAX_OUTPUT_LIMITATIONS_DOC"):
            raise unittest.SkipTest(
                "Set JAX_OUTPUT_LIMITATIONS_DOC=1 to enable the generation of the documentation"
            )
        # The CPU has more supported types, and harnesses
        self.assertEqual("cpu", jtu.device_under_test())
        self.assertTrue(
            config.x64_enabled,
            "Documentation generation must be run with JAX_ENABLE_X64=1")

        with open(
                os.path.join(
                    os.path.dirname(__file__),
                    "../g3doc/primitives_with_limited_support.md.template")
        ) as f:
            template = f.read()
        output_file = os.path.join(
            os.path.dirname(__file__),
            "../g3doc/primitives_with_limited_support.md")

        with open(output_file, "w") as f:
            f.write(template.replace("{{generation_date}}", str(datetime.date.today())) \
                    .replace("{{tf_error_table}}", "\n".join(tf_error_table)) \
                    .replace("{{tf_numerical_discrepancies_table}}", "\n".join(tf_numerical_discrepancies_table)) \
                    )
Exemplo n.º 7
0
 def applicable_jax2tf_limitation(l: Jax2TfLimitation) -> bool:
     return (l.filter(device=device, dtype=h.dtype, mode="graph")
             and l.expect_tf_error)