Exemple #1
0
    def ConvertAndCompare(self,
                          func_jax: Callable,
                          *args,
                          enable_xla: bool = True,
                          limitations: Sequence = ()):
        """Compares jax_func(*args) with convert(jax_func)(*args).

    It compares the result of JAX, TF ("eager" mode),
    TF with tf.function ("graph" mode), and TF with
    tf.function(jit_compile=True) ("compiled" mode). In each mode,
    either we expect to encounter a known limitation, or the value should
    match the value from the JAX execution.

    Args:
      func_jax: the function to invoke (``func_jax(*args)``)
      args: the arguments.
      enable_xla: if True, allows the use of XLA ops in jax2tf.convert
        (default: True).
      limitations: the set of limitations for this harness (not yet filtered
        by mode).
    """
        # Run JAX. Should not fail, we assume that the harness has been filtered
        # already by JAX unimplemented primitives.
        result_jax = func_jax(*args)  # JAX
        result_tf = None

        func_tf = jax2tf.convert(func_jax, enable_xla=enable_xla)

        unexpected_successes: List[str] = []
        # Run the "compiled" mode first, it is most important
        for mode in ("compiled", "eager", "graph"):

            def log_message(extra):
                return f"[{self._testMethodName}] mode={mode}: {extra}"

            jax2tf_limits = tuple(
                filter(lambda l: l.filter(mode=mode), limitations))

            skip_tf_run = [l for l in jax2tf_limits if l.skip_tf_run]
            if skip_tf_run:
                logging.info(
                    log_message(
                        f"Skip TF run due to limitations {skip_tf_run}"))
                continue

            try:
                result_tf = _run_tf_function(func_tf, *args, mode=mode)
                tf_exception = None
            except Exception as e:
                tf_exception = e

            expect_tf_error = [l for l in jax2tf_limits if l.expect_tf_error]
            if tf_exception:
                if expect_tf_error:
                    logging.info(
                        log_message(
                            "Found expected TF error with enabled limitations "
                            f"{expect_tf_error}; TF error is {tf_exception}"))
                    continue
                else:
                    raise tf_exception
            else:
                if expect_tf_error:
                    # It is more ergonomic to print all successful modes once
                    logging.warning(
                        log_message(
                            f"Unexpected success with known limitations {expect_tf_error}"
                        ))
                    unexpected_successes.append(f"{mode}: {expect_tf_error}")

            if (jtu.device_under_test() == "gpu"
                    and "dot_general_preferred" in self._testMethodName):
                logging.info(
                    log_message(
                        f"Arguments are {args}, JAX result is {result_jax}\nand TF result is {result_tf}"
                    ))

            skip_comparison = [l for l in jax2tf_limits if l.skip_comparison]
            if skip_comparison:
                logging.warning(
                    log_message(
                        f"Skip result comparison due to {skip_comparison}"))
                continue

            max_tol = None
            max_tol_lim = None if not jax2tf_limits else jax2tf_limits[
                0].get_max_tolerance_limitation(jax2tf_limits)
            if max_tol_lim is not None:
                max_tol = max_tol_lim.tol
                logging.info(
                    log_message(f"Using tol={max_tol} due to {max_tol_lim}"))

            # Convert results to np.arrays
            result_tf = tf.nest.map_structure(lambda t: t.numpy(),
                                              result_tf)  # type: ignore

            custom_assert_lim = [l for l in jax2tf_limits if l.custom_assert]
            assert len(
                custom_assert_lim
            ) <= 1, f"Expecting at most one applicable limitation with custom_assert, found {custom_assert_lim}"

            try:
                err_msg = f"TF mode {mode}."
                log_hlo_on_error = mode == "compiled" or jtu.device_under_test(
                ) == "tpu"
                if log_hlo_on_error:
                    err_msg += " See the logs for JAX and TF HLO comparisons."
                if custom_assert_lim:
                    logging.info(
                        log_message(
                            f"Running custom_assert with tol={max_tol} due to {custom_assert_lim[0]}"
                        ))
                    custom_assert_lim[0].custom_assert(self,
                                                       result_jax,
                                                       result_tf,
                                                       args=args,
                                                       tol=max_tol,
                                                       err_msg=err_msg)
                else:
                    logging.info(
                        log_message(
                            f"Running default assert with tol={max_tol}"))
                    self.assertAllClose(result_jax,
                                        result_tf,
                                        atol=max_tol,
                                        rtol=max_tol,
                                        err_msg=err_msg)
            except AssertionError as e:
                # Print the HLO for comparison
                if not log_hlo_on_error:
                    print(
                        f"[{self._testMethodName}] Not logging HLO because the "
                        f"mode was {mode}")
                    raise

                logging.info(
                    f"[{self._testMethodName}] Logging HLO for exception in mode {mode}: {e}"
                )
                jax_comp = jax.xla_computation(func_jax)(*args)
                jax_hlo = jax_comp.as_hlo_text()
                logging.info(f"[{self._testMethodName}] "
                             f"JAX NON_OPT HLO\n{jax_hlo}")

                tf_args_signature = _make_tf_input_signature(*args)
                # If we give the signature, we cannot pass scalars
                tf_args_no_scalars = tuple(
                    map(
                        lambda a, sig: tf.convert_to_tensor(
                            a, dtype=sig.dtype), args, tf_args_signature))

                tf_func_compiled = tf.function(
                    func_tf,
                    autograph=False,
                    jit_compile=True,
                    input_signature=tf_args_signature)
                tf_hlo = tf_func_compiled.experimental_get_compiler_ir(
                    *tf_args_no_scalars)(stage="hlo")
                logging.info(
                    f"[{self._testMethodName}] TF NON OPT HLO\n{tf_hlo}")

                backend = jax.lib.xla_bridge.get_backend()
                modules = backend.compile(jax_comp).hlo_modules()
                jax_opt_hlo = modules[0].to_string()
                logging.info(f"[{self._testMethodName}] "
                             f"JAX OPT HLO\n{jax_opt_hlo}")

                # TODO(b/189265364): Remove this workaround
                if (jtu.device_under_test() == "gpu"
                        and "dot_general" in self._testMethodName):
                    print(
                        f"[{self._testMethodName}] Not logging TF OPT HLO because of "
                        f"crash in tf.experimental_get_compiler_ir (b/189265364)"
                    )
                else:
                    tf_opt_hlo = tf_func_compiled.experimental_get_compiler_ir(
                        *tf_args_no_scalars)(stage="optimized_hlo")
                    logging.info(
                        f"[{self._testMethodName}] TF OPT HLO\n{tf_opt_hlo}")

                raise

        # end "for mode"

        if unexpected_successes:
            msg = (f"[{self._testMethodName}] The following are unexpected "
                   "successful modes:\n" + "\n".join(unexpected_successes))
            logging.warning(msg)
            # Uncomment the below if you want to see warnings as failures
            # self.assertEmpty(msg)
        return result_jax, result_tf
