Example #1
0
def _get_max_identity(dtype):
  if dtypes.issubdtype(dtype, np.inexact):
    return np.array(-np.inf, dtype)
  elif dtypes.issubdtype(dtype, np.integer):
    return np.array(dtypes.iinfo(dtype).min, dtype)
  elif dtypes.issubdtype(dtype, np.bool_):
    return np.array(False, np.bool_)
Example #2
0
def _get_min_identity(dtype):
  if dtypes.issubdtype(dtype, np.inexact):
    return np.array(np.inf, dtype)
  elif dtypes.issubdtype(dtype, np.integer):
    return np.array(dtypes.iinfo(dtype).max, dtype)
  elif dtypes.issubdtype(dtype, np.bool_):
    return np.array(True, np.bool_)
Example #3
0
 def testIsSubdtype(self):
   for t in scalar_types:
     self.assertTrue(dtypes.issubdtype(t, t))
     self.assertTrue(dtypes.issubdtype(np.dtype(t).type, t))
     self.assertTrue(dtypes.issubdtype(t, np.dtype(t).type))
     if t != jnp.bfloat16:
       for category in [np.generic, jnp.inexact, jnp.integer, jnp.signedinteger,
                        jnp.unsignedinteger, jnp.floating, jnp.complexfloating]:
         self.assertEqual(dtypes.issubdtype(t, category),
                          np.issubdtype(np.dtype(t).type, category))
         self.assertEqual(dtypes.issubdtype(t, category),
                          np.issubdtype(np.dtype(t).type, category))
Example #4
0
def _check_input_dtype_jacfwd(holomorphic, x):
    _check_arg(x)
    aval = core.get_aval(x)
    if holomorphic:
        if not (dtypes.issubdtype(aval.dtype, np.complexfloating)
                and not dtypes.issubdtype(aval.dtype, np.floating)):
            raise TypeError(
                "jacfwd with holomorphic=True requires inputs with complex dtype, "
                f"but got {aval.dtype.name}.")
    elif not dtypes.issubdtype(aval.dtype, np.floating):
        raise TypeError(
            "jacfwd requires real-valued inputs (input dtype that is "
            f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
            "For holomorphic differentiation, pass holomorphic=True. "
            "For differentiation of non-holomorphic functions involving complex "
            "inputs or integer inputs, use jax.jvp directly.")
Example #5
0
def closure_convert(fun, in_tree, in_avals):
    if config.omnistaging_enabled:
        wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun),
                                                     in_tree)
        jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(
            wrapped_fun, in_avals)
    else:
        in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
        wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun),
                                                     in_tree)
        with core.initial_style_staging():  # type: ignore
            jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
                wrapped_fun, in_pvals, instantiate=True,
                stage_out=False)  # type: ignore
    out_tree = out_tree()

    # We only want to closure convert for constants with respect to which we're
    # differentiating. As a proxy for that, we hoist consts with float dtype.
    # TODO(mattjj): revise this approach
    is_float = lambda c: dtypes.issubdtype(dtypes.dtype(c), jnp.inexact)
    (closure_consts, hoisted_consts), merge = partition_list(is_float, consts)
    num_consts = len(hoisted_consts)

    def converted_fun(y, t, *hconsts_args):
        hoisted_consts, args = split_list(hconsts_args, [num_consts])
        consts = merge(closure_consts, hoisted_consts)
        all_args, in_tree2 = tree_flatten((y, t, *args))
        assert in_tree == in_tree2
        out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
        return tree_unflatten(out_tree, out_flat)

    return converted_fun, hoisted_consts
Example #6
0
def _check_output_dtype_jacfwd(holomorphic, x):
    aval = core.get_aval(x)
    if holomorphic:
        if not dtypes.issubdtype(aval.dtype, np.complexfloating):
            raise TypeError(
                "jacfwd with holomorphic=True requires outputs with complex dtype, "
                f"but got {aval.dtype.name}.")
Example #7
0
def _psum_translation_rule(c, *args, replica_groups=None, platform=None):
  if platform in ("cpu", "tpu"):
    return _notuple_psum_translation_rule(c, *args, replica_groups=replica_groups)

  # XLA's tuple all-reduce doesn't support different dtypes in the same
  # allreduce. Instead, we perform once all-reduce for each argument input type.
  args_by_type = collections.defaultdict(lambda: ([], []))
  for i, arg in enumerate(args):
    indices, dtype_args = args_by_type[c.get_shape(arg).numpy_dtype()]
    indices.append(i)
    dtype_args.append(arg)

  # The outputs, in the original argument order.
  out = [None] * len(args)
  replica_groups_protos = xc.make_replica_groups(replica_groups)
  for dtype, (indices, dtype_args) in sorted(args_by_type.items()):
    is_complex = dtypes.issubdtype(dtype, onp.complexfloating)
    n = len(dtype_args)
    if is_complex:
      dtype_args = ([xops.Real(x) for x in dtype_args] +
                    [xops.Imag(x) for x in dtype_args])
    scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype())
    computation = xla.primitive_subcomputation(lax.add_p, scalar, scalar)
    all_reduce = xops.AllReduce(xops.Tuple(c, dtype_args), computation,
                                replica_groups_protos, None, None)
    if is_complex:
      xs = [xops.Complex(xops.GetTupleElement(all_reduce, i),
                         xops.GetTupleElement(all_reduce, n + i)) for i in range(n)]
    else:
      xs = [xops.GetTupleElement(all_reduce, i) for i in range(n)]
    for i, x in zip(indices, xs):
      out[i] = x
  return xops.Tuple(c, out)
Example #8
0
 def _translate(val):
   psum = partial(_allreduce_translation_rule, lax.add_p, c,
                  replica_groups=replica_groups)
   dtype = c.get_shape(val).numpy_dtype()
   if dtypes.issubdtype(dtype, onp.complexfloating):
     return xops.Complex(psum(xops.Real(val)), psum(xops.Imag(val)))
   else:
     return psum(val)
Example #9
0
def _psum_translation_rule(c, val, replica_groups, backend=None):
  psum = partial(_allreduce_translation_rule, lax.add_p, c,
                 replica_groups=replica_groups, backend=backend)
  dtype = c.GetShape(val).numpy_dtype()
  if dtypes.issubdtype(dtype, onp.complexfloating):
    return c.Complex(psum(c.Real(val)), psum(c.Imag(val)))
  else:
    return psum(val)
Example #10
0
 def _translate(val):
   psum = partial(_allreduce_translation_rule, lax.add_p, c,
                  axis_name=axis_name, axis_env=axis_env,
                  axis_index_groups=axis_index_groups, platform=platform)
   dtype = c.get_shape(val).numpy_dtype()
   if dtypes.issubdtype(dtype, np.complexfloating):
     return xops.Complex(psum(xops.Real(val)), psum(xops.Imag(val)))
   else:
     return psum(val)
Example #11
0
 def test_div(self, harness: primitive_harness.Harness):
     dividend, divisor = harness.dyn_args_maker(self.rng())
     prim = harness.params["prim"]
     if dtypes.issubdtype(dividend.dtype, np.integer):
         if (prim is lax.div_p
                 and np.any(divisor == np.array(0, dtype=divisor.dtype))):
             raise unittest.SkipTest(
                 "Divisor contains a 0, and TF returns an error value in compiled "
                 "mode instead of failing like in eager and graph mode for dtype "
                 f"{divisor.dtype}")
     self.ConvertAndCompare(harness.dyn_fun, dividend, divisor)
Example #12
0
def _allreduce_translation_rule(prim, c, *args, axis_name, axis_index_groups,
                                axis_env, platform):
    if platform in ("cpu", "tpu"):
        return _notuple_allreduce_translation_rule(
            prim,
            c,
            *args,
            axis_name=axis_name,
            axis_index_groups=axis_index_groups,
            axis_env=axis_env,
            platform=platform)

    # XLA's tuple all-reduce doesn't support different dtypes in the same
    # allreduce. Instead, we perform once all-reduce for each argument input type.
    args_by_type = collections.defaultdict(lambda: ([], []))
    for i, arg in enumerate(args):
        indices, dtype_args = args_by_type[c.get_shape(arg).numpy_dtype()]
        indices.append(i)
        dtype_args.append(arg)

    # The outputs, in the original argument order.
    out = [None] * len(args)
    replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups)
    replica_groups_protos = xc.make_replica_groups(replica_groups)
    for dtype, (indices, dtype_args) in sorted(args_by_type.items()):
        is_complex = dtypes.issubdtype(dtype, np.complexfloating)
        n = len(dtype_args)
        if is_complex and prim is lax.add_p:
            # TODO(b/141575627): we handle complex-dtype sum-reduction directly as a
            # special case because it's not currently handled by XLA:GPU
            dtype_args = ([xops.Real(x) for x in dtype_args] +
                          [xops.Imag(x) for x in dtype_args])
        scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype())
        computation = xla.primitive_subcomputation(prim, scalar, scalar)
        all_reduce = xops.AllReduce(xops.Tuple(c, dtype_args), computation,
                                    replica_groups_protos, None, None)
        if is_complex and prim is lax.add_p:
            xs = [
                xops.Complex(xops.GetTupleElement(all_reduce, i),
                             xops.GetTupleElement(all_reduce, n + i))
                for i in range(n)
            ]
        else:
            xs = [xops.GetTupleElement(all_reduce, i) for i in range(n)]
        for i, x in zip(indices, xs):
            out[i] = x
    return xops.Tuple(c, out)
Example #13
0
def _notuple_allreduce_translation_rule(prim, c, *args, axis_name, axis_env,
                                        axis_index_groups, platform):
  def all_reduce(x):
    replica_groups_protos = xc.make_replica_groups(
        _replica_groups(axis_env, axis_name, axis_index_groups))
    scalar = ShapedArray((), c.get_shape(x).numpy_dtype())
    computation = xla.primitive_subcomputation(prim, scalar, scalar)
    return xops.AllReduce(x, computation, replica_groups_protos, None, None)

  if prim is not lax.add_p:
    outs = [all_reduce(x) for x in args]
  else:
    # TODO(b/141575627): we handle complex-dtype sum-reduction directly as a
    # special case because it's not currently handled by XLA:GPU
    outs = [xops.Complex(all_reduce(xops.Real(x)), all_reduce(xops.Imag(x)))
            if dtypes.issubdtype(c.get_shape(x).numpy_dtype(), np.complexfloating)
            else all_reduce(x) for x in args]
  return xops.Tuple(c, outs)