class SparseObjectTest(jtu.JaxTestCase):
  @parameterized.named_parameters(
    {"testcase_name": "_{}".format(Obj.__name__), "Obj": Obj}
    for Obj in [sparse_ops.CSR, sparse_ops.CSC, sparse_ops.COO])
  def test_attrs(self, Obj, shape=(5, 8), dtype=np.float16):
    rng = rand_sparse(self.rng(), post=Obj.fromdense)
    M = rng(shape, dtype)

    assert isinstance(M, Obj)
    assert M.shape == shape
    assert M.dtype == dtype
    assert M.nnz == (M.todense() != 0).sum()
    assert M.data.dtype == dtype

    if isinstance(M, sparse_ops.CSR):
      assert len(M.data) == len(M.indices)
      assert len(M.indptr) == M.shape[0] + 1
    elif isinstance(M, sparse_ops.CSC):
      assert len(M.data) == len(M.indices)
      assert len(M.indptr) == M.shape[1] + 1
    elif isinstance(M, sparse_ops.COO):
      assert len(M.data) == len(M.row) == len(M.col)
    else:
      raise ValueError("Obj={Obj} not expected.")

  @parameterized.named_parameters(itertools.chain.from_iterable(
    jtu.cases_from_list(
      {"testcase_name": "_{}_Obj={}".format(
        jtu.format_shape_dtype_string(shape, dtype), Obj.__name__),
       "shape": shape, "dtype": dtype, "Obj": Obj}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex)
    for Obj in [sparse_ops.CSR, sparse_ops.CSC, sparse_ops.COO]))
  def test_dense_round_trip(self, shape, dtype, Obj):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    Msparse = Obj.fromdense(M)
    self.assertArraysEqual(M, Msparse.todense())

  @parameterized.named_parameters(itertools.chain.from_iterable(
    jtu.cases_from_list(
      {"testcase_name": "_{}_Obj={}".format(
        jtu.format_shape_dtype_string(shape, dtype), Obj.__name__),
       "shape": shape, "dtype": dtype, "Obj": Obj}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex)
    for Obj in [sparse_ops.CSR, sparse_ops.CSC, sparse_ops.COO]))
  def test_transpose(self, shape, dtype, Obj):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    Msparse = Obj.fromdense(M)
    self.assertArraysEqual(M.T, Msparse.T.todense())

  @unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision")
  @parameterized.named_parameters(itertools.chain.from_iterable(
    jtu.cases_from_list(
      {"testcase_name": "_{}_Obj={}_bshape={}".format(
        jtu.format_shape_dtype_string(shape, dtype), Obj.__name__, bshape),
       "shape": shape, "dtype": dtype, "Obj": Obj, "bshape": bshape}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for bshape in [shape[-1:] + s for s in [(), (3,), (4,)]]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex)
    for Obj in [sparse_ops.CSR, sparse_ops.CSC, sparse_ops.COO]))
  def test_matmul(self, shape, dtype, Obj, bshape):
    rng = rand_sparse(self.rng(), post=jnp.array)
    rng_b = jtu.rand_default(self.rng())
    M = rng(shape, dtype)
    Msp = Obj.fromdense(M)
    x = rng_b(bshape, dtype)
    x = jnp.asarray(x)

    self.assertAllClose(M @ x, Msp @ x, rtol=MATMUL_TOL)
    def test_generate_primitives_coverage_doc(self):
        harnesses = primitive_harness.all_harnesses
        print(f"Found {len(harnesses)} harnesses")

        harness_groups: Dict[str,
                             Sequence[primitive_harness.
                                      Harness]] = collections.defaultdict(list)

        def unique_hash(h: primitive_harness.Harness,
                        l: primitive_harness.Limitation):
            return (h.group_name, l.description, l.devices,
                    tuple([np.dtype(d).name for d in l.dtypes]))

        unique_limitations: Dict[Any,
                                 Tuple[primitive_harness.Harness,
                                       primitive_harness.Limitation]] = {}

        for h in harnesses:
            harness_groups[h.group_name].append(h)
            for l in h.jax_unimplemented:
                if l.enabled:
                    unique_limitations[hash(unique_hash(h, l))] = (h, l)

        primitive_coverage_table = [
            """
| Primitive | Total test harnesses | dtypes supported on at least one device | dtypes NOT tested on any device |
| --- | --- | --- | --- | --- |"""
        ]
        all_dtypes = set(jtu.dtypes.all)

        for group_name in sorted(harness_groups.keys()):
            hlist = harness_groups[group_name]
            dtypes_tested = set()  # Tested on at least some device
            for h in hlist:
                dtypes_tested = dtypes_tested.union({h.dtype})

            primitive_coverage_table.append(
                f"| {group_name} | {len(hlist)} | "
                f"{primitive_harness.dtypes_to_str(dtypes_tested)} | "
                f"{primitive_harness.dtypes_to_str(all_dtypes - dtypes_tested)} |"
            )

        print(f"Found {len(unique_limitations)} unique limitations")
        primitive_unimpl_table = [
            """
| Affected primitive | Description of limitation | Affected dtypes | Affected devices |
| --- | --- | --- | --- | --- |"""
        ]
        for h, l in sorted(unique_limitations.values(),
                           key=lambda pair: unique_hash(*pair)):
            devices = ", ".join(l.devices)
            primitive_unimpl_table.append(
                f"|{h.group_name}|{l.description}|"
                f"{primitive_harness.dtypes_to_str(l.dtypes, empty_means_all=True)}|{devices}|"
            )

        if not os.environ.get("JAX_OUTPUT_LIMITATIONS_DOC"):
            raise unittest.SkipTest(
                "Set JAX_OUTPUT_LIMITATIONS_DOC=1 to enable the generation of the documentation"
            )
        # The CPU/GPU have more supported types than TPU.
        self.assertEqual("cpu", jtu.device_under_test(),
                         "The documentation can be generated only on CPU")
        self.assertTrue(
            FLAGS.jax_enable_x64,
            "The documentation must be generated with JAX_ENABLE_X64=1")

        with open(
                os.path.join(
                    os.path.dirname(__file__),
                    '../g3doc/jax_primitives_coverage.md.template')) as f:
            template = f.read()
        output_file = os.path.join(os.path.dirname(__file__),
                                   '../g3doc/jax_primitives_coverage.md')

        with open(output_file, "w") as f:
            f.write(template.replace("{{generation_date}}", str(datetime.date.today())) \
                    .replace("{{nr_harnesses}}", str(len(harnesses))) \
                    .replace("{{nr_primitives}}", str(len(harness_groups))) \
                    .replace("{{primitive_unimpl_table}}", "\n".join(primitive_unimpl_table)) \
                    .replace("{{primitive_coverage_table}}", "\n".join(primitive_coverage_table)))
Exemple #4
0
    def test_svd(self, harness: primitive_harness.Harness):
        if jtu.device_under_test() == "tpu":
            raise unittest.SkipTest(
                "TODO: test crashes the XLA compiler for some TPU variants")
        expect_tf_exceptions = False
        if harness.params["dtype"] in [jnp.float16, dtypes.bfloat16]:
            if jtu.device_under_test() == "tpu":
                # TODO: SVD on TPU for bfloat16 seems to work for JAX but fails for TF
                expect_tf_exceptions = True
            else:
                # Does not work in JAX
                with self.assertRaisesRegex(NotImplementedError,
                                            "Unsupported dtype"):
                    harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
                return

        if harness.params["dtype"] in [jnp.complex64, jnp.complex128]:
            if jtu.device_under_test() == "tpu":
                # TODO: on JAX on TPU there is no SVD implementation for complex
                with self.assertRaisesRegex(
                        RuntimeError,
                        "Binary op compare with different element types"):
                    harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
                return
            else:
                # TODO: on CPU and GPU "No registered 'Svd' OpKernel for XLA_CPU_JIT devices".
                # Works on JAX because JAX uses a custom implementation.
                expect_tf_exceptions = True

        def _custom_assert(r_jax, r_tf, atol=1e-6, rtol=1e-6):
            def _reconstruct_operand(result, is_tf: bool):
                # Reconstructing operand as documented in numpy.linalg.svd (see
                # https://numpy.org/doc/stable/reference/generated/numpy.linalg.svd.html)
                s, u, v = result
                if is_tf:
                    s = s.numpy()
                    u = u.numpy()
                    v = v.numpy()
                U = u[..., :s.shape[-1]]
                V = v[..., :s.shape[-1], :]
                S = s[..., None, :]
                return jnp.matmul(U * S, V), s.shape, u.shape, v.shape

            if harness.params["compute_uv"]:
                r_jax_reconstructed = _reconstruct_operand(r_jax, False)
                r_tf_reconstructed = _reconstruct_operand(r_tf, True)
                self.assertAllClose(r_jax_reconstructed,
                                    r_tf_reconstructed,
                                    atol=atol,
                                    rtol=rtol)
            else:
                self.assertAllClose(r_jax, r_tf, atol=atol, rtol=rtol)

        tol = 1e-4
        custom_assert = partial(_custom_assert, atol=tol, rtol=tol)

        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               atol=tol,
                               rtol=tol,
                               expect_tf_exceptions=expect_tf_exceptions,
                               custom_assert=custom_assert,
                               always_custom_assert=True)
Exemple #5
0
 op_record("square", 1, float_dtypes + complex_dtypes, jtu.rand_default),
 op_record("reciprocal", 1, float_dtypes + complex_dtypes,
           jtu.rand_positive),
 op_record("tan", 1, float_dtypes, jtu.rand_default, {np.float32: 3e-5}),
 op_record("asin", 1, float_dtypes, jtu.rand_small),
 # TODO(j-towns) fix: op_record("acos", 1, float_dtypes, jtu.rand_small),
 op_record("atan", 1, float_dtypes, jtu.rand_small),
 op_record("asinh", 1, float_dtypes, jtu.rand_default),
 op_record("acosh", 1, float_dtypes, jtu.rand_positive),
 # TODO(b/155331781): atanh has only ~float precision
 op_record("atanh", 1, float_dtypes, jtu.rand_small, {np.float64: 1e-9}),
 op_record("sinh", 1, float_dtypes + complex_dtypes, jtu.rand_default),
 op_record("cosh", 1, float_dtypes + complex_dtypes, jtu.rand_default),
 op_record(
     "lgamma", 1, float_dtypes, jtu.rand_positive, {
         np.float32: 1e-3 if jtu.device_under_test() == "tpu" else 1e-5,
         np.float64: 1e-14
     }),
 op_record("digamma", 1, float_dtypes, jtu.rand_positive,
           {np.float64: 1e-14}),
 op_record("betainc", 3, float_dtypes, jtu.rand_positive,
           {np.float64: 1e-14}),
 op_record(
     "igamma", 2,
     [f for f in float_dtypes if f not in [dtypes.bfloat16, np.float16]],
     jtu.rand_positive, {np.float64: 1e-14}),
 op_record(
     "igammac", 2,
     [f for f in float_dtypes if f not in [dtypes.bfloat16, np.float16]],
     jtu.rand_positive, {np.float64: 1e-14}),
 op_record("erf", 1, float_dtypes, jtu.rand_small),
Exemple #6
0
    def test_unary_elementwise(self, harness: primitive_harness.Harness):
        dtype = harness.params["dtype"]
        lax_name = harness.params["lax_name"]
        arg, = harness.dyn_args_maker(self.rng())
        custom_assert = None
        if lax_name == "digamma":
            # TODO(necula): fix bug with digamma/(f32|f16) on TPU
            if dtype in [np.float16, np.float32
                         ] and jtu.device_under_test() == "tpu":
                raise unittest.SkipTest("TODO: fix bug: nan vs not-nan")

            # In the bfloat16 case, TF and lax both return NaN in undefined cases.
            if not dtype is dtypes.bfloat16:
                # digamma is not defined at 0 and -1
                def custom_assert(result_jax, result_tf):
                    # lax.digamma returns NaN and tf.math.digamma returns inf
                    special_cases = (arg == 0.) | (arg == -1.)
                    nr_special_cases = np.count_nonzero(special_cases)
                    self.assertAllClose(
                        np.full((nr_special_cases, ), dtype(np.nan)),
                        result_jax[special_cases])
                    self.assertAllClose(
                        np.full((nr_special_cases, ), dtype(np.inf)),
                        result_tf[special_cases])
                    # non-special cases are equal
                    self.assertAllClose(result_jax[~special_cases],
                                        result_tf[~special_cases])

        if lax_name == "erf_inv":
            # TODO(necula): fix erf_inv bug on TPU
            if jtu.device_under_test() == "tpu":
                raise unittest.SkipTest("erf_inv bug on TPU: nan vs non-nan")
            # TODO: investigate: in the (b)float16 cases, TF and lax both return the
            # same result in undefined cases.
            if not dtype in [np.float16, dtypes.bfloat16]:
                # erf_inv is not defined for arg <= -1 or arg >= 1
                def custom_assert(result_jax, result_tf):  # noqa: F811
                    # for arg < -1 or arg > 1
                    # lax.erf_inv returns NaN; tf.math.erf_inv return +/- inf
                    special_cases = (arg < -1.) | (arg > 1.)
                    nr_special_cases = np.count_nonzero(special_cases)
                    self.assertAllClose(
                        np.full((nr_special_cases, ),
                                dtype(np.nan),
                                dtype=dtype), result_jax[special_cases])
                    signs = np.where(arg[special_cases] < 0., -1., 1.)
                    self.assertAllClose(
                        np.full((nr_special_cases, ),
                                signs * dtype(np.inf),
                                dtype=dtype), result_tf[special_cases])
                    # non-special cases are equal
                    self.assertAllClose(result_jax[~special_cases],
                                        result_tf[~special_cases])

        atol = None
        if jtu.device_under_test() == "gpu":
            # TODO(necula): revisit once we fix the GPU tests
            atol = 1e-3
        self.ConvertAndCompare(harness.dyn_fun,
                               arg,
                               custom_assert=custom_assert,
                               atol=atol)
Exemple #7
0
class LaxBackedScipyTests(jtu.JaxTestCase):
    """Tests for LAX-backed Scipy implementation."""
    def _GetArgsMaker(self, rng, shapes, dtypes):
        return lambda: [
            rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)
        ]

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_axis={}_keepdims={}".format(
                jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
            # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU.
            "rng_factory":
            jtu.rand_some_inf_and_nan
            if jtu.device_under_test() != "cpu" else jtu.rand_default,
            "shape":
            shape,
            "dtype":
            dtype,
            "axis":
            axis,
            "keepdims":
            keepdims
        } for shape in all_shapes for dtype in float_dtypes
                            for axis in range(-len(shape), len(shape))
                            for keepdims in [False, True]))
    @jtu.skip_on_flag("jax_xla_backend", "xrt")
    def testLogSumExp(self, rng_factory, shape, dtype, axis, keepdims):
        rng = rng_factory()

        # TODO(mattjj): test autodiff
        def scipy_fun(array_to_reduce):
            return osp_special.logsumexp(array_to_reduce,
                                         axis,
                                         keepdims=keepdims)

        def lax_fun(array_to_reduce):
            return lsp_special.logsumexp(array_to_reduce,
                                         axis,
                                         keepdims=keepdims)

        args_maker = lambda: [rng(shape, dtype)]
        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        itertools.chain.from_iterable(
            jtu.cases_from_list(
                {
                    "testcase_name":
                    jtu.format_test_name_suffix(rec.test_name, shapes, dtypes),
                    "rng_factory":
                    rec.rng_factory,
                    "shapes":
                    shapes,
                    "dtypes":
                    dtypes,
                    "test_autodiff":
                    rec.test_autodiff,
                    "scipy_op":
                    getattr(osp_special, rec.name),
                    "lax_op":
                    getattr(lsp_special, rec.name)
                } for shapes in CombosWithReplacement(all_shapes, rec.nargs)
                for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
            for rec in JAX_SPECIAL_FUNCTION_RECORDS))
    def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes,
                            dtypes, test_autodiff):
        rng = rng_factory()
        args_maker = self._GetArgsMaker(rng, shapes, dtypes)
        args = args_maker()
        self.assertAllClose(scipy_op(*args),
                            lax_op(*args),
                            atol=1e-3,
                            rtol=1e-3,
                            check_dtypes=False)
        self._CompileAndCheck(lax_op, args_maker, check_dtypes=True)

        if test_autodiff:
            jtu.check_grads(lax_op,
                            args,
                            order=1,
                            atol=1e-3,
                            rtol=3e-3,
                            eps=1e-3)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_d={}".format(
                jtu.format_shape_dtype_string(shape, dtype), d),
            "rng_factory":
            jtu.rand_positive,
            "shape":
            shape,
            "dtype":
            dtype,
            "d":
            d
        } for shape in all_shapes for dtype in float_dtypes
                            for d in [1, 2, 5]))
    def testMultigammaln(self, rng_factory, shape, dtype, d):
        def scipy_fun(a):
            return osp_special.multigammaln(a, d)

        def lax_fun(a):
            return lsp_special.multigammaln(a, d)

        rng = rng_factory()
        args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    def testIssue980(self):
        x = onp.full((4, ), -1e20, dtype=onp.float32)
        self.assertAllClose(onp.zeros((4, ), dtype=onp.float32),
                            lsp_special.expit(x),
                            check_dtypes=True)
Exemple #8
0
 def setUp(self):
     super().setUp()
     if jtu.device_under_test() not in ["tpu", "gpu"]:
         raise SkipTest
     if jtu.device_under_test() == "gpu":
         os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL"
Exemple #9
0
 def default_tolerance():
     if jtu.device_under_test() != 'tpu':
         return jtu._default_tolerance
     tol = jtu._default_tolerance.copy()
     tol[onp.dtype(onp.float32)] = 5e-2
     return tol
Exemple #10
0
 def testSoftplusGrad(self):
     check_grads(nn.softplus, (1e-8, ),
                 4,
                 rtol=1e-2 if jtu.device_under_test() == "tpu" else None)
Exemple #11
0
 def setUp(self):
     super().setUp()
     if jtu.device_under_test() not in ["tpu", "gpu"]:
         raise SkipTest
Exemple #12
0
    # TODO(mattjj): make some-equal checks more robust, enable second-order
    # grad_test_spec(lax.max, nargs=2, order=1, rng_factory=jtu.rand_some_equal,
    #                dtypes=grad_float_dtypes, name="MaxSomeEqual"),
    # grad_test_spec(lax.min, nargs=2, order=1, rng_factory=jtu.rand_some_equal,
    #                dtypes=grad_float_dtypes, name="MinSomeEqual"),
]

GradSpecialValuesTestSpec = collections.namedtuple(
    "GradSpecialValuesTestSpec", ["op", "values", "tol"])
def grad_special_values_test_spec(op, values, tol=None):
  return GradSpecialValuesTestSpec(op, values, tol)