Example #14
0
class LaxVmapTest(jtu.JaxTestCase):
    def _CheckBatching(self,
                       op,
                       bdim_size,
                       bdims,
                       shapes,
                       dtypes,
                       rng,
                       rtol=None,
                       atol=None):
        batched_shapes = map(partial(add_bdim, bdim_size), bdims, shapes)
        args = [
            rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes)
        ]
        args_slice = args_slicer(args, bdims)
        ans = api.vmap(op, bdims)(*args)
        expected = onp.stack([op(*args_slice(i)) for i in range(bdim_size)])
        self.assertAllClose(ans, expected, rtol=rtol, atol=atol)

    @parameterized.named_parameters(
        itertools.chain.from_iterable(
            jtu.cases_from_list(
                {
                    "testcase_name":
                    "{}_bdims={}".format(
                        jtu.format_test_name_suffix(
                            rec.op, shapes, itertools.repeat(dtype)), bdims),
                    "op_name":
                    rec.op,
                    "rng_factory":
                    rec.rng_factory,
                    "shapes":
                    shapes,
                    "dtype":
                    dtype,
                    "bdims":
                    bdims,
                    "tol":
                    rec.tol
                } for shape_group in compatible_shapes
                for shapes in CombosWithReplacement(shape_group, rec.nargs)
                for bdims in all_bdims(*shapes) for dtype in rec.dtypes)
            for rec in LAX_OPS))
    def testOp(self, op_name, rng_factory, shapes, dtype, bdims, tol):
        rng = rng_factory(self.rng())
        op = getattr(lax, op_name)
        self._CheckBatching(op,
                            10,
                            bdims,
                            shapes, [dtype] * len(shapes),
                            rng,
                            atol=tol,
                            rtol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
            "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}"
            "_lhs_bdim={}_rhs_bdim={}".format(
                jtu.format_shape_dtype_string(lhs_shape, dtype),
                jtu.format_shape_dtype_string(rhs_shape, dtype), strides,
                padding, lhs_dil, rhs_dil, ",".join(dim_nums),
                feature_group_count, batch_group_count, lhs_bdim, rhs_bdim),
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "strides":
            strides,
            "padding":
            padding,
            "lhs_dil":
            lhs_dil,
            "rhs_dil":
            rhs_dil,
            "rng_factory":
            rng_factory,
            "dimension_numbers":
            dim_nums,
            "perms":
            perms,
            "lhs_bdim":
            lhs_bdim,
            "rhs_bdim":
            rhs_bdim,
            "feature_group_count":
            feature_group_count,
            "batch_group_count":
            batch_group_count,
        } for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1,
                                                                           2)])
                            for lhs_shape, rhs_shape, all_strides, all_pads,
                            lhs_dils, rhs_dils in [
                                (
                                    (b * batch_group_count,
                                     i * feature_group_count, 6,
                                     7),  # lhs_shape
                                    (j * batch_group_count *
                                     feature_group_count, i, 1,
                                     2),  # rhs_shape
                                    [(1, 1), (1, 2), (2, 1)],  # strides
                                    [((0, 0),
                                      (0, 0)), ((1, 0),
                                                (0, 1)), ((0, -1),
                                                          (0, 0))],  # pads
                                    [(1, 1), (2, 1)],  # lhs_dils
                                    [(1, 1), (2, 2)])  # rhs_dils
                                for b, i, j in itertools.product([1, 2],
                                                                 repeat=3)
                            ] for strides in all_strides
                            for rhs_dil in rhs_dils for lhs_dil in lhs_dils
                            for dtype in [onp.float32] for padding in all_pads
                            for dim_nums, perms in [(("NCHW", "OIHW", "NCHW"),
                                                     ([0, 1, 2, 3],
                                                      [0, 1, 2, 3])),
                                                    (("NHWC", "HWIO", "NHWC"),
                                                     ([0, 2, 3, 1],
                                                      [2, 3, 1, 0])),
                                                    (("NHWC", "OIHW", "NCHW"),
                                                     ([0, 2, 3, 1],
                                                      [0, 1, 2, 3]))]
                            for lhs_bdim in itertools.chain(
                                [cast(Optional[int], None)],
                                range(len(lhs_shape) + 1))
                            for rhs_bdim in itertools.chain(
                                [cast(Optional[int], None)],
                                range(len(rhs_shape) + 1))
                            if (lhs_bdim, rhs_bdim) != (None, None)
                            for rng_factory in [jtu.rand_default]))
    def testConvGeneralDilatedBatching(self, lhs_shape, rhs_shape, dtype,
                                       strides, padding, lhs_dil, rhs_dil,
                                       dimension_numbers, perms,
                                       feature_group_count, batch_group_count,
                                       lhs_bdim, rhs_bdim, rng_factory):
        rng = rng_factory(self.rng())
        tol = 1e-1 if dtypes.finfo(dtype).bits <= 32 else 1e-3

        # permute shapes to match dim_spec, scale by feature_group_count
        lhs_perm, rhs_perm = perms
        lhs_shape = list(onp.take(lhs_shape, lhs_perm))
        rhs_shape = list(onp.take(rhs_shape, rhs_perm))

        conv = partial(lax.conv_general_dilated,
                       window_strides=strides,
                       padding=padding,
                       lhs_dilation=lhs_dil,
                       rhs_dilation=rhs_dil,
                       dimension_numbers=dimension_numbers,
                       feature_group_count=feature_group_count,
                       batch_group_count=batch_group_count,
                       precision=lax.Precision.HIGHEST)
        self._CheckBatching(conv,
                            5, (lhs_bdim, rhs_bdim), (lhs_shape, rhs_shape),
                            (dtype, dtype),
                            rng,
                            rtol=tol,
                            atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format(
                shape, from_dtype, to_dtype, bdims),
            "shape":
            shape,
            "from_dtype":
            from_dtype,
            "to_dtype":
            to_dtype,
            "bdims":
            bdims,
            "rng_factory":
            rng_factory
        } for from_dtype, to_dtype in itertools.product(
            [onp.float32, onp.int32, "float32", "int32"], repeat=2)
                            for shape in [(2, 3)] for bdims in all_bdims(shape)
                            for rng_factory in [jtu.rand_default]))
    def testConvertElementType(self, shape, from_dtype, to_dtype, bdims,
                               rng_factory):
        rng = rng_factory(self.rng())
        op = lambda x: lax.convert_element_type(x, to_dtype)
        self._CheckBatching(op, 10, bdims, (shape, ), (from_dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format(
                shape, from_dtype, to_dtype, bdims),
            "shape":
            shape,
            "from_dtype":
            from_dtype,
            "to_dtype":
            to_dtype,
            "bdims":
            bdims,
            "rng_factory":
            rng_factory
        } for from_dtype, to_dtype in itertools.product(
            [onp.float32, onp.int32, "float32", "int32"], repeat=2)
                            for shape in [(2, 3)] for bdims in all_bdims(shape)
                            for rng_factory in [jtu.rand_default]))
    def testBitcastElementType(self, shape, from_dtype, to_dtype, bdims,
                               rng_factory):
        rng = rng_factory(self.rng())
        op = lambda x: lax.bitcast_convert_type(x, to_dtype)
        self._CheckBatching(op, 10, bdims, (shape, ), (from_dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_min_shape={}_operand_shape={}_max_shape={}_bdims={}".format(
                    jtu.format_shape_dtype_string(min_shape, dtype),
                    jtu.format_shape_dtype_string(operand_shape, dtype),
                    jtu.format_shape_dtype_string(max_shape, dtype), bdims),
                "min_shape":
                min_shape,
                "operand_shape":
                operand_shape,
                "max_shape":
                max_shape,
                "dtype":
                dtype,
                "bdims":
                bdims,
                "rng_factory":
                rng_factory
            } for min_shape, operand_shape, max_shape in [
                [(), (2, 3), ()],
                [(2, 3), (2, 3), ()],
                [(), (2, 3), (2, 3)],
                [(2, 3), (2, 3), (2, 3)],
            ] for dtype in default_dtypes
            for bdims in all_bdims(min_shape, operand_shape, max_shape)
            for rng_factory in [jtu.rand_default]))
    def testClamp(self, min_shape, operand_shape, max_shape, dtype, bdims,
                  rng_factory):
        rng = rng_factory(self.rng())
        raise SkipTest(
            "batching rule for clamp not implemented")  # TODO(mattj)
        shapes = [min_shape, operand_shape, max_shape]
        self._CheckBatching(lax.clamp, 10, bdims, shapes, [dtype] * 3, rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs_shape={}_rhs_shape={}_bdims={}".format(
                jtu.format_shape_dtype_string(lhs_shape, dtype),
                jtu.format_shape_dtype_string(rhs_shape, dtype), bdims),
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "bdims":
            bdims,
            "rng_factory":
            rng_factory
        } for lhs_shape in [(3, ), (4, 3)] for rhs_shape in [(3, ), (3, 6)]
                            for bdims in all_bdims(lhs_shape, rhs_shape)
                            for dtype in default_dtypes
                            for rng_factory in [jtu.rand_default]))
    def testDot(self, lhs_shape, rhs_shape, dtype, bdims, rng_factory):
        rng = rng_factory(self.rng())
        op = partial(lax.dot, precision=lax.Precision.HIGHEST)
        self._CheckBatching(op,
                            5,
                            bdims, (lhs_shape, rhs_shape), (dtype, dtype),
                            rng,
                            rtol={
                                onp.float16: 5e-2,
                                onp.float64: 5e-14
                            })

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}_bdims={}"
            .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
                    jtu.format_shape_dtype_string(rhs_shape, dtype),
                    lhs_contracting, rhs_contracting, bdims),
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "lhs_contracting":
            lhs_contracting,
            "rhs_contracting":
            rhs_contracting,
            "bdims":
            bdims,
            "rng_factory":
            rng_factory
        } for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [
            [(3, 5), (2, 5), [1], [1]],
            [(5, 3), (5, 2), [0], [0]],
            [(5, 3, 2), (5, 2, 4), [0], [0]],
            [(5, 3, 2), (5, 2, 4), [0, 2], [0, 1]],
            [(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]],
            [(3, 2), (2, 4), [1], [0]],
        ] for bdims in all_bdims(lhs_shape, rhs_shape)
                            for dtype in default_dtypes
                            for rng_factory in [jtu.rand_small]))
    def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype,
                                   lhs_contracting, rhs_contracting, bdims,
                                   rng_factory):
        rng = rng_factory(self.rng())
        dimension_numbers = ((lhs_contracting, rhs_contracting), ([], []))
        dot = partial(lax.dot_general, dimension_numbers=dimension_numbers)
        self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape),
                            (dtype, dtype), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_bdims={}".format(
                jtu.format_shape_dtype_string(lhs_shape, dtype),
                jtu.format_shape_dtype_string(rhs_shape, dtype),
                dimension_numbers, bdims),
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "dimension_numbers":
            dimension_numbers,
            "bdims":
            bdims,
            "rng_factory":
            rng_factory
        } for lhs_shape, rhs_shape, dimension_numbers in [
            ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
            ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
        ] for bdims in all_bdims(lhs_shape, rhs_shape)
                            for dtype in default_dtypes
                            for rng_factory in [jtu.rand_small]))
    def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype,
                                       dimension_numbers, bdims, rng_factory):
        rng = rng_factory(self.rng())
        dot = partial(lax.dot_general, dimension_numbers=dimension_numbers)
        self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape),
                            (dtype, dtype), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_dtype={}_broadcast_sizes={}_bdims={}".format(
                shape,
                onp.dtype(dtype).name, broadcast_sizes, bdims),
            "shape":
            shape,
            "dtype":
            dtype,
            "broadcast_sizes":
            broadcast_sizes,
            "bdims":
            bdims,
            "rng_factory":
            rng_factory
        } for shape in [(), (2, 3)] for dtype in default_dtypes
                            for broadcast_sizes in [(), (2, ), (1, 2)]
                            for bdims in all_bdims(shape)
                            for rng_factory in [jtu.rand_default]))
    def testBroadcast(self, shape, dtype, broadcast_sizes, bdims, rng_factory):
        rng = rng_factory(self.rng())
        op = lambda x: lax.broadcast(x, broadcast_sizes)
        self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_outshape={}_bcdims={}_bdims={}".format(
                jtu.format_shape_dtype_string(inshape, dtype), outshape,
                broadcast_dimensions, bdims),
            "inshape":
            inshape,
            "dtype":
            dtype,
            "outshape":
            outshape,
            "dimensions":
            broadcast_dimensions,
            "bdims":
            bdims,
            "rng_factory":
            rng_factory
        } for inshape, outshape, broadcast_dimensions in [
            ([2], [2, 2], [0]),
            ([2], [2, 2], [1]),
            ([2], [2, 3], [0]),
            ([], [2, 3], []),
        ] for dtype in default_dtypes for bdims in all_bdims(inshape)
                            for rng_factory in [jtu.rand_default]))
    def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims,
                           rng_factory):
        rng = rng_factory(self.rng())
        raise SkipTest("this test has failures in some cases")  # TODO(mattjj)
        op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
        self._CheckBatching(op, 5, bdims, (inshape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_dimensions={}_bdims={}".format(
                jtu.format_shape_dtype_string(arg_shape, onp.float32),
                dimensions, bdims),
            "arg_shape":
            arg_shape,
            "dimensions":
            dimensions,
            "bdims":
            bdims,
            "rng_factory":
            rng_factory
        } for arg_shape, dimensions in [
            [(1, ), (0, )],
            [(1, ), (-1, )],
            [(2, 1, 4), (1, )],
            [(2, 1, 4), (-2, )],
            [(2, 1, 3, 1), (1, )],
            [(2, 1, 3, 1), (1, 3)],
            [(2, 1, 3, 1), (3, )],
            [(2, 1, 3, 1), (1, -1)],
        ] for bdims in all_bdims(arg_shape)
                            for rng_factory in [jtu.rand_default]))
    def testSqueeze(self, arg_shape, dimensions, bdims, rng_factory):
        dtype = onp.float32
        rng = rng_factory(self.rng())
        op = lambda x: lax.squeeze(x, dimensions)
        self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_inshape={}_outshape={}_dims={}_bdims={}".format(
                    jtu.format_shape_dtype_string(arg_shape, dtype),
                    jtu.format_shape_dtype_string(out_shape, dtype),
                    dimensions, bdims),
                "arg_shape":
                arg_shape,
                "out_shape":
                out_shape,
                "dtype":
                dtype,
                "dimensions":
                dimensions,
                "bdims":
                bdims,
                "rng_factory":
                rng_factory
            } for dtype in default_dtypes
            for arg_shape, dimensions, out_shape in [[(3, 4), None, (
                12, )], [(2, 1, 4), None, (8, )], [(2, 2, 4), None, (
                    2, 8)], [(2, 2,
                              4), (0, 1,
                                   2), (2, 8)], [(2, 2, 4), (1, 0, 2), (
                                       8, 2)], [(2, 2, 4), (2, 1, 0), (4, 2,
                                                                       2)]]
            for bdims in all_bdims(arg_shape)
            for rng_factory in [jtu.rand_default]))
    def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims,
                    rng_factory):
        rng = rng_factory(self.rng())
        op = lambda x: lax.reshape(x, out_shape, dimensions=dimensions)
        self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_pads={}_bdims={}".format(
                jtu.format_shape_dtype_string(shape, dtype), pads, bdims),
            "shape":
            shape,
            "dtype":
            dtype,
            "pads":
            pads,
            "rng_factory":
            jtu.rand_small,
            "bdims":
            bdims
        } for shape in [(2, 3)] for bdims in all_bdims(shape)
                            for dtype in default_dtypes
                            for pads in [[(1, 2, 1), (0, 1, 0)]]))
    def testPad(self, shape, dtype, pads, bdims, rng_factory):
        rng = rng_factory(self.rng())
        fun = lambda operand: lax.pad(operand, onp.array(0, dtype), pads)
        self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_predshape={}_argshapes={}_bdims={}".format(
                    jtu.format_shape_dtype_string(pred_shape, onp.bool_),
                    jtu.format_shape_dtype_string(arg_shape, arg_dtype),
                    bdims),
                "pred_shape":
                pred_shape,
                "arg_shape":
                arg_shape,
                "arg_dtype":
                arg_dtype,
                "bdims":
                bdims,
                "rng_factory":
                rng_factory
            }
            for arg_shape in [(), (3, ), (2, 3)]
            for pred_shape in ([(), arg_shape] if arg_shape else [()])
            for bdims in all_bdims(pred_shape, arg_shape, arg_shape)
            for arg_dtype in default_dtypes
            for rng_factory in [jtu.rand_default]))
    def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims, rng_factory):
        rng = rng_factory(self.rng())
        op = lambda c, x, y: lax.select(c < 0, x, y)
        self._CheckBatching(op, 5, bdims, (
            pred_shape,
            arg_shape,
            arg_shape,
        ), (onp.bool_, arg_dtype, arg_dtype), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_start_indices={}_limit_indices={}_strides={}_bdims={}".
            format(jtu.format_shape_dtype_string(shape, dtype), start_indices,
                   limit_indices, strides, bdims),
            "shape":
            shape,
            "dtype":
            dtype,
            "starts":
            start_indices,
            "limits":
            limit_indices,
            "strides":
            strides,
            "bdims":
            bdims,
            "rng_factory":
            rng_factory
        } for shape, start_indices, limit_indices, strides in [
            [(3, ), (1, ), (2, ), None],
            [(7, ), (4, ), (7, ), None],
            [(5, ), (1, ), (5, ), (2, )],
            [(8, ), (1, ), (6, ), (2, )],
            [(5, 3), (1, 1), (3, 2), None],
            [(5, 3), (1, 1), (3, 1), None],
            [(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
            [(5, 3), (1, 1), (2, 1), (1, 1)],
            [(5, 3), (1, 1), (5, 3), (2, 1)],
        ] for bdims in all_bdims(shape) for dtype in default_dtypes
                            for rng_factory in [jtu.rand_default]))
    def testSlice(self, shape, dtype, starts, limits, strides, bdims,
                  rng_factory):
        rng = rng_factory(self.rng())
        op = lambda x: lax.slice(x, starts, limits, strides)
        self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_perm={}_bdims={}".format(
                jtu.format_shape_dtype_string(shape, dtype), perm, bdims),
            "shape":
            shape,
            "dtype":
            dtype,
            "perm":
            perm,
            "bdims":
            bdims,
            "rng_factory":
            rng_factory
        } for shape, perm in [
            [(3, 4), (1, 0)],
            [(3, 4), (0, 1)],
            [(3, 4, 5), (2, 1, 0)],
            [(3, 4, 5), (1, 0, 2)],
        ] for bdims in all_bdims(shape) for dtype in default_dtypes
                            for rng_factory in [jtu.rand_default]))
    def testTranspose(self, shape, dtype, perm, bdims, rng_factory):
        rng = rng_factory(self.rng())
        op = lambda x: lax.transpose(x, perm)
        self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_op={}_inshape={}_reducedims={}_initval={}_bdims={}".format(
                op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims,
                init_val, bdims),
            "op":
            op,
            "init_val":
            init_val,
            "shape":
            shape,
            "dtype":
            dtype,
            "dims":
            dims,
            "bdims":
            bdims,
            "rng_factory":
            rng_factory
        } for init_val, op, dtypes in [
            (0, lax.add, default_dtypes),
            (1, lax.mul, default_dtypes),
            (0, lax.max, all_dtypes),  # non-monoidal
            (-onp.inf, lax.max, float_dtypes),
            (dtypes.iinfo(onp.int32).min, lax.max, [onp.int32]),
            (dtypes.iinfo(onp.int64).min, lax.max, [onp.int64]),
            (dtypes.iinfo(onp.uint32).min, lax.max, [onp.uint32]),
            (dtypes.iinfo(onp.uint64).min, lax.max, [onp.uint64]),
            (onp.inf, lax.min, float_dtypes),
            (dtypes.iinfo(onp.int32).max, lax.min, [onp.int32]),
            (dtypes.iinfo(onp.int64).max, lax.min, [onp.int64]),
            (dtypes.iinfo(onp.uint32).max, lax.min, [onp.uint32]),
            (dtypes.iinfo(onp.uint64).max, lax.min, [onp.uint64]),
        ] for dtype in dtypes for shape, dims in [[(3, 4, 5), (
            0, )], [(3, 4, 5), (1, 2)], [(3, 4,
                                          5), (0, 2)], [(3, 4, 5), (0, 1, 2)]]
                            for bdims in all_bdims(shape)
                            for rng_factory in [jtu.rand_small]))
    def testReduce(self, op, init_val, shape, dtype, dims, bdims, rng_factory):
        rng = rng_factory(self.rng())
        init_val = onp.asarray(init_val, dtype=dtype)
        fun = lambda operand: lax.reduce(operand, init_val, op, dims)
        self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_op={}_dtype={}_padding={}".format(op.__name__,
                                                onp.dtype(dtype).name,
                                                padding),
            "op":
            op,
            "init_val":
            init_val,
            "dtype":
            dtype,
            "padding":
            padding,
            "rng_factory":
            rng_factory
        } for init_val, op, dtypes in [
            (0, lax.add, [onp.float32]),
            (-onp.inf, lax.max, [onp.float32]),
            (onp.inf, lax.min, [onp.float32]),
        ] for dtype in dtypes for padding in ["VALID", "SAME"]
                            for rng_factory in [jtu.rand_small]))
    def testReduceWindow(self, op, init_val, dtype, padding, rng_factory):
        rng = rng_factory(self.rng())
        init_val = onp.asarray(init_val, dtype=dtype)

        all_configs = itertools.chain(
            itertools.product([(4, 6)], [(2, 1), (1, 2)], [(1, 1), (2, 1),
                                                           (1, 2)]),
            itertools.product([(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
                              [(1, 2, 2, 1), (1, 1, 1, 1)]))

        def fun(operand):
            return lax.reduce_window(operand, init_val, op, dims, strides,
                                     padding)

        for shape, dims, strides in all_configs:
            for bdims in all_bdims(shape):
                self._CheckBatching(fun, 3, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_op={}_shape={}_axis={}_bdims={}".format(
                    op.__name__, jtu.format_shape_dtype_string(shape, dtype),
                    axis, bdims),
                "op":
                op,
                "shape":
                shape,
                "dtype":
                dtype,
                "bdims":
                bdims,
                "axis":
                axis,
                "rng_factory":
                rng_factory
            } for op, types in [
                (lax.cumsum, [onp.float32, onp.float64]),
                (lax.cumprod, [onp.float32, onp.float64]),
            ] for dtype in types for shape in [[10], [3, 4, 5]]
            for axis in range(len(shape)) for bdims in all_bdims(shape)
            for rng_factory in [
                jtu.rand_default if dtypes.issubdtype(dtype, onp.integer
                                                      ) else jtu.rand_small
            ]))
    def testCumulativeReduce(self, op, shape, dtype, axis, bdims, rng_factory):
        rng = rng_factory(self.rng())
        self._CheckBatching(partial(op, axis=axis), 7, bdims, (shape, ),
                            (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_dtype={}_padding={}".format(onp.dtype(dtype).name, padding),
            "dtype":
            dtype,
            "padding":
            padding,
            "rng_factory":
            rng_factory
        } for dtype in float_dtypes for padding in ["VALID", "SAME"]
                            for rng_factory in [jtu.rand_small]))
    @jtu.skip_on_flag("jax_skip_slow_tests", True)
    @jtu.ignore_warning(message="Using reduced precision for gradient.*")
    def testSelectAndGatherAdd(self, dtype, padding, rng_factory):
        if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16:
            raise SkipTest(
                "bfloat16 _select_and_gather_add doesn't work on tpu")
        rng = rng_factory(self.rng())
        all_configs = itertools.chain(
            itertools.product([(4, 6)], [(2, 1), (1, 2)], [(1, 1), (2, 1),
                                                           (1, 2)]),
            itertools.product([(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
                              [(1, 2, 2, 1), (1, 1, 1, 1)]))

        def fun(operand, tangents):
            return lax._select_and_gather_add(operand, tangents, lax.ge_p,
                                              dims, strides, padding)

        for shape, dims, strides in all_configs:
            for bdims in all_bdims(shape, shape):
                self._CheckBatching(fun, 3, bdims, (shape, shape),
                                    (dtype, dtype), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_bdims={}_fft_ndims={}".format(shape, bdims, fft_ndims),
            "shape":
            shape,
            "bdims":
            bdims,
            "fft_ndims":
            fft_ndims,
            "rng_factory":
            rng_factory
        } for shape in [(5, ), (3, 4, 5), (2, 3, 4, 5)]
                            for bdims in all_bdims(shape)
                            for fft_ndims in range(0,
                                                   min(3, len(shape)) + 1)
                            for rng_factory in [jtu.rand_default]))
    @jtu.skip_on_devices("tpu")  # TODO(b/137993701): unimplemented cases.
    def testFft(self, fft_ndims, shape, bdims, rng_factory):
        rng = rng_factory(self.rng())
        ndims = len(shape)
        axes = range(ndims - fft_ndims, ndims)
        fft_lengths = [shape[axis] for axis in axes]
        op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths)
        self._CheckBatching(op, 5, bdims, [shape], [onp.complex64], rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_idxs={}_dnums={}_slice_sizes={}_bdims={}".format(
                jtu.format_shape_dtype_string(shape, dtype), idxs, dnums,
                slice_sizes, bdims),
            "shape":
            shape,
            "dtype":
            dtype,
            "idxs":
            idxs,
            "dnums":
            dnums,
            "slice_sizes":
            slice_sizes,
            "bdims":
            bdims
        } for dtype in all_dtypes for shape, idxs, dnums, slice_sizes in [
            ((5, ), onp.array([[0], [2]]),
             lax.GatherDimensionNumbers(offset_dims=(),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, )),
            ((10, ), onp.array([[0], [0], [0]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(),
                                        start_index_map=(0, )), (2, )),
            ((
                10,
                5,
            ), onp.array([[0], [2], [1]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, 3)),
            ((10, 5), onp.array([[0, 2], [1, 0]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, 1)), (1, 3)),
        ] for bdims in all_bdims(shape, idxs.shape)))
    def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims):
        fun = partial(lax.gather,
                      dimension_numbers=dnums,
                      slice_sizes=slice_sizes)
        self._CheckBatching(fun, 5, bdims, [shape, idxs.shape],
                            [dtype, idxs.dtype], jtu.rand_default(self.rng()))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_shape={}_idxs={}_update={}_dnums={}_bdims={}".format(
                    jtu.format_shape_dtype_string(arg_shape, dtype), idxs,
                    update_shape, dnums, bdims),
                "arg_shape":
                arg_shape,
                "dtype":
                dtype,
                "idxs":
                idxs,
                "update_shape":
                update_shape,
                "dnums":
                dnums,
                "bdims":
                bdims
            } for dtype in float_dtypes
            for arg_shape, idxs, update_shape, dnums in [
                ((5, ), onp.array([[0], [2]]), (2, ),
                 lax.ScatterDimensionNumbers(update_window_dims=(),
                                             inserted_window_dims=(0, ),
                                             scatter_dims_to_operand_dims=(
                                                 0, ))),
                ((10, ), onp.array([[0], [0], [0]]), (3, 2),
                 lax.ScatterDimensionNumbers(update_window_dims=(1, ),
                                             inserted_window_dims=(),
                                             scatter_dims_to_operand_dims=(
                                                 0, ))),
                ((
                    10,
                    5,
                ), onp.array([[0], [2], [1]]), (3, 3),
                 lax.ScatterDimensionNumbers(update_window_dims=(1, ),
                                             inserted_window_dims=(0, ),
                                             scatter_dims_to_operand_dims=(
                                                 0, ))),
            ] for bdims in all_bdims(arg_shape, idxs.shape, update_shape)))
    def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums,
                       bdims):
        fun = partial(lax.scatter_add, dimension_numbers=dnums)
        self._CheckBatching(fun,
                            5,
                            bdims, [arg_shape, idxs.shape, update_shape],
                            [dtype, idxs.dtype, dtype],
                            jtu.rand_default(self.rng()),
                            rtol={onp.float16: 5e-3})

    def testShapeUsesBuiltinInt(self):
        x = lax.iota(onp.int32, 3) + 1
        self.assertIsInstance(x.shape[0], int)  # not np.int64

    def testBroadcastShapesReturnsPythonInts(self):
        shape1, shape2 = (1, 2, 3), (2, 3)
        out_shape = lax.broadcast_shapes(shape1, shape2)
        self.assertTrue(all(type(s) is int for s in out_shape))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_shape={}_k={}_bdims={}".format(
                    jtu.format_shape_dtype_string(shape, dtype), k, bdims),
                "shape":
                shape,
                "dtype":
                dtype,
                "k":
                k,
                "bdims":
                bdims,
                "rng_factory":
                rng_factory
            } for shape in [(4, ), (3, 5, 3)] for k in [1, 3]
            for bdims in all_bdims(shape)
            # TODO(b/155170120): test with repeats once the XLA:CPU stable top_k bug is fixed:
            # The top_k indices for integer arrays with identical entries won't match between
            # vmap'd version and manual reference, so only test unique integer arrays for int_dtypes.
            # Note also that we chose 3 * 5 * 3 * 5 such that it fits in the range of
            # values a bfloat16 can represent exactly to avoid ties.
            for dtype, rng_factory in itertools.chain(
                unsafe_zip(float_dtypes +
                           int_dtypes, itertools.repeat(jtu.rand_unique_int))))
    )
    def testTopK(self, shape, dtype, k, bdims, rng_factory):
        rng = rng_factory(self.rng())
        # _CheckBatching doesn't work with tuple outputs, so test outputs separately.
        op1 = lambda x: lax.top_k(x, k=k)[0]
        self._CheckBatching(op1, 5, bdims, (shape, ), (dtype, ), rng)
        op2 = lambda x: lax.top_k(x, k=k)[1]
        self._CheckBatching(op2, 5, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_dimension={}_arity={}_bdims={}_isstable={}".format(
                jtu.format_shape_dtype_string(shape, onp.float32), dimension,
                arity, bdims, is_stable),
            "shape":
            shape,
            "dimension":
            dimension,
            "arity":
            arity,
            "bdims":
            bdims,
            "is_stable":
            is_stable
        } for shape in [(2, 3)] for dimension in [0, 1] for arity in range(3)
                            for bdims in all_bdims(*((shape, ) * arity))
                            for is_stable in [False, True]))
    def testSort(self, shape, dimension, arity, bdims, is_stable):
        rng = jtu.rand_default(self.rng())
        if arity == 1:
            fun = partial(lax.sort, dimension=dimension)
            self._CheckBatching(fun, 5, bdims, (shape, ) * arity,
                                (onp.float32, ) * arity, rng)
        else:
            for i in range(arity):
                fun = lambda *args, i=i: lax.sort(
                    args, dimension=dimension, is_stable=is_stable)[i]
                self._CheckBatching(fun, 5, bdims, (shape, ) * arity,
                                    (onp.float32, ) * arity, rng)