LAX_GRAD_SPECIAL_VALUE_TESTS = [
    grad_special_values_test_spec(
      lax.sinh, [0.],
      tol={np.float32: 1e-2} if jtu.device_under_test() == "tpu" else None),
    grad_special_values_test_spec(
      lax.cosh, [0.],
      tol={np.float32: 1e-2} if jtu.device_under_test() == "tpu" else None),
    grad_special_values_test_spec(lax.tanh, [0., 1000.]),
    grad_special_values_test_spec(lax.sin, [0., np.pi, np.pi/2., np.pi/4.]),
    grad_special_values_test_spec(lax.cos, [0., np.pi, np.pi/2., np.pi/4.]),
    grad_special_values_test_spec(lax.tan, [0.]),
    grad_special_values_test_spec(lax.asin, [0.]),
    grad_special_values_test_spec(lax.acos, [0.]),
    grad_special_values_test_spec(lax.atan, [0., 1000.]),
    grad_special_values_test_spec(lax.erf, [0., 10.]),
    grad_special_values_test_spec(lax.erfc, [0., 10.]),
]

Exemple #13
0
 def setUp(self):
     if jtu.device_under_test() != "gpu":
         self.skipTest("__cuda_array_interface__ is only supported on GPU")
Exemple #14
0
 def setUp(self):
   if jtu.device_under_test() != "tpu":
     raise SkipTest
Exemple #15
0
    def test_eig(self, harness: primitive_harness.Harness):
        operand = harness.dyn_args_maker(self.rng())[0]
        compute_left_eigenvectors = harness.params["compute_left_eigenvectors"]
        compute_right_eigenvectors = harness.params[
            "compute_right_eigenvectors"]
        dtype = harness.params["dtype"]

        if jtu.device_under_test() != "cpu":
            raise unittest.SkipTest("eig only supported on CPU in JAX")

        if dtype in [np.float16, dtypes.bfloat16]:
            raise unittest.SkipTest("eig unsupported with (b)float16 in JAX")

        def custom_assert(result_jax, result_tf):
            result_tf = tuple(map(lambda e: e.numpy(), result_tf))
            inner_dimension = operand.shape[-1]

            # Test ported from tests.lax_test.testEig
            # Norm, adjusted for dimension and type.
            def norm(x):
                norm = np.linalg.norm(x, axis=(-2, -1))
                return norm / ((inner_dimension + 1) * jnp.finfo(dtype).eps)

            def check_right_eigenvectors(a, w, vr):
                self.assertTrue(
                    np.all(
                        norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100))

            def check_left_eigenvectors(a, w, vl):
                rank = len(a.shape)
                aH = jnp.conj(
                    a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2]))
                wC = jnp.conj(w)
                check_right_eigenvectors(aH, wC, vl)

            def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array):
                tol = None
                # TODO(bchetioui): numerical discrepancies
                if dtype in [np.float32, np.complex64]:
                    tol = 1e-4
                elif dtype in [np.float64, np.complex128]:
                    tol = 1e-13
                closest_diff = min(abs(eigenvalues_array - eigenvalue))
                self.assertAllClose(closest_diff,
                                    np.array(0., closest_diff.dtype),
                                    atol=tol)

            all_w_jax, all_w_tf = result_jax[0], result_tf[0]
            for idx in itertools.product(*map(range, operand.shape[:-2])):
                w_jax, w_tf = all_w_jax[idx], all_w_tf[idx]
                for i in range(inner_dimension):
                    check_eigenvalue_is_in_array(w_jax[i], w_tf)
                    check_eigenvalue_is_in_array(w_tf[i], w_jax)

            if compute_left_eigenvectors:
                check_left_eigenvectors(operand, all_w_tf, result_tf[1])
            if compute_right_eigenvectors:
                check_right_eigenvectors(
                    operand, all_w_tf,
                    result_tf[1 + compute_left_eigenvectors])

        self.ConvertAndCompare(harness.dyn_fun,
                               operand,
                               custom_assert=custom_assert)
Exemple #16
0
class cuSparseTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype
        } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
                            for dtype in jtu.dtypes.floating +
                            jtu.dtypes.complex))
    def test_csr_todense(self, shape, dtype):
        rng = rand_sparse(self.rng(), post=sparse.csr_matrix)
        M = rng(shape, dtype)

        args = (M.data, M.indices, M.indptr)
        todense = lambda *args: sparse_ops.csr_todense(*args, shape=M.shape)

        self.assertArraysEqual(M.toarray(), todense(*args))
        self.assertArraysEqual(M.toarray(), jit(todense)(*args))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype
        } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
                            for dtype in jtu.dtypes.floating +
                            jtu.dtypes.complex))
    def test_csr_fromdense(self, shape, dtype):
        rng = rand_sparse(self.rng())
        M = rng(shape, dtype)
        M_csr = sparse.csr_matrix(M)

        nnz = M_csr.nnz
        index_dtype = jnp.int32
        fromdense = lambda M: sparse_ops.csr_fromdense(
            M, nnz=nnz, index_dtype=jnp.int32)

        data, indices, indptr = fromdense(M)
        self.assertArraysEqual(data, M_csr.data.astype(dtype))
        self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype))
        self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype))

        data, indices, indptr = jit(fromdense)(M)
        self.assertArraysEqual(data, M_csr.data.astype(dtype))
        self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype))
        self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype),
                              transpose),
            "shape":
            shape,
            "dtype":
            dtype,
            "transpose":
            transpose
        } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
                            for dtype in jtu.dtypes.floating +
                            jtu.dtypes.complex for transpose in [True, False]))
    def test_csr_matvec(self, shape, dtype, transpose):
        op = lambda M: M.T if transpose else M

        v_rng = jtu.rand_default(self.rng())
        rng = rand_sparse(self.rng(), post=sparse.csr_matrix)
        M = rng(shape, dtype)
        v = v_rng(op(M).shape[1], dtype)

        args = (M.data, M.indices, M.indptr, v)
        matvec = lambda *args: sparse_ops.csr_matvec(
            *args, shape=M.shape, transpose=transpose)

        self.assertAllClose(op(M) @ v, matvec(*args))
        self.assertAllClose(op(M) @ v, jit(matvec)(*args))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype),
                              transpose),
            "shape":
            shape,
            "dtype":
            dtype,
            "transpose":
            transpose
        } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
                            for dtype in jtu.dtypes.floating +
                            jtu.dtypes.complex for transpose in [True, False]))
    def test_csr_matmat(self, shape, dtype, transpose):
        op = lambda M: M.T if transpose else M

        B_rng = jtu.rand_default(self.rng())
        rng = rand_sparse(self.rng(), post=sparse.csr_matrix)
        M = rng(shape, dtype)
        B = B_rng((op(M).shape[1], 4), dtype)

        args = (M.data, M.indices, M.indptr, B)
        matmat = lambda *args: sparse_ops.csr_matmat(
            *args, shape=shape, transpose=transpose)

        self.assertAllClose(op(M) @ B, matmat(*args))
        self.assertAllClose(op(M) @ B, jit(matmat)(*args))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype
        } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
                            for dtype in jtu.dtypes.floating +
                            jtu.dtypes.complex))
    def test_coo_todense(self, shape, dtype):
        rng = rand_sparse(self.rng(), post=sparse.coo_matrix)
        M = rng(shape, dtype)

        args = (M.data, M.row, M.col)
        todense = lambda *args: sparse_ops.coo_todense(*args, shape=M.shape)

        self.assertArraysEqual(M.toarray(), todense(*args))
        self.assertArraysEqual(M.toarray(), jit(todense)(*args))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype
        } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
                            for dtype in jtu.dtypes.floating +
                            jtu.dtypes.complex))
    def test_coo_fromdense(self, shape, dtype):
        rng = rand_sparse(self.rng())
        M = rng(shape, dtype)
        M_coo = sparse.coo_matrix(M)

        nnz = M_coo.nnz
        index_dtype = jnp.int32
        fromdense = lambda M: sparse_ops.coo_fromdense(
            M, nnz=nnz, index_dtype=jnp.int32)

        data, row, col = fromdense(M)
        self.assertArraysEqual(data, M_coo.data.astype(dtype))
        self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
        self.assertArraysEqual(col, M_coo.col.astype(index_dtype))

        data, indices, indptr = jit(fromdense)(M)
        self.assertArraysEqual(data, M_coo.data.astype(dtype))
        self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
        self.assertArraysEqual(col, M_coo.col.astype(index_dtype))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype),
                              transpose),
            "shape":
            shape,
            "dtype":
            dtype,
            "transpose":
            transpose
        } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
                            for dtype in jtu.dtypes.floating +
                            jtu.dtypes.complex for transpose in [True, False]))
    def test_coo_matvec(self, shape, dtype, transpose):
        op = lambda M: M.T if transpose else M

        v_rng = jtu.rand_default(self.rng())
        rng = rand_sparse(self.rng(), post=sparse.coo_matrix)
        M = rng(shape, dtype)
        v = v_rng(op(M).shape[1], dtype)

        args = (M.data, M.row, M.col, v)
        matvec = lambda *args: sparse_ops.coo_matvec(
            *args, shape=M.shape, transpose=transpose)

        self.assertAllClose(op(M) @ v, matvec(*args))
        self.assertAllClose(op(M) @ v, jit(matvec)(*args))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype),
                              transpose),
            "shape":
            shape,
            "dtype":
            dtype,
            "transpose":
            transpose
        } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
                            for dtype in jtu.dtypes.floating +
                            jtu.dtypes.complex for transpose in [True, False]))
    def test_coo_matmat(self, shape, dtype, transpose):
        op = lambda M: M.T if transpose else M

        B_rng = jtu.rand_default(self.rng())
        rng = rand_sparse(self.rng(), post=sparse.coo_matrix)
        M = rng(shape, dtype)
        B = B_rng((op(M).shape[1], 4), dtype)

        args = (M.data, M.row, M.col, B)
        matmat = lambda *args: sparse_ops.coo_matmat(
            *args, shape=shape, transpose=transpose)

        self.assertAllClose(op(M) @ B, matmat(*args))
        self.assertAllClose(op(M) @ B, jit(matmat)(*args))

    @unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
    def test_gpu_translation_rule(self):
        version = xla_bridge.get_backend().platform_version
        cuda_version = None if version == "<unknown>" else int(
            version.split()[-1])
        if cuda_version is None or cuda_version < 11000:
            self.assertNotIn(sparse_ops.csr_todense_p,
                             xla.backend_specific_translations["gpu"])
        else:
            self.assertIn(sparse_ops.csr_todense_p,
                          xla.backend_specific_translations["gpu"])

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_{}".format(jtu.format_shape_dtype_string(shape, dtype),
                            mat_type),
            "shape":
            shape,
            "dtype":
            dtype,
            "mat_type":
            mat_type
        } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
                            for dtype in jtu.dtypes.floating +
                            jtu.dtypes.complex for mat_type in ['csr', 'coo']))
    def test_extra_nnz(self, shape, dtype, mat_type):
        rng = rand_sparse(self.rng())
        M = rng(shape, dtype)
        nnz = (M != 0).sum() + 5
        fromdense = getattr(sparse_ops, f"{mat_type}_fromdense")
        todense = getattr(sparse_ops, f"{mat_type}_todense")
        args = fromdense(M, nnz=nnz, index_dtype=jnp.int32)
        M_out = todense(*args, shape=M.shape)
        self.assertArraysEqual(M, M_out)
Exemple #17
0
    def test_eigh(self, harness: primitive_harness.Harness):
        operand = harness.dyn_args_maker(self.rng())[0]
        lower = harness.params["lower"]
        # Make operand self-adjoint
        operand = (operand + np.conj(np.swapaxes(operand, -1, -2))) / 2
        # Make operand lower/upper triangular
        triangular_operand = np.tril(operand) if lower else np.triu(operand)
        dtype = harness.params["dtype"]

        if (dtype in [np.complex64, np.complex128]
                and jtu.device_under_test() == "tpu"):
            raise unittest.SkipTest(
                "TODO: complex eigh not supported on TPU in JAX")

        def custom_assert(result_jax, result_tf):
            result_tf = tuple(map(lambda e: e.numpy(), result_tf))
            inner_dimension = operand.shape[-1]

            def check_right_eigenvectors(a, w, vr):
                tol = 1e-16
                # TODO(bchetioui): tolerance needs to be very high in compiled mode,
                # specifically for eigenvectors.
                if dtype == np.float64:
                    tol = 1e-6
                elif dtype == np.float32:
                    tol = 1e-2
                elif dtype in [dtypes.bfloat16, np.complex64]:
                    tol = 1e-3
                elif dtype == np.complex128:
                    tol = 1e-13
                self.assertAllClose(np.matmul(a, vr) - w[..., None, :] * vr,
                                    np.zeros(a.shape, dtype=vr.dtype),
                                    atol=tol)

            def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array):
                tol = None
                if dtype in [dtypes.bfloat16, np.float32, np.complex64]:
                    tol = 1e-3
                elif dtype in [np.float64, np.complex128]:
                    tol = 1e-11
                closest_diff = min(abs(eigenvalues_array - eigenvalue))
                self.assertAllClose(closest_diff,
                                    np.array(0., closest_diff.dtype),
                                    atol=tol)

            _, all_w_jax = result_jax
            all_vr_tf, all_w_tf = result_tf

            for idx in itertools.product(*map(range, operand.shape[:-2])):
                w_jax, w_tf = all_w_jax[idx], all_w_tf[idx]
                for i in range(inner_dimension):
                    check_eigenvalue_is_in_array(w_jax[i], w_tf)
                    check_eigenvalue_is_in_array(w_tf[i], w_jax)

            check_right_eigenvectors(operand, all_w_tf, all_vr_tf)

        # On CPU and GPU, JAX makes custom calls
        always_custom_assert = True
        # On TPU, JAX calls xops.Eigh
        if jtu.device_under_test == "tpu":
            always_custom_assert = False

        self.ConvertAndCompare(harness.dyn_fun,
                               triangular_operand,
                               custom_assert=custom_assert,
                               always_custom_assert=always_custom_assert)
Exemple #18
0
 def testDtypeMatchesInput(self, dtype, fn):
     if dtype is jnp.float16 and jtu.device_under_test() == "tpu":
         self.skipTest("float16 not supported on TPU")
     x = jnp.zeros((), dtype=dtype)
     out = fn(x)
     self.assertEqual(out.dtype, dtype)
Exemple #19
0
    def test_binary_elementwise(self, harness):
        tol = None
        lax_name, dtype = harness.params["lax_name"], harness.params["dtype"]
        if lax_name in ("igamma", "igammac"):
            # TODO(necula): fix bug with igamma/f16
            if dtype in [np.float16, dtypes.bfloat16]:
                raise unittest.SkipTest(
                    "TODO: igamma(c) unsupported with (b)float16 in JAX")
            # TODO(necula): fix bug with igamma/f32 on TPU
            if dtype is np.float32 and jtu.device_under_test() == "tpu":
                raise unittest.SkipTest("TODO: fix bug: nan vs not-nan")
        arg1, arg2 = harness.dyn_args_maker(self.rng())
        custom_assert = None
        if lax_name == "igamma":
            # igamma is not defined when the first argument is <=0
            def custom_assert(result_jax, result_tf):
                # lax.igamma returns NaN when arg1 == arg2 == 0; tf.math.igamma returns 0
                special_cases = (arg1 == 0.) & (arg2 == 0.)
                nr_special_cases = np.count_nonzero(special_cases)
                self.assertAllClose(
                    np.full((nr_special_cases, ), np.nan, dtype=dtype),
                    result_jax[special_cases])
                self.assertAllClose(
                    np.full((nr_special_cases, ), 0., dtype=dtype),
                    result_tf[special_cases])
                # non-special cases are equal
                self.assertAllClose(result_jax[~special_cases],
                                    result_tf[~special_cases])

        if lax_name == "igammac":
            # On GPU, tolerance also needs to be adjusted in compiled mode
            if dtype == np.float64 and jtu.device_under_test() == 'gpu':
                tol = 1e-14
            # igammac is not defined when the first argument is <=0
            def custom_assert(result_jax, result_tf):  # noqa: F811
                # lax.igammac returns 1. when arg1 <= 0; tf.math.igammac returns NaN
                special_cases = (arg1 <= 0.) | (arg2 <= 0)
                nr_special_cases = np.count_nonzero(special_cases)
                self.assertAllClose(
                    np.full((nr_special_cases, ), 1., dtype=dtype),
                    result_jax[special_cases])
                self.assertAllClose(
                    np.full((nr_special_cases, ), np.nan, dtype=dtype),
                    result_tf[special_cases])
                # On CPU, tolerance only needs to be adjusted in eager & graph modes
                tol = None
                if dtype == np.float64:
                    tol = 1e-14

                # non-special cases are equal
                self.assertAllClose(result_jax[~special_cases],
                                    result_tf[~special_cases],
                                    atol=tol,
                                    rtol=tol)

        self.ConvertAndCompare(harness.dyn_fun,
                               arg1,
                               arg2,
                               custom_assert=custom_assert,
                               atol=tol,
                               rtol=tol)