Example #15
0
def categorize(prim: core.Primitive, *args, **kwargs) \
    -> List[Limitation]:
  """
  Given a primitive and a set of parameters one would like to pass to it,
  categorize identifies the potential limitations the call would encounter when
  converted to TF through jax2tf.

  Args:
    prim: the primitive to call.
    args: the arguments to pass to prim.
    kwargs: the keyword arguments to pass to prim.

  Returns:
    A list of limitations
  """
  limitations: List[Limitation] = []
  all_devices = ["CPU", "GPU", "TPU"]

  def _report_failure(error_type: str, msg: str,
                      affected_dtype: Optional[NpDType] = None,
                      devs: Sequence[str] = all_devices) -> None:
    affected_dtypes = (
      tuple([affected_dtype]) if affected_dtype is not None else tuple())
    limitations.append(Limitation(prim.name, error_type, msg,
                                  affected_dtypes, tuple(devs)))

  def tf_unimpl(np_dtype: Optional[NpDType] = None,
                additional_msg: Optional[str] = None,
                devs: Sequence[str] = all_devices) -> None:
    msg = "Primitive is unimplemented in TF"
    if additional_msg:
      msg += '; ' + additional_msg
    _report_failure(CATEGORY_MISSING_TF_SUPPORT, msg, np_dtype, devs=devs)

  def tf_possible_incorrect(np_dtype: Optional[NpDType] = None,
                            msg: str = "",
                            devs: Sequence[str] = all_devices) -> None:
    _report_failure(CATEGORY_POSSIBLE_INCORRECT_RESULTS, msg, np_dtype, devs=devs)

  def _to_np_dtype(dtype) -> NpDType:
    try:
      dtype = to_jax_dtype(dtype)
    except:
      pass
    return np.dtype(dtype)

  if args and args[0] is not core.unit:
    np_dtype = _to_np_dtype(args[0].dtype)
  else:
    np_dtype = None

  if prim is lax.regularized_incomplete_beta_p:
    if np_dtype in [np.float16, dtypes.bfloat16]:
      tf_unimpl(np_dtype)

  if prim in [lax.reduce_min_p, lax.reduce_max_p]:
    if np_dtype in [np.complex64, np.complex128]:
      tf_unimpl(np_dtype)

  if prim in [lax.min_p, lax.max_p, lax.reduce_window_min_p,
              lax.reduce_window_max_p]:
    if np_dtype in [np.bool_, np.int8, np.uint16, np.uint32, np.uint64,
                    np.complex64, np.complex128]:
      tf_unimpl(np_dtype)

  if prim is lax.div_p:
    if np_dtype in [np.uint8, np.uint16, np.uint32, np.uint64,
                    np.int8, np.int16]:
      tf_unimpl(np_dtype)
    elif dtypes.issubdtype(np_dtype, np.integer):
      tf_unimpl(np_dtype, additional_msg=("integer division fails if the "
                                          "divisor contains a 0"))

  if prim is lax.rem_p:
    if np_dtype in [np.uint8, np.uint16, np.uint32, np.uint64,
                    np.int8, np.int16, np.float16]:
      tf_unimpl(np_dtype)
    elif dtypes.issubdtype(np_dtype, np.integer):
      tf_unimpl(np_dtype, additional_msg=("integer division fails if the "
                                          "divisor contains a 0"))

  if prim is lax.atan2_p and np_dtype in [np.float16, dtypes.bfloat16]:
      # b/158006398: TF kernels are missing for 'rem' and 'atan2'
      tf_unimpl(np_dtype)

  if prim is lax.nextafter_p:
    if np_dtype in [np.float16, dtypes.bfloat16]:
      tf_unimpl(np_dtype)

  if prim is lax.linalg.cholesky_p:
    if np_dtype in [np.complex64, np.complex128]:
      # See https://github.com/google/jax/pull/3775#issuecomment-659407824;
      # experimental_compile=True breaks for complex types.
      tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled "
                                          "mode (experimental_compile=True))"))

  if prim is lax.linalg.qr_p:
    if np_dtype in [np.complex64, np.complex128]:
      # See https://github.com/google/jax/pull/3775#issuecomment-659407824;
      # experimental_compile=True breaks for complex types.
      tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled "
                                          "mode (experimental_compile=True))"))

  if prim is lax.linalg.eig_p:
    tf_unimpl(additional_msg=("this is a problem only in compiled mode "
                              "(experimental_compile=True))"))
    compute_left_eigenvectors = kwargs['compute_left_eigenvectors']
    compute_right_eigenvectors = kwargs['compute_right_eigenvectors']
    if compute_left_eigenvectors and compute_right_eigenvectors:
      tf_unimpl(additional_msg=("it is not possible to request both left and "
                                "right eigenvectors for now"))

  if prim is lax.linalg.eigh_p:
    if np_dtype in [np.complex64, np.complex128]:
      # See https://github.com/google/jax/pull/3775#issuecomment-659407824;
      # experimental_compile=True breaks for complex types.
      tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled "
                                          "mode (experimental_compile=True))"))

  if prim is lax.linalg.lu_p:
    if np_dtype == np.complex64:
      tf_unimpl(np_dtype, devs=["TPU"])

  if prim is lax.linalg.triangular_solve_p:
    if np_dtype in [dtypes.bfloat16, np.float16]:
      tf_unimpl(np_dtype)

  if prim is lax.linalg.svd_p:
    if np_dtype in [dtypes.bfloat16]:
      # TODO: SVD on TPU for bfloat16 seems to work for JAX but fails for TF
      tf_unimpl(np_dtype, devs=["TPU"])
    elif np_dtype in [np.complex64, np.complex128]:
      # TODO: on CPU and GPU "No registered 'Svd' OpKernel for XLA_CPU_JIT
      # devices". Works on JAX because JAX uses a custom implementation.
      # There exists a XlaSvd operation that could replace tf.linalg.svd in
      # these cases but complex numbers support is not implemented in XLA yet,
      # and the API of XlaSvd is different than the one in JAX/TF, which also
      # limits its useability (e.g. no full_matrices argument, …).
      additional_msg = ("this works on JAX because JAX uses a custom "
                        "implementation")
      tf_unimpl(np_dtype, additional_msg=additional_msg, devs=["CPU", "GPU"])

  if prim is lax.select_and_scatter_add_p:
    if np_dtype in [np.uint64, np.uint32, np.uint16]:
      tf_unimpl(np_dtype)

  if prim is lax.select_and_gather_add_p:
    # TODO: the conversion is only supported for float16/float32 on CPU/GPU,
    # and float16 on TPU. This is because we do not implement a precision
    # reduction in the case where packing 2 n-bit values together results in
    # more than the maximum number of bits allowed on the platform (64 on
    # CPU/GPU, 32 on TPU). This could be fixed by implementing a variadic
    # reduce_window in tfxla, or we can require the user to reduce the
    # precision of their arrays manually based on the platform they run on.
    devices_and_max_bits = [ (["CPU", "GPU"], 64)
                           , (["TPU"], 32)
                           ]
    for devs, max_bits in devices_and_max_bits:
      if dtypes.finfo(np_dtype).bits * 2 > max_bits:
        # TODO: getting an exception "XLA encountered an HLO for which this
        # rewriting is not implemented"
        tf_unimpl(np_dtype, devs=devs)

  if prim in [lax.add_p, lax.reduce_window_sum_p]:
    if np_dtype in [np.uint16, np.uint32, np.uint64]:
      # TODO(bchetioui): tf.math.add is not defined for the above types.
      tf_unimpl(np_dtype)

  if prim is lax.mul_p:
    if np_dtype in [np.uint32, np.uint64]:
      # TODO(bchetioui): tf.math.multiply is not defined for the above types.
      tf_unimpl(np_dtype)

  if prim is lax.sort_p:
    if np_dtype in [np.complex64, np.complex128]:
      tf_unimpl(np_dtype)
    if np_dtype == np.bool_ and len(args) == 2:
      tf_unimpl(np_dtype, additional_msg=(
        "sorting 2 arrays where the first one is an array of booleans is not "
        "supported for XlaSort"))
    if kwargs["is_stable"]:
      tf_unimpl(additional_msg="stable sort not implemented for XlaSort")
    if kwargs["dimension"] != len(np.shape(args[0])) - 1:
      tf_unimpl(additional_msg="only sorting on last dimension is supported "
                               "for XlaSort")
    if len(args) > 2:
      tf_unimpl(additional_msg=(
        "sorting more than 2 arrays is not supported for XlaSort"))

  if prim is lax.population_count_p:
    if np_dtype in [np.uint32, np.uint64]:
      tf_unimpl(np_dtype)

  if prim is lax.clamp_p:
    if np_dtype in [np.int8, np.uint16, np.uint32, np.uint64]:
      tf_unimpl(np_dtype)

  # Testing with matmul (TODO: comment out and test without matmul)
  if prim is lax.dot_general_p:
    np_dtype = _to_np_dtype(args[0].dtype)
    if np_dtype in [np.bool, np.uint8, np.uint16, np.uint32, np.uint64,
                    np.int8]:
      tf_unimpl(np_dtype)
    elif np_dtype == np.int16:
      # TODO(bchetioui): the path using 'einsum' is not compatible with int16
      # arguments on CPU/GPU, while the one using 'matmul' is (but not in
      # compiled mode).
      tf_unimpl(np_dtype, additional_msg=("only cases representable as 2D "
                                          "matrix multiplication can be "
                                          "converted properly"),
                devs=['CPU', 'GPU'])
      tf_unimpl(np_dtype, devs=['TPU'])
    elif np_dtype in [np.int16, np.int64]:
      devs = ['CPU'] if np_dtype == np.int16 else ['CPU', 'GPU']
      tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled "
                                          "mode (experimental_compile=True))"),
                devs=devs)
  if prim is lax.conv_general_dilated_p:
    batch_group_count = kwargs['batch_group_count']
    if batch_group_count != 1:
      tf_unimpl(additional_msg="batch_group_count != 1 unsupported")
    if np_dtype in [np.complex64, np.complex128]:
      tf_unimpl(np_dtype, additional_msg="likely bug in the HLO -> LLVM IR "
                                         "lowering of XlaConv")

  if prim in [lax.acosh_p, lax.asinh_p, lax.atanh_p, lax.bessel_i0e_p,
              lax.bessel_i1e_p, lax.digamma_p, lax.erf_p, lax.erf_inv_p,
              lax.erfc_p, lax.lgamma_p, lax.round_p, lax.rsqrt_p]:
    if np_dtype == dtypes.bfloat16:
      tf_unimpl(np_dtype, devs=["CPU", "GPU"])

  if prim is lax.convert_element_type_p:
    if np_dtype == dtypes.bfloat16:
      tf_unimpl(np_dtype, devs=["CPU", "GPU"])

  if prim in [lax.sinh_p, lax.cosh_p, lax.atanh_p, lax.asinh_p, lax.acosh_p,
              lax.erf_inv_p]:
    if np_dtype == np.float16:
      # b/158006398: float16 support missing from the kernel of the above
      # operations.
      tf_unimpl(np_dtype)

  if prim in [lax.le_p, lax.lt_p, lax.ge_p, lax.gt_p]:
    if np_dtype in [np.bool_, np.uint16, np.uint32, np.uint64]:
      tf_unimpl(np_dtype)

  if prim is lax.fft_p:
    if np_dtype in [np.float64, np.complex128]:
      tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled "
                                          "mode (experimental_compile=True))"))

  if prim is lax.top_k_p:
    if np_dtype in [np.float64, np.int64, np.uint64]:
      tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled "
                                          "mode (experimental_compile=True))"))
  return limitations