Exemple #20
0
 def testSoftplusGradZero(self):
     check_grads(nn.softplus, (0., ),
                 order=1,
                 rtol=1e-2 if jtu.device_under_test() == "tpu" else None)
 def setUp(self):
     super(DLPackTest, self).setUp()
     if jtu.device_under_test() == "tpu":
         self.skipTest("DLPack not supported on TPU")
Exemple #22
0
 def testSoftplusGradNan(self):
     check_grads(nn.softplus, (float('nan'), ),
                 order=1,
                 rtol=1e-2 if jtu.device_under_test() == "tpu" else None)
Exemple #23
0
    def test_unary_elementwise(self, harness: primitive_harness.Harness):
        dtype = harness.params["dtype"]
        lax_name = harness.params["lax_name"]
        if dtype is dtypes.bfloat16:
            raise unittest.SkipTest("bfloat16 not implemented")
        if lax_name in ("sinh", "cosh", "atanh", "asinh",
                        "acosh") and dtype is np.float16:
            raise unittest.SkipTest(
                "b/158006398: float16 support is missing from '%s' TF kernel" %
                lax_name)
        arg, = harness.dyn_args_maker(self.rng())
        custom_assert = None
        if lax_name == "digamma":
            # TODO(necula): fix bug with digamma/f32 on TPU
            if harness.params["dtype"] is np.float32 and jtu.device_under_test(
            ) == "tpu":
                raise unittest.SkipTest("TODO: fix bug: nan vs not-nan")
            if harness.params["dtype"] is np.float16 and jtu.device_under_test(
            ) == "tpu":
                raise unittest.SkipTest("TODO: fix bug: nans and infs")

            # digamma is not defined at 0 and -1
            def custom_assert(result_jax, result_tf):
                # lax.digamma returns NaN and tf.math.digamma returns inf
                special_cases = (arg == 0.) | (arg == -1.)
                nr_special_cases = np.count_nonzero(special_cases)
                self.assertAllClose(
                    np.full((nr_special_cases, ), dtype(np.nan)),
                    result_jax[special_cases])
                self.assertAllClose(
                    np.full((nr_special_cases, ), dtype(np.inf)),
                    result_tf[special_cases])
                # non-special cases are equal
                self.assertAllClose(result_jax[~special_cases],
                                    result_tf[~special_cases])

        if lax_name == "erf_inv":
            # TODO(necula): fix bug with erf_inv/f16
            if dtype is np.float16:
                raise unittest.SkipTest("TODO: fix bug")
            # TODO(necula): fix erf_inv bug on TPU
            if jtu.device_under_test() == "tpu":
                raise unittest.SkipTest("erf_inv bug on TPU: nan vs non-nan")
            # erf_inf is not defined for arg <= -1 or arg >= 1
            def custom_assert(result_jax, result_tf):  # noqa: F811
                # for arg < -1 or arg > 1
                # lax.erf_inf returns NaN; tf.math.erf_inv return +/- inf
                special_cases = (arg < -1.) | (arg > 1.)
                nr_special_cases = np.count_nonzero(special_cases)
                self.assertAllClose(
                    np.full((nr_special_cases, ), dtype(np.nan)),
                    result_jax[special_cases])
                signs = np.where(arg[special_cases] < 0., -1., 1.)
                self.assertAllClose(
                    np.full((nr_special_cases, ), signs * dtype(np.inf)),
                    result_tf[special_cases])
                # non-special cases are equal
                self.assertAllClose(result_jax[~special_cases],
                                    result_tf[~special_cases])

        atol = None
        if jtu.device_under_test() == "gpu":
            # TODO(necula): revisit once we fix the GPU tests
            atol = 1e-3
        self.ConvertAndCompare(harness.dyn_fun,
                               arg,
                               custom_assert=custom_assert,
                               atol=atol)
Exemple #24
0
 def testReluGrad(self):
     rtol = 1e-2 if jtu.device_under_test() == "tpu" else None
     check_grads(nn.relu, (1., ), order=3, rtol=rtol)
     check_grads(nn.relu, (-1., ), order=3, rtol=rtol)
     jaxpr = jax.make_jaxpr(jax.grad(nn.relu))(0.)
     self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 2)
Exemple #25
0
    def test_generate_limitations_doc(self):
        """Generates primitives_with_limited_support.md.

    See the doc for instructions.
    """

        harnesses = [
            h for h in primitive_harness.all_harnesses
            if h.filter(h, include_jax_unimpl=True)
        ]
        print(f"Found {len(harnesses)} test harnesses that work in JAX")

        def unique_hash(h: primitive_harness.Harness, l: Jax2TfLimitation):
            return (h.group_name, l.description, l.devices,
                    tuple([np.dtype(d).name for d in l.dtypes]), l.modes)

        unique_limitations: Dict[Any, Tuple[primitive_harness.Harness,
                                            Jax2TfLimitation]] = {}
        for h in harnesses:
            for l in h.jax_unimplemented:
                if l.enabled:
                    # Fake a Jax2TFLimitation from the Limitation
                    tfl = Jax2TfLimitation(
                        description="Not implemented in JAX: " + l.description,
                        devices=l.devices,
                        dtypes=l.dtypes,
                        expect_tf_error=False,
                        skip_tf_run=True)
                    unique_limitations[hash(unique_hash(h, tfl))] = (h, tfl)
        for h in harnesses:
            for l in Jax2TfLimitation.limitations_for_harness(h):
                unique_limitations[hash(unique_hash(h, l))] = (h, l)

        print(f"Found {len(unique_limitations)} unique limitations")
        tf_error_table = [
            """
| Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes |
| --- | --- | --- | --- | --- |"""
        ]
        tf_numerical_discrepancies_table = list(tf_error_table)  # a copy
        for h, l in sorted(unique_limitations.values(),
                           key=lambda pair: unique_hash(*pair)):
            devices = ", ".join(sorted(l.devices))
            modes = ", ".join(sorted(l.modes))
            description = l.description
            if l.skip_comparison:
                description = "Numeric comparision disabled: " + description
            if l.expect_tf_error:
                description = "TF error: " + description
            if l.skip_tf_run:
                description = "TF test skipped: " + description

            if l.skip_tf_run or l.expect_tf_error:
                to_table = tf_error_table
            elif l.skip_comparison or l.custom_assert:
                to_table = tf_numerical_discrepancies_table
            else:
                continue

            to_table.append(
                f"| {h.group_name} | {description} | "
                f"{primitive_harness.dtypes_to_str(l.dtypes, empty_means_all=True)} | {devices} | {modes} |"
            )

        if not os.environ.get("JAX_OUTPUT_LIMITATIONS_DOC"):
            raise unittest.SkipTest(
                "Set JAX_OUTPUT_LIMITATIONS_DOC=1 to enable the generation of the documentation"
            )
        # The CPU has more supported types, and harnesses
        self.assertEqual("cpu", jtu.device_under_test())
        self.assertTrue(
            config.x64_enabled,
            "Documentation generation must be run with JAX_ENABLE_X64=1")

        with open(
                os.path.join(
                    os.path.dirname(__file__),
                    "../g3doc/primitives_with_limited_support.md.template")
        ) as f:
            template = f.read()
        output_file = os.path.join(
            os.path.dirname(__file__),
            "../g3doc/primitives_with_limited_support.md")

        with open(output_file, "w") as f:
            f.write(template.replace("{{generation_date}}", str(datetime.date.today())) \
                    .replace("{{tf_error_table}}", "\n".join(tf_error_table)) \
                    .replace("{{tf_numerical_discrepancies_table}}", "\n".join(tf_numerical_discrepancies_table)) \
                    )
Exemple #26
0
 def setUp(self):
     super(ShardedJitTest, self).setUp()
     if jtu.device_under_test() != "tpu":
         raise SkipTest
class cuSparseTest(jtu.JaxTestCase):
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
  def test_csr_todense(self, shape, dtype):
    rng = rand_sparse(self.rng(), post=sparse.csr_matrix)
    M = rng(shape, dtype)

    args = (M.data, M.indices, M.indptr)
    todense = lambda *args: sparse_ops.csr_todense(*args, shape=M.shape)

    self.assertArraysEqual(M.toarray(), todense(*args))
    self.assertArraysEqual(M.toarray(), jit(todense)(*args))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
  def test_csr_fromdense(self, shape, dtype):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    M_csr = sparse.csr_matrix(M)

    nnz = M_csr.nnz
    index_dtype = jnp.int32
    fromdense = lambda M: sparse_ops.csr_fromdense(M, nnz=nnz, index_dtype=jnp.int32)

    data, indices, indptr = fromdense(M)
    self.assertArraysEqual(data, M_csr.data.astype(dtype))
    self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype))
    self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype))

    data, indices, indptr = jit(fromdense)(M)
    self.assertArraysEqual(data, M_csr.data.astype(dtype))
    self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype))
    self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
       "shape": shape, "dtype": dtype, "transpose": transpose}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex
      for transpose in [True, False]))
  def test_csr_matvec(self, shape, dtype, transpose):
    op = lambda M: M.T if transpose else M

    v_rng = jtu.rand_default(self.rng())
    rng = rand_sparse(self.rng(), post=sparse.csr_matrix)
    M = rng(shape, dtype)
    v = v_rng(op(M).shape[1], dtype)

    args = (M.data, M.indices, M.indptr, v)
    matvec = lambda *args: sparse_ops.csr_matvec(*args, shape=M.shape, transpose=transpose)

    self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
    self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
       "shape": shape, "dtype": dtype, "transpose": transpose}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex
      for transpose in [True, False]))
  def test_csr_matmat(self, shape, dtype, transpose):
    op = lambda M: M.T if transpose else M

    B_rng = jtu.rand_default(self.rng())
    rng = rand_sparse(self.rng(), post=sparse.csr_matrix)
    M = rng(shape, dtype)
    B = B_rng((op(M).shape[1], 4), dtype)

    args = (M.data, M.indices, M.indptr, B)
    matmat = lambda *args: sparse_ops.csr_matmat(*args, shape=shape, transpose=transpose)

    self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
    self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
  def test_coo_todense(self, shape, dtype):
    rng = rand_sparse(self.rng(), post=sparse.coo_matrix)
    M = rng(shape, dtype)

    args = (M.data, M.row, M.col)
    todense = lambda *args: sparse_ops.coo_todense(*args, shape=M.shape)

    self.assertArraysEqual(M.toarray(), todense(*args))
    self.assertArraysEqual(M.toarray(), jit(todense)(*args))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
  def test_coo_fromdense(self, shape, dtype):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    M_coo = sparse.coo_matrix(M)

    nnz = M_coo.nnz
    index_dtype = jnp.int32
    fromdense = lambda M: sparse_ops.coo_fromdense(M, nnz=nnz, index_dtype=jnp.int32)

    data, row, col = fromdense(M)
    self.assertArraysEqual(data, M_coo.data.astype(dtype))
    self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
    self.assertArraysEqual(col, M_coo.col.astype(index_dtype))

    data, indices, indptr = jit(fromdense)(M)
    self.assertArraysEqual(data, M_coo.data.astype(dtype))
    self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
    self.assertArraysEqual(col, M_coo.col.astype(index_dtype))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
       "shape": shape, "dtype": dtype, "transpose": transpose}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex
      for transpose in [True, False]))
  def test_coo_matvec(self, shape, dtype, transpose):
    op = lambda M: M.T if transpose else M

    v_rng = jtu.rand_default(self.rng())
    rng = rand_sparse(self.rng(), post=sparse.coo_matrix)
    M = rng(shape, dtype)
    v = v_rng(op(M).shape[1], dtype)

    args = (M.data, M.row, M.col, v)
    matvec = lambda *args: sparse_ops.coo_matvec(*args, shape=M.shape, transpose=transpose)

    self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
    self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)

  @unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
       "shape": shape, "dtype": dtype, "transpose": transpose}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex
      for transpose in [True, False]))
  def test_coo_matmat(self, shape, dtype, transpose):
    op = lambda M: M.T if transpose else M

    B_rng = jtu.rand_default(self.rng())
    rng = rand_sparse(self.rng(), post=sparse.coo_matrix)
    M = rng(shape, dtype)
    B = B_rng((op(M).shape[1], 4), dtype)

    args = (M.data, M.row, M.col, B)
    matmat = lambda *args: sparse_ops.coo_matmat(*args, shape=shape, transpose=transpose)

    self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
    self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)

    y, dy = jvp(lambda x: sparse_ops.coo_matmat(M.data, M.row, M.col, x, shape=shape, transpose=transpose).sum(), (B, ), (jnp.ones_like(B), ))
    self.assertAllClose((op(M) @ B).sum(), y, rtol=MATMUL_TOL)

    y, dy = jvp(lambda x: sparse_ops.coo_matmat(x, M.row, M.col, B, shape=shape, transpose=transpose).sum(), (M.data, ), (jnp.ones_like(M.data), ))
    self.assertAllClose((op(M) @ B).sum(), y, rtol=MATMUL_TOL)

  @unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
  def test_gpu_translation_rule(self):
    version = xla_bridge.get_backend().platform_version
    cuda_version = None if version == "<unknown>" else int(version.split()[-1])
    if cuda_version is None or cuda_version < 11000:
      self.assertFalse(cusparse and cusparse.is_supported)
      self.assertNotIn(sparse_ops.csr_todense_p, xla.backend_specific_translations["gpu"])
    else:
      self.assertTrue(cusparse and cusparse.is_supported)
      self.assertIn(sparse_ops.csr_todense_p, xla.backend_specific_translations["gpu"])

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_{}".format(
         jtu.format_shape_dtype_string(shape, dtype), mat_type),
       "shape": shape, "dtype": dtype, "mat_type": mat_type}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex
      for mat_type in ['csr', 'coo']))
  def test_extra_nnz(self, shape, dtype, mat_type):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    nnz = (M != 0).sum() + 5
    fromdense = getattr(sparse_ops, f"{mat_type}_fromdense")
    todense = getattr(sparse_ops, f"{mat_type}_todense")
    args = fromdense(M, nnz=nnz, index_dtype=jnp.int32)
    M_out = todense(*args, shape=M.shape)
    self.assertArraysEqual(M, M_out)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
  def test_coo_todense_ad(self, shape, dtype):
    rng = rand_sparse(self.rng(), post=jnp.array)
    M = rng(shape, dtype)
    data, row, col = sparse_ops.coo_fromdense(M, nnz=(M != 0).sum())
    f = lambda data: sparse_ops.coo_todense(data, row, col, shape=M.shape)

    # Forward-mode
    primals, tangents = api.jvp(f, [data], [jnp.ones_like(data)])
    self.assertArraysEqual(primals, f(data))
    self.assertArraysEqual(tangents, jnp.zeros_like(M).at[row, col].set(1))

    # Reverse-mode
    primals, vjp_fun = api.vjp(f, data)
    data_out, = vjp_fun(primals)
    self.assertArraysEqual(primals, f(data))
    self.assertArraysEqual(data_out, data)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
  def test_coo_fromdense_ad(self, shape, dtype):
    rng = rand_sparse(self.rng(), post=jnp.array)
    M = rng(shape, dtype)
    nnz = (M != 0).sum()
    f = lambda M: sparse_ops.coo_fromdense(M, nnz=nnz)

    # Forward-mode
    primals, tangents = api.jvp(f, [M], [jnp.ones_like(M)])
    self.assertArraysEqual(primals[0], f(M)[0])
    self.assertArraysEqual(primals[1], f(M)[1])
    self.assertArraysEqual(primals[2], f(M)[2])
    self.assertArraysEqual(tangents[0], jnp.ones(nnz, dtype=dtype))
    self.assertEqual(tangents[1].dtype, dtypes.float0)
    self.assertEqual(tangents[2].dtype, dtypes.float0)

    # Reverse-mode
    primals, vjp_fun = api.vjp(f, M)
    M_out, = vjp_fun(primals)
    self.assertArraysEqual(primals[0], f(M)[0])
    self.assertArraysEqual(primals[1], f(M)[1])
    self.assertArraysEqual(primals[2], f(M)[2])
    self.assertArraysEqual(M_out, M)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_{}".format(
        jtu.format_shape_dtype_string(shape, dtype),
        jtu.format_shape_dtype_string(bshape, dtype)),
       "shape": shape, "dtype": dtype, "bshape": bshape}
      for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
      for bshape in [shape[-1:] + s for s in [()]]  # TODO: matmul autodiff
      for dtype in jtu.dtypes.floating + jtu.dtypes.complex))  # TODO: other types

  def test_coo_matvec_ad(self, shape, dtype, bshape):
    tol = {np.float32: 1E-6, np.float64: 1E-13, np.complex64: 1E-6, np.complex128: 1E-13}

    rng = rand_sparse(self.rng(), post=jnp.array)
    rng_b = jtu.rand_default(self.rng())

    M = rng(shape, dtype)
    data, row, col = sparse_ops.coo_fromdense(M, nnz=(M != 0).sum())
    x = rng_b(bshape, dtype)
    xdot = rng_b(bshape, dtype)

    # Forward-mode with respect to the vector
    f_dense = lambda x: M @ x
    f_sparse = lambda x: sparse_ops.coo_matvec(data, row, col, x, shape=M.shape)
    v_sparse, t_sparse = api.jvp(f_sparse, [x], [xdot])
    v_dense, t_dense = api.jvp(f_dense, [x], [xdot])
    self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol)
    self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol)

    # Reverse-mode with respect to the vector
    primals_dense, vjp_dense = api.vjp(f_dense, x)
    primals_sparse, vjp_sparse = api.vjp(f_sparse, x)
    out_dense, = vjp_dense(primals_dense)
    out_sparse, = vjp_sparse(primals_sparse)
    self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol)
    self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)

    # Forward-mode with respect to nonzero elements of the matrix
    f_sparse = lambda data: sparse_ops.coo_matvec(data, row, col, x, shape=M.shape)
    f_dense = lambda data: sparse_ops.coo_todense(data, row, col, shape=M.shape) @ x
    data = rng((len(data),), data.dtype)
    data_dot = rng((len(data),), data.dtype)
    v_sparse, t_sparse = api.jvp(f_sparse, [data], [data_dot])
    v_dense, t_dense = api.jvp(f_dense, [data], [data_dot])

    self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol)
    self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol)

    # Reverse-mode with respect to nonzero elements of the matrix
    primals_dense, vjp_dense = api.vjp(f_dense, data)
    primals_sparse, vjp_sparse = api.vjp(f_sparse, data)
    out_dense, = vjp_dense(primals_dense)
    out_sparse, = vjp_sparse(primals_sparse)
    self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol)
    self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)