Example #16
0
class LaxAutodiffTest(jtu.JaxTestCase):

  @parameterized.named_parameters(itertools.chain.from_iterable(
      jtu.cases_from_list(
        {"testcase_name": jtu.format_test_name_suffix(
            rec.name, shapes, itertools.repeat(dtype)),
         "op": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype,
         "order": rec.order, "tol": rec.tol}
        for shape_group in compatible_shapes
        for shapes in CombosWithReplacement(shape_group, rec.nargs)
        for dtype in rec.dtypes)
      for rec in LAX_GRAD_OPS))
  def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol):
    rng = rng_factory(self.rng())
    if jtu.device_under_test() == "tpu" and op is lax.pow:
      raise SkipTest("pow grad imprecise on tpu")
    tol = jtu.join_tolerance(1e-1, tol) if num_float_bits(dtype) == 32 else tol
    args = tuple(rng(shape, dtype) for shape in shapes)
    check_grads(op, args, order, ["fwd", "rev"], tol, tol)

  @parameterized.named_parameters(itertools.chain.from_iterable(
      jtu.cases_from_list(
          {"testcase_name": "_{}_{}".format(rec.op.__name__, special_value),
           "op": rec.op, "special_value": special_value, "tol": rec.tol}
          for special_value in rec.values)
      for rec in LAX_GRAD_SPECIAL_VALUE_TESTS))
  def testOpGradSpecialValue(self, op, special_value, tol):
    check_grads(op, (special_value,), 2, ["fwd", "rev"], rtol=tol, atol=tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_from_dtype={}_to_dtype={}".format(
          jtu.dtype_str(from_dtype), jtu.dtype_str(to_dtype)),
       "from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory}
      for from_dtype, to_dtype in itertools.product(
          float_dtypes + complex_dtypes, repeat=2)
      for rng_factory in [jtu.rand_default]))
  def testConvertElementTypeGrad(self, from_dtype, to_dtype, rng_factory):
    rng = rng_factory(self.rng())
    tol = max(jtu.tolerance(to_dtype, jtu.default_gradient_tolerance),
              jtu.tolerance(from_dtype, jtu.default_gradient_tolerance))
    args = (rng((2, 3), from_dtype),)
    convert_element_type = lambda x: lax.convert_element_type(x, to_dtype)
    convert_element_type = jtu.ignore_warning(category=onp.ComplexWarning)(
      convert_element_type)
    check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}".format(
          jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype,
       "rng_factory": rng_factory}
      for shape in [(), (2, 3)]
      for dtype in grad_float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testClampGrad(self, shape, dtype, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    low = operand - dtype(10)
    high = operand + dtype(10)
    # Avoids points near the boundary where the gradient may be inaccurate.
    check_grads(lax.clamp, (operand, low, high), 2, ["fwd", "rev"], eps=1e-2)
    check_grads(lax.clamp, (low, operand, high), 2, ["fwd", "rev"], eps=1e-2)
    check_grads(lax.clamp, (low, high, operand), 2, ["fwd", "rev"], eps=1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format(
          dim, ",".join(str(d) for d in base_shape), onp.dtype(dtype).name,
          num_arrs),
       "dim": dim, "base_shape": base_shape, "dtype": dtype,
       "num_arrs": num_arrs, "rng_factory": rng_factory}
      for num_arrs in [3]
      for dtype in float_dtypes
      for base_shape in [(4,), (3, 4), (2, 3, 4)]
      for dim in range(len(base_shape))
      for rng_factory in [jtu.rand_default]))
  def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs, rng_factory):
    rng = rng_factory(self.rng())
    shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:]
              for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))]
    operands = tuple(rng(shape, dtype) for shape in shapes)
    concatenate = lambda *args: lax.concatenate(args, dim)
    check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               strides, padding),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "strides": strides, "padding": padding, "rng_factory": rng_factory,}
       for lhs_shape, rhs_shape, all_strides in itertools.chain(
           [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)])
            for b, i, j in itertools.product([2, 3], repeat=3)],
           [((4, 2, 1), (3, 2, 1), [(1,)])])
       for strides in all_strides
       for dtype in float_dtypes
       for padding in ["VALID", "SAME"]
       for rng_factory in [jtu.rand_small]))
  def testConvGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, rng_factory):
    rng = rng_factory(self.rng())
    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    conv = partial(lax.conv, window_strides=strides, padding=padding,
                   precision=lax.Precision.HIGHEST)
    check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
                         atol=1e-2, rtol=1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
       "rhs_dilation={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               strides, padding, lhs_dil, rhs_dil),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "strides": strides, "padding": padding, "lhs_dil": lhs_dil,
       "rhs_dil": rhs_dil, "rng_factory": rng_factory}
       for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in
       itertools.chain(
           [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)],
             [((0, 0), (0, 0)), ((-1, 0), (0, -1)), ((1, 0), (0, 1))],
             [(1, 1), (2, 1)], [(1, 1)])
            for b, i, j in itertools.product([2, 3], repeat=3)],
           [((4, 2, 1), (3, 2, 1), [(1,)], [((1, 1),), ((0, 0),)],
             [(1,), (2,)], [(1,), (2,)])])
       for strides in all_strides
       for rhs_dil in rhs_dils
       for lhs_dil in lhs_dils
       for dtype in float_dtypes
       for padding in all_pads
       for rng_factory in [jtu.rand_small]))
  def testConvWithGeneralPaddingGrad(self, lhs_shape, rhs_shape, dtype, strides,
                                     padding, lhs_dil, rhs_dil, rng_factory):
    rng = rng_factory(self.rng())
    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    conv = partial(lax.conv_with_general_padding, window_strides=strides,
                   padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
                   precision=lax.Precision.HIGHEST)
    check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
                         atol=1e-2, rtol=1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
       "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums),
               feature_group_count, batch_group_count),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "strides": strides, "padding": padding, "lhs_dil": lhs_dil,
       "rhs_dil": rhs_dil, "rng_factory": rng_factory, "dimension_numbers": dim_nums,
       "perms": perms, "feature_group_count": feature_group_count,
       "batch_group_count": batch_group_count}
      for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)])
      for lhs_shapes, rhs_shape, all_strides, lhs_dils, rhs_dils in [
          ([(b * batch_group_count, i * feature_group_count, 6, 7),
            (b * batch_group_count, i * feature_group_count, 0, 4)],  # lhs_shape
           (j * batch_group_count * feature_group_count, i, 1, 2),  # rhs_shape
           [(1, 1), (1, 2), (2, 1)],  # strides
           [(1, 1), (2, 1)],  # lhs_dils
           [(1, 1), (2, 2)])  # rhs_dils
          for b, i, j in itertools.product([1, 2], repeat=3)]
      for lhs_shape in lhs_shapes
      for strides in all_strides
      for rhs_dil in rhs_dils
      for lhs_dil in lhs_dils
      for dtype in grad_float_dtypes
      for padding in ([((0, 0), (0, 0)), ((1, 0), (0, 1))] +
        ([((0, -1), (0, 0))] if lhs_shape[2] != 0 else []))
      for dim_nums, perms in [
          (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
          (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
          (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))]
      for rng_factory in [jtu.rand_default]
  ))
  def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides,
                                 padding, lhs_dil, rhs_dil, dimension_numbers,
                                 perms, feature_group_count, batch_group_count,
                                 rng_factory):
    if dtype == onp.float16:
      raise SkipTest("float16 numerical issues")  # TODO(mattjj): resolve

    rng = rng_factory(self.rng())
    tol = {dtypes.bfloat16: 1e-0, onp.float16: 5e-1, onp.float32: 2e-4}

    # permute shapes to match dim_spec, scale by feature_group_count
    lhs_perm, rhs_perm = perms
    lhs_shape = list(onp.take(lhs_shape, lhs_perm))
    rhs_shape = list(onp.take(rhs_shape, rhs_perm))

    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    conv = partial(lax.conv_general_dilated, window_strides=strides,
                   padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
                   dimension_numbers=dimension_numbers,
                   feature_group_count=feature_group_count,
                   batch_group_count=batch_group_count,
                   precision=lax.Precision.HIGHEST)
    check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
                         atol=tol, rtol=tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_lhs_shape={}_rhs_shape={}".format(
          jtu.format_shape_dtype_string(lhs_shape, dtype),
          jtu.format_shape_dtype_string(rhs_shape, dtype)),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "rng_factory": jtu.rand_default}
      for lhs_shape in [(2,), (3, 2)] for rhs_shape in [(2,), (2, 4)]
      for dtype in float_dtypes))
  def testDotGrad(self, lhs_shape, rhs_shape, dtype, rng_factory):
    rng = rng_factory(self.rng())
    tol = {onp.float16: 1e-1, onp.float32: 1e-4}
    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    dot = partial(lax.dot, precision=lax.Precision.HIGHEST)
    check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"],
                         atol=tol, rtol=tol)
    # check that precision config is preserved
    result, pullback = api.vjp(dot, lhs, rhs)
    gresult = lax.zeros_like_array(result)
    s = str(api.make_jaxpr(pullback)(gresult))
    assert "precision=HIGHEST" in s

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_dimension_numbers={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               dimension_numbers),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "dimension_numbers": dimension_numbers, "rng_factory": jtu.rand_small}
      for lhs_shape, rhs_shape, dimension_numbers in [
          ((3, 2), (2, 4), (([1], [0]), ([], []))),
          ((3, 5), (2, 5), (([1], [1]), ([], []))),
          ((5, 3), (5, 2), (([0], [0]), ([], []))),
          ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
      ]
      for dtype in float_dtypes))
  def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype,
                                          dimension_numbers, rng_factory):
    rng = rng_factory(self.rng())
    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers,
                          precision=lax.Precision.HIGHEST)
    check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"])
    # check that precision config is preserved
    result, pullback = api.vjp(dot_general, lhs, rhs)
    gresult = lax.zeros_like_array(result)
    s = str(api.make_jaxpr(pullback)(gresult))
    assert "precision=HIGHEST" in s

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format(
          shape, onp.dtype(dtype).name, broadcast_sizes),
       "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes,
       "rng_factory": rng_factory}
      for shape in [(), (2, 3)]
      for dtype in float_dtypes
      for broadcast_sizes in [(), (2,), (1, 2)]
      for rng_factory in [jtu.rand_default]))
  def testBroadcastGrad(self, shape, dtype, broadcast_sizes, rng_factory):
    rng = rng_factory(self.rng())
    args = (rng(shape, dtype),)
    broadcast = lambda x: lax.broadcast(x, broadcast_sizes)
    check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
          jtu.format_shape_dtype_string(inshape, dtype),
          outshape, broadcast_dimensions),
       "inshape": inshape, "dtype": dtype, "outshape": outshape,
       "dimensions": broadcast_dimensions, "rng_factory": rng_factory}
      for inshape, outshape, broadcast_dimensions in [
          ([2], [2, 2], [0]),
          ([2], [2, 2], [1]),
          ([2], [2, 3], [0]),
          ([], [2, 3], []),
      ]
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(inshape, dtype)
    broadcast_in_dim = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
    check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_outshape={}_perm={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          jtu.format_shape_dtype_string(out_shape, dtype),
          permutation),
       "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
       "rng_factory": rng_factory, "permutation": permutation}
      for dtype in float_dtypes
      for arg_shape, out_shape, permutation in [
          [(3, 4), (12,), None],
          [(2, 1, 4), (8,), None],
          [(2, 2, 4), (2, 8), None],
          [(3, 4), (12,), (0, 1)],
          [(3, 4), (12,), (1, 0)],
          [(2, 1, 4), (8,), (0, 2, 1)],
          [(2, 1, 4), (8,), (2, 0, 1)],
          [(2, 2, 4), (2, 8), (0, 2, 1)],
          [(2, 2, 4), (2, 8), (2, 0, 1)],
      ]
      for rng_factory in [jtu.rand_default]))
  def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(arg_shape, dtype)
    reshape = lambda x: lax.reshape(x, out_shape, permutation)
    check_grads(reshape, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_pads={}"
       .format(jtu.format_shape_dtype_string(shape, dtype), pads),
       "shape": shape, "dtype": dtype, "pads": pads, "rng_factory": jtu.rand_small}
      for shape in [(2, 3)]
      for dtype in float_dtypes
      for pads in [[(1, 2, 1), (0, 1, 0)], [(-1, 0, 0), (-1, 0, 2)]]))
  def testPadGrad(self, shape, dtype, pads, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    pad = lambda operand: lax.pad(operand, onp.array(0, dtype), pads)
    check_grads(pad, (operand,), 2, ["fwd", "rev"], eps=1.)

    operand = rng(shape, dtype)
    padding_value = onp.array(0., dtype)
    pad = lambda operand, padding_value: lax.pad(operand, padding_value, pads)
    check_grads(pad, (operand, padding_value), 2, ["fwd", "rev"], eps=1.)

  def testReverseGrad(self):
    rev = lambda operand: lax.rev(operand, dimensions)

    dimensions = [0]
    check_grads(rev, (onp.array([3., 2., 1.]),), 2)

    dimensions = [0, 1]
    check_grads(rev, (onp.array([[6., 5., 4.], [3., 2., 1.]]),), 2,
                rtol={onp.float32: 3e-3})

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_predshape={}_argshapes={}".format(
          jtu.format_shape_dtype_string(pred_shape, onp.bool_),
          jtu.format_shape_dtype_string(arg_shape, dtype)),
       "pred_shape": pred_shape, "arg_shape": arg_shape, "dtype": dtype,
       "rng_factory": rng_factory}
      for arg_shape in [(), (3,), (2, 3)]
      for pred_shape in ([(), arg_shape] if arg_shape else [()])
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testSelectGrad(self, pred_shape, arg_shape, dtype, rng_factory):
    rng = rng_factory(self.rng())
    pred = rng(pred_shape, onp.bool_)
    on_true = rng(arg_shape, dtype)
    on_false = rng(arg_shape, dtype)
    select = lambda on_true, on_false: lax.select(pred, on_true, on_false)
    check_grads(select, (on_true, on_false), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_start_indices={}_limit_indices={}_strides={}".format(
          jtu.format_shape_dtype_string(shape, dtype),
          start_indices, limit_indices, strides),
       "shape": shape, "dtype": dtype, "starts": start_indices,
       "limits": limit_indices, "strides": strides, "rng_factory": rng_factory}
      for shape, start_indices, limit_indices, strides in [
        [(3,), (1,), (2,), None],
        [(7,), (4,), (7,), None],
        [(5,), (1,), (5,), (2,)],
        [(8,), (1,), (6,), (2,)],
        [(5, 3), (1, 1), (3, 2), None],
        [(5, 3), (1, 1), (3, 1), None],
        [(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
        [(5, 3), (1, 1), (2, 1), (1, 1)],
        [(5, 3), (1, 1), (5, 3), (2, 1)],
      ]
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testSliceGrad(self, shape, dtype, starts, limits, strides, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    slice = lambda x: lax.slice(x, starts, limits, strides)
    check_grads(slice, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_start_indices={}_size_indices={}".format(
          jtu.format_shape_dtype_string(shape, dtype),
          start_indices, size_indices),
       "shape": shape, "dtype": dtype, "start_indices": start_indices,
       "size_indices": size_indices, "rng_factory": rng_factory}
      for shape, start_indices, size_indices in [
        [(3,), (1,), (1,)],
        [(5, 3), (1, 1), (3, 1)],
        [(7, 5, 3), (4, 1, 0), (2, 0, 1)],
      ]
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices,
                           rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    dynamic_slice = lambda x: lax.dynamic_slice(x, start_indices, size_indices)
    check_grads(dynamic_slice, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_start_indices={}_update_shape={}".format(
          jtu.format_shape_dtype_string(shape, dtype),
          start_indices, update_shape),
       "shape": shape, "dtype": dtype, "start_indices": start_indices,
       "update_shape": update_shape, "rng_factory": rng_factory}
      for shape, start_indices, update_shape in [
        [(3,), (1,), (1,)],
        [(5, 3), (1, 1), (3, 1)],
        [(7, 5, 3), (4, 1, 0), (2, 0, 1)],
      ]
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices,
                                 update_shape, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    update = rng(update_shape, dtype)
    start_indices = onp.array(start_indices)

    dus = lambda x, y: lax.dynamic_update_slice(x, y, start_indices)
    check_grads(dus, (operand, update), 2, ["fwd", "rev"], eps=1.)

    dus = lambda x: lax.dynamic_update_slice(x, update, start_indices)
    check_grads(dus, (operand,), 2, ["fwd", "rev"], eps=1.)

    dus = lambda y: lax.dynamic_update_slice(operand, y, start_indices)
    check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_perm={}".format(
          jtu.format_shape_dtype_string(shape, dtype), perm),
       "shape": shape, "dtype": dtype, "perm": perm, "rng_factory": rng_factory}
      for shape, perm in [
        [(3, 4), (1, 0)],
        [(3, 4), (0, 1)],
        [(3, 4, 5), (2, 1, 0)],
        [(3, 4, 5), (1, 0, 2)],
      ]
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testTransposeGrad(self, shape, dtype, perm, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    transpose = lambda x: lax.transpose(x, perm)
    check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_inshape={}_reducedims={}"
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims),
       "op": op, "init_val": init_val, "shape": shape, "dtype": dtype,
       "dims": dims, "rng_factory": rng_factory}
      for init_val, op, dtypes, rng_factory in [
          (0, lax.add, inexact_dtypes, jtu.rand_default),
          (-onp.inf, lax.max, grad_inexact_dtypes, jtu.rand_unique_int),
          (onp.inf, lax.min, grad_inexact_dtypes, jtu.rand_unique_int),
          (1, lax.mul, grad_float_dtypes, partial(jtu.rand_default, scale=1)),
      ]
      for dtype in dtypes
      for shape, dims in [
          [(3, 4, 5), ()],
          [(3, 4, 5), (0,)],
          [(3, 4, 5), (1, 2)],
          [(3, 4, 5), (0, 2)],
          [(3, 4, 5), (0, 1, 2)],
          [(3, 1), (1,)],
      ]))
  def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory):
    rng = rng_factory(self.rng())
    if jtu.device_under_test() == "tpu" and op is lax.mul:
      raise SkipTest("unimplemented case")
    tol = {dtypes.bfloat16: 2e-1, onp.float16: 1e-1, onp.float32: 1e-1,
           onp.float64: 1e-3, onp.complex64: 1e-1}
    operand = rng(shape, dtype)
    init_val = onp.asarray(init_val, dtype=dtype)
    reduce = lambda operand: lax.reduce(operand, init_val, op, dims)
    eps = (1.0 if dtypes.finfo(dtype).bits == 16 and op is lax.add else
           1e-1 if dtype == dtypes.bfloat16 else
           1e-2 if dtypes.finfo(dtype).bits == 32 else None)
    check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_dtype={}_padding={}"
       .format(op.__name__, onp.dtype(dtype).name, padding),
       "op": op, "init_val": init_val, "dtype": dtype, "padding": padding,
       "rng_factory": rng_factory}
      for init_val, op, dtypes, rng_factory in [
          (0, lax.add, grad_float_dtypes, jtu.rand_small),
          (-onp.inf, lax.max, grad_float_dtypes, jtu.rand_unique_int),
          (onp.inf, lax.min, grad_float_dtypes, jtu.rand_unique_int),
      ]
      for dtype in dtypes
      for padding in ["VALID", "SAME"]))
  @jtu.skip_on_flag("jax_skip_slow_tests", True)
  @jtu.ignore_warning(category=UserWarning,
                      message="Using reduced precision for gradient.*")
  def testReduceWindowGrad(self, op, init_val, dtype, padding, rng_factory):
    rng = rng_factory(self.rng())
    init_val = onp.asarray(init_val, dtype=dtype)

    # We need this conditional and the corresponding loop logic to be in the
    # test method, rather than at the parameterized test level, because it
    # depends on FLAGS for the device under test.
    # TODO(b/31565929): enable when fixed.
    if jtu.device_under_test() == "tpu" and op is not lax.add:
      all_configs = [((6, 5, 4, 3), (2, 2, 1, 1), (1, 2, 1, 1))]

      # TODO(b/73062247): need variadic reduce-window for better precision.
      gradient_order = 1
    else:
      all_configs = itertools.chain(
          itertools.product(
              [(4, 6)],  # shapes
              [(2, 1), (1, 2)],  # window_dimensions
              [(1, 1), (2, 1), (1, 2)]  # strides
          ),
          itertools.product(
              [(3, 2, 4, 6)],  # shapes
              [(1, 1, 2, 1), (2, 1, 2, 1)],  # window_dimensions
              [(1, 2, 2, 1), (1, 1, 1, 1)]),  # strides
      )
      gradient_order = 3

    def fun(operand):
      return lax.reduce_window(operand, init_val, op, dims, strides, padding)

    for shape, dims, strides in all_configs:
      operand = rng(shape, dtype)
      if op is lax.add:
        eps = 1.
        tol = None
      else:
        # this test can fail if there are duplicates in operand
        self.assertEqual(onp.unique(operand).size, operand.size,
                         msg="test requires operand elements to be unique.")
        eps = 1e-2
        tol = {onp.float16: 1e-1, onp.float32: 6e-2, onp.float64: 6e-2}
      check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], tol, tol,
                  eps)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_shape={}_axis={}"
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis),
       "op": op, "shape": shape, "dtype": dtype,
       "axis": axis, "rng_factory": rng_factory}
      for op, types in [
          (lax.cumsum, [onp.float32, onp.float64]),
          (lax.cumprod, [onp.float32, onp.float64]),
      ]
      for dtype in types
      for shape in [[10], [3, 4, 5]]
      for axis in range(len(shape))
      for rng_factory in [
          jtu.rand_default if dtypes.issubdtype(dtype, onp.integer)
          else jtu.rand_small]))
  def testCumulativeReduceGrad(self, op, shape, dtype, axis, rng_factory):
    rng = rng_factory(self.rng())
    check_grads(partial(op, axis=axis), (rng(shape, dtype),), order=2)


  # TODO(b/205052657): enable more tests when supported
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_axis={}_isstable={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, is_stable),
       "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "axis": axis,
       "is_stable": is_stable}
      for dtype in [onp.float32]
      for shape in [(5,), (5, 7)]
      for axis in [len(shape) - 1]
      for is_stable in [False, True]
      for rng_factory in [jtu.rand_default]))
  def testSortGrad(self, shape, dtype, axis, is_stable, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    sort = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable)
    check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2)

  # TODO(b/205052657): enable more tests when supported
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_keyshape={}_valshape={}_axis={}_isstable={}".format(
          jtu.format_shape_dtype_string(shape, key_dtype),
          jtu.format_shape_dtype_string(shape, val_dtype),
          axis, is_stable),
       "rng_factory": rng_factory, "shape": shape,
       "key_dtype": key_dtype, "val_dtype": val_dtype, "axis": axis,
       "is_stable": is_stable}
      for key_dtype in [onp.float32]
      for val_dtype in [onp.float32]
      for shape in [(3,), (5, 3)]
      for axis in [len(shape) - 1]
      for is_stable in [False, True]
      for rng_factory in [jtu.rand_default]))
  def testSortKeyValGrad(self, shape, key_dtype, val_dtype, axis, is_stable,
                         rng_factory):
    rng = rng_factory(self.rng())
    # This test relies on the property that wherever keys are tied, values are
    # too, since we don't guarantee the same ordering of values with equal keys.
    # To avoid that case, we generate unique keys (globally in the key array).
    def args_maker():
      flat_keys = onp.arange(onp.prod(shape, dtype=int), dtype=key_dtype)
      keys = self.rng().permutation(flat_keys).reshape(shape)
      values = rng(shape, val_dtype)
      return keys, values
    keys, values = args_maker()

    fun = lambda keys, values: lax.sort_key_val(keys, values, axis, is_stable)
    check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_k={}".format(
          jtu.format_shape_dtype_string(shape, dtype), k),
       "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k}
      for dtype in [onp.float32,]
      for shape in [(4,), (5, 5), (2, 1, 4)]
      for k in [1, 3]
      for rng_factory in [jtu.rand_default]))
  def testTopKGrad(self, shape, dtype, k, rng_factory):
    flat_values = onp.arange(onp.prod(shape, dtype=int), dtype=dtype)
    values = self.rng().permutation(flat_values).reshape(shape)
    fun = lambda vs: lax.top_k(vs, k=k)[0]
    check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_axes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), idxs, axes),
       "shape": shape, "dtype": dtype, "idxs": idxs, "axes": axes,
       "rng_factory": rng_factory}
      for dtype in float_dtypes
      for shape, idxs, axes in [
          [(3, 4, 5), (onp.array([0, 2, 1]),), (0,)],
          [(3, 4, 5), (onp.array([-1, -2]),), (0,)],
          [(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 1)],
          [(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 2)],
      ]
      for rng_factory in [jtu.rand_default]))
  def testIndexTakeGrad(self, shape, dtype, idxs, axes, rng_factory):
    rng = rng_factory(self.rng())
    src = rng(shape, dtype)
    index_take = lambda src: lax.index_take(src, idxs, axes)
    check_grads(index_take, (src,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), idxs, dnums,
          slice_sizes),
       "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
       "slice_sizes": slice_sizes, "rng_factory": rng_factory,
       "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for shape, idxs, dnums, slice_sizes, max_idx in [
          ((5,), onp.array([[0], [2]]), lax.GatherDimensionNumbers(
            offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1,), 5),
          ((10,), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
            (2,), 9),
          ((10, 5,), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1, 3), 3),
      ]
      for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)]
      for rng_factory in [jtu.rand_default]))
  def testGatherGrad(self, shape, dtype, idxs, dnums, slice_sizes, rng_factory,
                     rng_idx_factory):
    rng = rng_factory(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    gather = lambda x: lax.gather(x, idxs, dimension_numbers=dnums,
                                  slice_sizes=slice_sizes)
    x = rng(shape, dtype)
    check_grads(gather, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory,
       "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for arg_shape, idxs, update_shape, dnums, max_idx in [
          ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,)), 4),
          ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,)), 9),
          ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,)), 3),
      ]
      for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)]
      for rng_factory in [jtu.rand_default]))
  def testScatterAddGrad(self, arg_shape, dtype, idxs, update_shape, dnums,
                         rng_factory, rng_idx_factory):
    rng = rng_factory(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    scatter_add = lambda x, y: lax.scatter_add(x, idxs, y,
                                               dimension_numbers=dnums)
    x = rng(arg_shape, dtype)
    y = rng(update_shape, dtype)
    check_grads(scatter_add, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory,
       "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for arg_shape, idxs, update_shape, dnums, max_idx in [
          ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,)), 4),
          ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,)), 9),
          ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,)), 3),
      ]
      # Scatters with conflicting indices are not deterministic on GPU, so we
      # use indices that do not collide.
      for rng_idx_factory in [partial(jtu.rand_unique_int, high=max_idx)]
      for rng_factory in [jtu.rand_default]))
  def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums,
                      rng_factory, rng_idx_factory):
    rng = rng_factory(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    scatter = lambda x, y: lax.scatter(x, idxs, y, dimension_numbers=dnums)
    x = rng(arg_shape, dtype)
    y = rng(update_shape, dtype)
    check_grads(scatter, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)

  def testScatterGradSymbolicZeroUpdate(self):
    # https://github.com/google/jax/issues/1901
    def f(x):
      n = x.shape[0]
      y = onp.arange(n, dtype=x.dtype)
      return jax.ops.index_update(x, onp.diag_indices(n), y)
    rng = jtu.rand_default(self.rng())
    check_grads(f, (rng((5, 5), onp.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2,
                1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums,
       "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for arg_shape, idxs, update_shape, dnums in [
          ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
          ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,))),
          ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
      ]
      for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))]
      for rng_factory in [jtu.rand_default]))
  def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums,
                     rng_factory, rng_idx_factory):
    rng = rng_factory(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    scatter_max = lambda x, y: lax.scatter_max(x, idxs, y, dnums)
    x = rng(arg_shape, dtype)
    y = rng(update_shape, dtype)
    check_grads(scatter_max, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums,
       "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for arg_shape, idxs, update_shape, dnums in [
          ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
          ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,))),
          ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
      ]
      for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))]
      for rng_factory in [jtu.rand_default]))
  def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums,
                     rng_factory, rng_idx_factory):
    rng = rng_factory(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    scatter_min = lambda x, y: lax.scatter_min(x, idxs, y, dnums)
    x = rng(arg_shape, dtype)
    y = rng(update_shape, dtype)
    check_grads(scatter_min, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)

  def testStopGradient(self):
    def f(x):
      return lax.sin(x) * lax.cos(lax.stop_gradient(x))

    def f2(x, y):
      return lax.sin(x) * lax.cos(y)

    x = 3.14
    ans = api.grad(f)(x)
    expected = api.grad(f2)(x, x)
    self.assertAllClose(ans, expected)

    ans = api.grad(api.grad(f))(x)
    expected = api.grad(api.grad(f2))(x, x)
    self.assertAllClose(ans, expected)

    ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.)
    expected = onp.array(0.0)
    self.assertAllClose(ans, expected, check_dtypes=False)

    with core.skipping_checks():
      with self.assertRaises(TypeError):
        lax.stop_gradient(lambda x: x)

  # TODO(mattjj): make this a more systematic test
  def testRemainder(self):
    rng = onp.random.RandomState(0)
    x = rng.uniform(-0.9, 9, size=(3, 4))
    y = rng.uniform(0.7, 1.9, size=(3, 1))
    assert not set(onp.unique(x)) & set(onp.unique(y))
    tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3
    check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol)

    rng = onp.random.RandomState(0)
    x = rng.uniform(-0.9, 9, size=(1, 4))
    y = rng.uniform(0.7, 1.9, size=(3, 4))
    assert not set(onp.unique(x)) & set(onp.unique(y))
    tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3
    check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol)

  def testHigherOrderGradientOfReciprocal(self):
    # Regression test for https://github.com/google/jax/issues/3136
    def inv(x):
      # N.B.: intentionally written as 1/x, not x ** -1 or reciprocal(x)
      return 1 / x
    grad_fn = jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(inv))))))
    self.assertAllClose(onp.float32(0.0439453125), grad_fn(onp.float32(4.)))
Example #17
0
 def testCumulativeReduceGrad(self, op, shape, dtype, axis, reverse):
   rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
                  else jtu.rand_small)
   rng = rng_factory(self.rng())
   check_grads(partial(op, axis=axis, reverse=reverse), (rng(shape, dtype),),
               order=2)
Example #18
0
 def testCumulativeReduce(self, op, shape, dtype, axis, bdims, reverse):
     rng_factory = (jtu.rand_default if dtypes.issubdtype(
         dtype, np.integer) else jtu.rand_small)
     rng = rng_factory(self.rng())
     self._CheckBatching(partial(op, axis=axis, reverse=reverse), 7, bdims,
                         (shape, ), (dtype, ), rng)
Example #19
0
 def is_float(c):
     return dtypes.issubdtype(dtypes.dtype(c), jnp.inexact)