Exemple #28
0
    def test_unary_elementwise(self, harness: primitive_harness.Harness):
        dtype = harness.params["dtype"]
        lax_name = harness.params["lax_name"]
        if (lax_name in ("acosh", "asinh", "atanh", "bessel_i0e", "bessel_i1e",
                         "digamma", "erf", "erf_inv", "erfc", "lgamma",
                         "round", "rsqrt") and dtype is dtypes.bfloat16
                and jtu.device_under_test() in ["cpu", "gpu"]):
            raise unittest.SkipTest(
                f"bfloat16 support is missing from '{lax_name}' TF kernel on {jtu.device_under_test()} devices."
            )
        # TODO(bchetioui): do they have bfloat16 support, though?
        if lax_name in ("sinh", "cosh", "atanh", "asinh", "acosh",
                        "erf_inv") and dtype is np.float16:
            raise unittest.SkipTest(
                "b/158006398: float16 support is missing from '%s' TF kernel" %
                lax_name)
        arg, = harness.dyn_args_maker(self.rng())
        custom_assert = None
        if lax_name == "digamma":
            # TODO(necula): fix bug with digamma/(f32|f16) on TPU
            if dtype in [np.float16, np.float32
                         ] and jtu.device_under_test() == "tpu":
                raise unittest.SkipTest("TODO: fix bug: nan vs not-nan")

            # In the bfloat16 case, TF and lax both return NaN in undefined cases.
            if not dtype is dtypes.bfloat16:
                # digamma is not defined at 0 and -1
                def custom_assert(result_jax, result_tf):
                    # lax.digamma returns NaN and tf.math.digamma returns inf
                    special_cases = (arg == 0.) | (arg == -1.)
                    nr_special_cases = np.count_nonzero(special_cases)
                    self.assertAllClose(
                        np.full((nr_special_cases, ), dtype(np.nan)),
                        result_jax[special_cases])
                    self.assertAllClose(
                        np.full((nr_special_cases, ), dtype(np.inf)),
                        result_tf[special_cases])
                    # non-special cases are equal
                    self.assertAllClose(result_jax[~special_cases],
                                        result_tf[~special_cases])

        if lax_name == "erf_inv":
            # TODO(necula): fix erf_inv bug on TPU
            if jtu.device_under_test() == "tpu":
                raise unittest.SkipTest("erf_inv bug on TPU: nan vs non-nan")
            # TODO: investigate: in the (b)float16 cases, TF and lax both return the same
            # result in undefined cases.
            if not dtype in [np.float16, dtypes.bfloat16]:
                # erf_inv is not defined for arg <= -1 or arg >= 1
                def custom_assert(result_jax, result_tf):  # noqa: F811
                    # for arg < -1 or arg > 1
                    # lax.erf_inv returns NaN; tf.math.erf_inv return +/- inf
                    special_cases = (arg < -1.) | (arg > 1.)
                    nr_special_cases = np.count_nonzero(special_cases)
                    self.assertAllClose(
                        np.full((nr_special_cases, ), dtype(np.nan)),
                        result_jax[special_cases])
                    signs = np.where(arg[special_cases] < 0., -1., 1.)
                    self.assertAllClose(
                        np.full((nr_special_cases, ), signs * dtype(np.inf)),
                        result_tf[special_cases])
                    # non-special cases are equal
                    self.assertAllClose(result_jax[~special_cases],
                                        result_tf[~special_cases])

        atol = None
        if jtu.device_under_test() == "gpu":
            # TODO(necula): revisit once we fix the GPU tests
            atol = 1e-3
        self.ConvertAndCompare(harness.dyn_fun,
                               arg,
                               custom_assert=custom_assert,
                               atol=atol)
Exemple #29
0
 def setUp(self):
   super(CudaArrayInterfaceTest, self).setUp()
   if jtu.device_under_test() != "gpu":
     self.skipTest("__cuda_array_interface__ is only supported on GPU")
Exemple #30
0
def _add_vmap_primitive_harnesses():
    """For each harness group, pick a single dtype.

  Ignore harnesses that fail in graph mode in jax2tf.
  """
    all_h = primitive_harness.all_harnesses

    # Index by group
    harness_groups: Dict[
        str,
        Sequence[primitive_harness.Harness]] = collections.defaultdict(list)
    device = jtu.device_under_test()

    for h in all_h:
        # Drop the the JAX limitations
        if not h.filter(device_under_test=device, include_jax_unimpl=False):
            continue
        # And the jax2tf limitations that are known to result in TF error.
        if any(l.expect_tf_error for l in _get_jax2tf_limitations(device, h)):
            continue
        harness_groups[h.group_name].append(h)

    selected_harnesses = []
    for group_name, hlist in harness_groups.items():
        # Pick the dtype with the most harnesses in this group. Some harness
        # groups only test different use cases at a few dtypes.
        c = collections.Counter([h.dtype for h in hlist])
        (dtype, _), = c.most_common(1)
        selected_harnesses.extend([h for h in hlist if h.dtype == dtype])

    # We do not yet support shape polymorphism for vmap for some primitives
    _NOT_SUPPORTED_YET = frozenset([
        # In the random._gamma_impl we do reshape(-1, 2) for the keys
        "random_gamma",

        # In linalg._lu_python we do reshape(-1, ...)
        "lu",
        "custom_linear_solve",

        # We do *= shapes in the batching rule for conv_general_dilated
        "conv_general_dilated",

        # vmap(clamp) fails in JAX
        "clamp",
        "iota",  # vmap does not make sense for 0-argument functions
    ])

    batch_size = 3
    for h in selected_harnesses:
        if h.group_name in _NOT_SUPPORTED_YET:
            continue

        def make_batched_arg_descriptor(
            ad: primitive_harness.ArgDescriptor
        ) -> Optional[primitive_harness.ArgDescriptor]:
            if isinstance(ad, RandArg):
                return RandArg((batch_size, ) + ad.shape, ad.dtype)
            elif isinstance(ad, CustomArg):

                def wrap_custom(rng):
                    arg = ad.make(rng)
                    return np.stack([arg] * batch_size)

                return CustomArg(wrap_custom)
            else:
                assert isinstance(ad, np.ndarray), ad
                return np.stack([ad] * batch_size)

        new_args = [
            make_batched_arg_descriptor(ad) for ad in h.arg_descriptors
            if not isinstance(ad, StaticArg)
        ]

        # We do not check the result of harnesses that require custom assertions.
        check_result = all(
            not l.custom_assert and not l.skip_comparison and l.tol is None
            for l in _get_jax2tf_limitations(device, h))
        vmap_harness = _make_harness(h.group_name,
                                     f"vmap_{h.name}",
                                     jax.vmap(h.dyn_fun, in_axes=0,
                                              out_axes=0),
                                     new_args,
                                     poly_axes=[0] * len(new_args),
                                     check_result=check_result,
                                     **h.params)
        _POLY_SHAPE_TEST_HARNESSES.append(vmap_harness)