Esempio n. 1
0
class LaxVmapTest(jtu.JaxTestCase):

  def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng,
                     rtol=None, atol=None):
    batched_shapes = map(partial(add_bdim, bdim_size), bdims, shapes)
    args = [rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes)]
    args_slice = args_slicer(args, bdims)
    ans = api.vmap(op, bdims)(*args)
    if bdim_size == 0:
      args = [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
      out = op(*args)
      expected = np.zeros((0,) + out.shape, out.dtype)
    else:
      expected = np.stack([op(*args_slice(i)) for i in range(bdim_size)])
    self.assertAllClose(ans, expected, rtol=rtol, atol=atol)

  @parameterized.named_parameters(itertools.chain.from_iterable(
      jtu.cases_from_list(
        {"testcase_name": "{}_bdims={}".format(
            jtu.format_test_name_suffix(rec.op, shapes,
                                        itertools.repeat(dtype)), bdims),
         "op_name": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes,
         "dtype": dtype, "bdims": bdims, "tol": rec.tol}
        for shape_group in compatible_shapes
        for shapes in itertools.combinations_with_replacement(shape_group, rec.nargs)
        for bdims in all_bdims(*shapes)
        for dtype in rec.dtypes)
      for rec in LAX_OPS))
  def testOp(self, op_name, rng_factory, shapes, dtype, bdims, tol):
    rng = rng_factory(self.rng())
    op = getattr(lax, op_name)
    self._CheckBatching(op, 10, bdims, shapes, [dtype] * len(shapes), rng,
                        atol=tol, rtol=tol)

  @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
       "testcase_name":
       "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
       "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}"
       "_lhs_bdim={}_rhs_bdim={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums),
               feature_group_count, batch_group_count, lhs_bdim, rhs_bdim),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "strides": strides, "padding": padding, "lhs_dil": lhs_dil,
       "rhs_dil": rhs_dil, "dimension_numbers": dim_nums,
       "perms": perms, "lhs_bdim": lhs_bdim, "rhs_bdim": rhs_bdim,
       "feature_group_count": feature_group_count,
       "batch_group_count": batch_group_count,
     } for batch_group_count, feature_group_count in s([(1, 1), (2, 1), (1, 2)])
       for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in s([
           ((b * batch_group_count, i * feature_group_count, 6, 7),  # lhs_shape
            (j * batch_group_count * feature_group_count, i, 1, 2),  # rhs_shape
            [(1, 1), (1, 2), (2, 1)],  # strides
            [((0, 0), (0, 0)), ((1, 0), (0, 1)), ((0, -1), (0, 0))],  # pads
            [(1, 1), (2, 1)],  # lhs_dils
            [(1, 1), (2, 2)])  # rhs_dils
           for b, i, j in itertools.product([1, 2], repeat=3)])
       for strides in s(all_strides)
       for rhs_dil in s(rhs_dils)
       for lhs_dil in s(lhs_dils)
       for dtype in s([np.float32])
       for padding in s(all_pads)
       for dim_nums, perms in s([
           (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
           (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
           (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))])
       for lhs_bdim in s(itertools.chain([cast(Optional[int], None)],
                                         range(len(lhs_shape) + 1)))
       for rhs_bdim in s(itertools.chain([cast(Optional[int], None)],
                                         range(len(rhs_shape) + 1)))
       if (lhs_bdim, rhs_bdim) != (None, None)
       )))
  def testConvGeneralDilatedBatching(
      self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil,
      dimension_numbers, perms, feature_group_count, batch_group_count,
      lhs_bdim, rhs_bdim):
    rng = jtu.rand_default(self.rng())
    tol = 1e-1 if dtypes.finfo(dtype).bits <= 32 else 1e-3

    # permute shapes to match dim_spec, scale by feature_group_count
    lhs_perm, rhs_perm = perms
    lhs_shape = list(np.take(lhs_shape, lhs_perm))
    rhs_shape = list(np.take(rhs_shape, rhs_perm))

    conv = partial(lax.conv_general_dilated, window_strides=strides,
                   padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
                   dimension_numbers=dimension_numbers,
                   feature_group_count=feature_group_count,
                   batch_group_count=batch_group_count,
                   precision=lax.Precision.HIGHEST)
    self._CheckBatching(conv, 5, (lhs_bdim, rhs_bdim), (lhs_shape, rhs_shape),
                        (dtype, dtype), rng, rtol=tol, atol=tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format(
          shape, from_dtype, to_dtype, bdims),
       "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype,
       "bdims": bdims}
      for from_dtype, to_dtype in itertools.product(
          [np.float32, np.int32, "float32", "int32"], repeat=2)
      for shape in [(2, 3)]
      for bdims in all_bdims(shape)))
  def testConvertElementType(self, shape, from_dtype, to_dtype, bdims):
    rng = jtu.rand_default(self.rng())
    op = lambda x: lax.convert_element_type(x, to_dtype)
    self._CheckBatching(op, 10, bdims, (shape,), (from_dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_nmant={}_nexp={}_bdims={}".format(
          jtu.format_shape_dtype_string(shape, dtype), nmant, nexp, bdims),
       "shape": shape, "dtype": dtype, "nmant": nmant, "nexp": nexp, "bdims": bdims}
      for dtype in float_dtypes
      for shape in [(2, 4)]
      for nexp in [1, 3, 5]
      for nmant in [0, 2, 4]
      for bdims in all_bdims(shape)))
  def testReducePrecision(self, shape, dtype, nmant, nexp, bdims):
    rng = jtu.rand_default(self.rng())
    op = lambda x: lax.reduce_precision(x, exponent_bits=nexp, mantissa_bits=nmant)
    self._CheckBatching(op, 10, bdims, (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format(
          shape, from_dtype, to_dtype, bdims),
       "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype,
       "bdims": bdims}
      for from_dtype, to_dtype in itertools.product(
          [np.float32, np.int32, "float32", "int32"], repeat=2)
      for shape in [(2, 3)]
      for bdims in all_bdims(shape)))
  def testBitcastElementType(self, shape, from_dtype, to_dtype, bdims,):
    rng = jtu.rand_default(self.rng())
    op = lambda x: lax.bitcast_convert_type(x, to_dtype)
    self._CheckBatching(op, 10, bdims, (shape,), (from_dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_min_shape={}_operand_shape={}_max_shape={}_bdims={}"
       .format(jtu.format_shape_dtype_string(min_shape, dtype),
               jtu.format_shape_dtype_string(operand_shape, dtype),
               jtu.format_shape_dtype_string(max_shape, dtype),
               bdims),
       "min_shape": min_shape, "operand_shape": operand_shape,
       "max_shape": max_shape, "dtype": dtype, "bdims": bdims}
      for min_shape, operand_shape, max_shape in [
          [(), (2, 3), ()],
          [(2, 3), (2, 3), ()],
          [(), (2, 3), (2, 3)],
          [(2, 3), (2, 3), (2, 3)],
      ]
      for dtype in default_dtypes
      for bdims in all_bdims(min_shape, operand_shape, max_shape)))
  def testClamp(self, min_shape, operand_shape, max_shape, dtype, bdims):
    rng = jtu.rand_default(self.rng())
    shapes = [min_shape, operand_shape, max_shape]
    self._CheckBatching(lax.clamp, 10, bdims, shapes, [dtype] * 3, rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_lhs_shape={}_rhs_shape={}_bdims={}".format(
          jtu.format_shape_dtype_string(lhs_shape, dtype),
          jtu.format_shape_dtype_string(rhs_shape, dtype),
          bdims),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "bdims": bdims}
      for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]
      for bdims in all_bdims(lhs_shape, rhs_shape)
      for dtype in default_dtypes))
  def testDot(self, lhs_shape, rhs_shape, dtype, bdims):
    rng = jtu.rand_default(self.rng())
    op = partial(lax.dot, precision=lax.Precision.HIGHEST)
    self._CheckBatching(op, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype),
                        rng, rtol={np.float16: 5e-2, np.float64: 5e-14})

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}_bdims={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               lhs_contracting, rhs_contracting, bdims),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting,
       "bdims": bdims}
      for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [
          [(5,), (5,), [0], [0]],
          [(5, 7), (5,), [0], [0]],
          [(7, 5), (5,), [1], [0]],
          [(3, 5), (2, 5), [1], [1]],
          [(5, 3), (5, 2), [0], [0]],
          [(5, 3, 2), (5, 2, 4), [0], [0]],
          [(5, 3, 2), (5, 2, 4), [0,2], [0,1]],
          [(5, 3, 2), (3, 5, 2, 4), [0,2], [1,2]],
          [(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]],
          [(3, 2), (2, 4), [1], [0]],
      ]
      for bdims in all_bdims(lhs_shape, rhs_shape)
      for dtype in default_dtypes))
  def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype,
                                 lhs_contracting, rhs_contracting, bdims):
    rng = jtu.rand_small(self.rng())
    dimension_numbers = ((lhs_contracting, rhs_contracting), ([], []))
    dot = partial(lax.dot_general, dimension_numbers=dimension_numbers)
    self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype),
                        rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_bdims={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               dimension_numbers, bdims),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "dimension_numbers": dimension_numbers, "bdims": bdims}
      for lhs_shape, rhs_shape, dimension_numbers in [
          ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
          ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))),
          ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
      ]
      for bdims in all_bdims(lhs_shape, rhs_shape)
      for dtype in default_dtypes))
  def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype,
                                     dimension_numbers, bdims):
    rng = jtu.rand_small(self.rng())
    dot = partial(lax.dot_general, dimension_numbers=dimension_numbers)
    self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype),
                        rng)

    # Checks that batching didn't introduce any transposes or broadcasts.
    jaxpr = api.make_jaxpr(dot)(np.zeros(lhs_shape, dtype),
                                np.zeros(rhs_shape, dtype))
    for eqn in jtu.iter_eqns(jaxpr.jaxpr):
      self.assertFalse(eqn.primitive in ["transpose", "broadcast"])

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}_bdims={}".format(
          shape, np.dtype(dtype).name, broadcast_sizes, bdims),
       "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes,
       "bdims": bdims}
      for shape in [(), (2, 3)]
      for dtype in default_dtypes
      for broadcast_sizes in [(), (2,), (1, 2)]
      for bdims in all_bdims(shape)))
  def testBroadcast(self, shape, dtype, broadcast_sizes, bdims):
    rng = jtu.rand_default(self.rng())
    op = lambda x: lax.broadcast(x, broadcast_sizes)
    self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_outshape={}_bcdims={}_bdims={}".format(
          jtu.format_shape_dtype_string(inshape, dtype),
          outshape, broadcast_dimensions, bdims),
       "inshape": inshape, "dtype": dtype, "outshape": outshape,
       "dimensions": broadcast_dimensions, "bdims": bdims}
      for inshape, outshape, broadcast_dimensions in [
          ([2], [2, 2], [0]),
          ([2], [2, 2], [1]),
          ([2], [2, 3], [0]),
          ([], [2, 3], []),
      ]
      for dtype in default_dtypes
      for bdims in all_bdims(inshape)))
  def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims):
    rng = jtu.rand_default(self.rng())
    raise SkipTest("this test has failures in some cases")  # TODO(mattjj)
    op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
    self._CheckBatching(op, 5, bdims, (inshape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_dimensions={}_bdims={}".format(
          jtu.format_shape_dtype_string(arg_shape, np.float32),
          dimensions, bdims),
       "arg_shape": arg_shape, "dimensions": dimensions, "bdims": bdims}
      for arg_shape, dimensions in [
          [(1,), (0,)],
          [(1,), (-1,)],
          [(2, 1, 4), (1,)],
          [(2, 1, 4), (-2,)],
          [(2, 1, 3, 1), (1,)],
          [(2, 1, 3, 1), (1, 3)],
          [(2, 1, 3, 1), (3,)],
          [(2, 1, 3, 1), (1, -1)],
      ]
      for bdims in all_bdims(arg_shape)))
  def testSqueeze(self, arg_shape, dimensions, bdims):
    dtype = np.float32
    rng = jtu.rand_default(self.rng())
    op = lambda x: lax.squeeze(x, dimensions)
    self._CheckBatching(op, 10, bdims, (arg_shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_outshape={}_dims={}_bdims={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          jtu.format_shape_dtype_string(out_shape, dtype),
          dimensions, bdims),
       "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
       "dimensions": dimensions, "bdims": bdims}
      for dtype in default_dtypes
      for arg_shape, dimensions, out_shape in [
          [(3, 4), None, (12,)],
          [(2, 1, 4), None, (8,)],
          [(2, 2, 4), None, (2, 8)],
          [(2, 2, 4), (0, 1, 2), (2, 8)],
          [(2, 2, 4), (1, 0, 2), (8, 2)],
          [(2, 2, 4), (2, 1, 0), (4, 2, 2)]
      ]
      for bdims in all_bdims(arg_shape)))
  def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims):
    rng = jtu.rand_default(self.rng())
    op = lambda x: lax.reshape(x, out_shape, dimensions=dimensions)
    self._CheckBatching(op, 10, bdims, (arg_shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_pads={}_bdims={}"
       .format(jtu.format_shape_dtype_string(shape, dtype), pads, bdims),
       "shape": shape, "dtype": dtype, "pads": pads, "bdims": bdims}
      for shape in [(2, 3)]
      for bdims in all_bdims(shape, ())
      for dtype in default_dtypes
      for pads in [[(1, 2, 1), (0, 1, 0)]]))
  def testPad(self, shape, dtype, pads, bdims):
    rng = jtu.rand_small(self.rng())
    fun = lambda operand, padding: lax.pad(operand, padding, pads)
    self._CheckBatching(fun, 5, bdims, (shape, ()), (dtype, dtype), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_predshape={}_argshapes={}_bdims={}".format(
          jtu.format_shape_dtype_string(pred_shape, np.bool_),
          jtu.format_shape_dtype_string(arg_shape, arg_dtype),
          bdims),
       "pred_shape": pred_shape, "arg_shape": arg_shape, "arg_dtype": arg_dtype,
       "bdims": bdims}
      for arg_shape in [(), (3,), (2, 3)]
      for pred_shape in ([(), arg_shape] if arg_shape else [()])
      for bdims in all_bdims(pred_shape, arg_shape, arg_shape)
      for arg_dtype in default_dtypes))
  def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims):
    rng = jtu.rand_default(self.rng())
    op = lambda c, x, y: lax.select(c < 0, x, y)
    self._CheckBatching(op, 5, bdims, (pred_shape, arg_shape, arg_shape,),
                        (np.bool_, arg_dtype, arg_dtype), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_start_indices={}_limit_indices={}_strides={}_bdims={}".format(
          jtu.format_shape_dtype_string(shape, dtype),
          start_indices, limit_indices, strides, bdims),
       "shape": shape, "dtype": dtype, "starts": start_indices,
       "limits": limit_indices, "strides": strides, "bdims": bdims}
      for shape, start_indices, limit_indices, strides in [
        [(3,), (1,), (2,), None],
        [(7,), (4,), (7,), None],
        [(5,), (1,), (5,), (2,)],
        [(8,), (1,), (6,), (2,)],
        [(5, 3), (1, 1), (3, 2), None],
        [(5, 3), (1, 1), (3, 1), None],
        [(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
        [(5, 3), (1, 1), (2, 1), (1, 1)],
        [(5, 3), (1, 1), (5, 3), (2, 1)],
      ]
      for bdims in all_bdims(shape)
      for dtype in default_dtypes))
  def testSlice(self, shape, dtype, starts, limits, strides, bdims):
    rng = jtu.rand_default(self.rng())
    op = lambda x: lax.slice(x, starts, limits, strides)
    self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_perm={}_bdims={}".format(
          jtu.format_shape_dtype_string(shape, dtype), perm, bdims),
       "shape": shape, "dtype": dtype, "perm": perm, "bdims": bdims}
      for shape, perm in [
        [(3, 4), (1, 0)],
        [(3, 4), (0, 1)],
        [(3, 4, 5), (2, 1, 0)],
        [(3, 4, 5), (1, 0, 2)],
      ]
      for bdims in all_bdims(shape)
      for dtype in default_dtypes))
  def testTranspose(self, shape, dtype, perm, bdims):
    rng = jtu.rand_default(self.rng())
    op = lambda x: lax.transpose(x, perm)
    self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_inshape={}_reducedims={}_initval={}_bdims={}"
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims,
               init_val, bdims),
       "op": op, "init_val": init_val, "shape": shape, "dtype": dtype,
       "dims": dims, "bdims": bdims}
      for init_val, op, dtypes in [
          (0, lax.add, default_dtypes),
          (1, lax.mul, default_dtypes),
          (0, lax.max, all_dtypes), # non-monoidal
          (-np.inf, lax.max, float_dtypes),
          (dtypes.iinfo(np.int32).min, lax.max, [np.int32]),
          (dtypes.iinfo(np.int64).min, lax.max, [np.int64]),
          (dtypes.iinfo(np.uint32).min, lax.max, [np.uint32]),
          (dtypes.iinfo(np.uint64).min, lax.max, [np.uint64]),
          (np.inf, lax.min, float_dtypes),
          (dtypes.iinfo(np.int32).max, lax.min, [np.int32]),
          (dtypes.iinfo(np.int64).max, lax.min, [np.int64]),
          (dtypes.iinfo(np.uint32).max, lax.min, [np.uint32]),
          (dtypes.iinfo(np.uint64).max, lax.min, [np.uint64]),
      ]
      for dtype in dtypes
      for shape, dims in [
          [(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)],
          [(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]
      ]
      for bdims in all_bdims(shape)))
  def testReduce(self, op, init_val, shape, dtype, dims, bdims):
    rng = jtu.rand_small(self.rng())
    init_val = np.asarray(init_val, dtype=dtype)
    fun = lambda operand: lax.reduce(operand, init_val, op, dims)
    self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_inshape={}_reducedims={}_bdims={}"
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dim,
               bdims),
       "op": op, "shape": shape, "dtype": dtype,
       "dim": dim, "bdims": bdims}
      for op in [lax.argmin, lax.argmax]
      for dtype in default_dtypes
      for shape in [(3, 4, 5)]
      for dim in range(len(shape))
      for bdims in all_bdims(shape)))
  def testArgminmax(self, op, shape, dtype, dim, bdims):
    rng = jtu.rand_default(self.rng())
    fun = lambda operand: op(operand, dim, np.int32)
    self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}"
                         "_basedilation={}_windowdilation={}")
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype),
               dims, strides, padding, base_dilation, window_dilation),
       "op": op, "init_val": init_val, "dtype": dtype, "shape": shape,
       "dims": dims, "strides": strides, "padding": padding,
       "base_dilation": base_dilation, "window_dilation": window_dilation}
      for init_val, op, dtypes in [
          (0, lax.add, [np.float32]),
          (-np.inf, lax.max, [np.float32]),
          (np.inf, lax.min, [np.float32]),
      ]
      for shape, dims, strides, padding, base_dilation, window_dilation in (
        itertools.chain(
          itertools.product(
            [(4, 6)],
            [(2, 1), (1, 2)],
            [(1, 1), (2, 1), (1, 2)],
            ["VALID", "SAME", [(0, 3), (1, 2)]],
            [(1, 1), (2, 3)],
            [(1, 1), (1, 2)]),
          itertools.product(
            [(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
            [(1, 2, 2, 1), (1, 1, 1, 1)],
            ["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]],
            [(1, 1, 1, 1), (2, 1, 3, 2)],
            [(1, 1, 1, 1), (1, 2, 2, 1)])))
      for dtype in dtypes))
  def testReduceWindow(self, op, init_val, dtype, shape, dims, strides, padding,
                       base_dilation, window_dilation):
    rng = jtu.rand_small(self.rng())
    init_val = np.asarray(init_val, dtype=dtype)

    def fun(operand):
      return lax.reduce_window(operand, init_val, op, dims, strides, padding,
                               base_dilation, window_dilation)

    for bdims in all_bdims(shape):
      self._CheckBatching(fun, 3, bdims, (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_shape={}_axis={}_bdims={}_reverse={}"
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis,
               bdims, reverse),
       "op": op, "shape": shape, "dtype": dtype, "bdims": bdims,
       "axis": axis, "reverse": reverse}
      for op, types in [
          (lax.cumsum, [np.float32, np.float64]),
          (lax.cumprod, [np.float32, np.float64]),
      ]
      for dtype in types
      for shape in [[10], [3, 4, 5]]
      for axis in range(len(shape))
      for bdims in all_bdims(shape)
      for reverse in [False, True]))
  def testCumulativeReduce(self, op, shape, dtype, axis, bdims, reverse):
    rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
                   else jtu.rand_small)
    rng = rng_factory(self.rng())
    self._CheckBatching(partial(op, axis=axis, reverse=reverse), 7, bdims,
                        (shape,), (dtype,), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_dtype={}_padding={}".format(np.dtype(dtype).name,
                                                      padding),
       "dtype": dtype, "padding": padding}
      for dtype in float_dtypes
      for padding in ["VALID", "SAME"]))
  @jtu.skip_on_flag("jax_skip_slow_tests", True)
  @jtu.ignore_warning(message="Using reduced precision for gradient.*")
  def testSelectAndGatherAdd(self, dtype, padding):
    if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16:
      raise SkipTest("bfloat16 _select_and_gather_add doesn't work on tpu")
    rng = jtu.rand_small(self.rng())
    all_configs = itertools.chain(
        itertools.product(
            [(4, 6)],
            [(2, 1), (1, 2)],
            [(1, 1), (2, 1), (1, 2)]),
        itertools.product(
            [(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
            [(1, 2, 2, 1), (1, 1, 1, 1)]))

    def fun(operand, tangents):
      pads = lax.padtype_to_pads(operand.shape, dims, strides, padding)
      ones = (1,) * len(operand.shape)
      return lax._select_and_gather_add(operand, tangents, lax.ge_p, dims,
                                        strides, pads, ones, ones)

    for shape, dims, strides in all_configs:
      for bdims in all_bdims(shape, shape):
        self._CheckBatching(fun, 3, bdims, (shape, shape), (dtype, dtype), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": f"_dtype={jtu.format_shape_dtype_string(shape, dtype)}"
      f"_padding={padding}_dims={dims}_strides={strides}",
       "dtype": dtype, "padding": padding, "shape": shape,
       "dims": dims, "strides": strides}
      for dtype in float_dtypes
      for padding in ["VALID", "SAME"]
      for shape in [(3, 2, 4, 6)]
      for dims in [(1, 1, 2, 1)]
      for strides in [(1, 2, 2, 1), (1, 1, 1, 1)]))
  def testSelectAndScatterAdd(self, dtype, padding, shape, dims, strides):
    rng = jtu.rand_small(self.rng())

    pads = lax.padtype_to_pads(shape, dims, strides, padding)

    def fun(operand, cotangents):
      return lax._select_and_scatter_add(operand, cotangents, lax.ge_p, dims,
                                         strides, pads)
    ones = (1,) * len(shape)
    cotangent_shape = api.eval_shape(
      lambda x: lax._select_and_gather_add(x, x, lax.ge_p, dims, strides,
                                           pads, ones, ones),
      np.ones(shape, dtype)).shape

    for bdims in all_bdims(cotangent_shape, shape):
      self._CheckBatching(fun, 3, bdims, (cotangent_shape, shape),
                          (dtype, dtype), rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_bdims={}_fft_ndims={}"
       .format(shape, bdims, fft_ndims),
       "shape": shape, "bdims": bdims, "fft_ndims": fft_ndims}
      for shape in [(5,), (3, 4, 5), (2, 3, 4, 5)]
      for bdims in all_bdims(shape)
      for fft_ndims in range(0, min(3, len(shape)) + 1)))
  @jtu.skip_on_devices("tpu")  # TODO(b/137993701): unimplemented cases.
  def testFft(self, fft_ndims, shape, bdims):
    rng = jtu.rand_default(self.rng())
    ndims = len(shape)
    axes = range(ndims - fft_ndims, ndims)
    fft_lengths = [shape[axis] for axis in axes]
    op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths)
    self._CheckBatching(op, 5, bdims, [shape], [np.complex64], rng)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}_bdims={}"
       .format(jtu.format_shape_dtype_string(shape, dtype), idxs, dnums,
               slice_sizes, bdims),
       "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
       "slice_sizes": slice_sizes, "bdims": bdims}
      for dtype in all_dtypes
      for shape, idxs, dnums, slice_sizes in [
          ((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers(
            offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1,)),
          ((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
            (2,)),
          ((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1, 3)),
          ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
            (1, 3)),
      ]
      for bdims in all_bdims(shape, idxs.shape)))
  def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims):
    fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
    self._CheckBatching(fun, 0, bdims, [shape, idxs.shape], [dtype, idxs.dtype],
                        jtu.rand_default(self.rng()))
    self._CheckBatching(fun, 5, bdims, [shape, idxs.shape], [dtype, idxs.dtype],
                        jtu.rand_default(self.rng()))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}_bdims={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums, bdims),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums, "bdims": bdims}
      for dtype in float_dtypes
      for arg_shape, idxs, update_shape, dnums in [
          ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
          ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,))),
          ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
      ]
      for bdims in all_bdims(arg_shape, idxs.shape, update_shape)))
  def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims):
    fun = partial(lax.scatter_add, dimension_numbers=dnums)
    self._CheckBatching(fun, 5, bdims, [arg_shape, idxs.shape, update_shape],
                        [dtype, idxs.dtype, dtype], jtu.rand_default(self.rng()),
                        rtol={np.float16: 5e-3, dtypes.bfloat16: 3e-2})

  def testShapeUsesBuiltinInt(self):
    x = lax.iota(np.int32, 3) + 1
    self.assertIsInstance(x.shape[0], int)  # not np.int64

  def testBroadcastShapesReturnsPythonInts(self):
    shape1, shape2 = (1, 2, 3), (2, 3)
    out_shape = lax.broadcast_shapes(shape1, shape2)
    self.assertTrue(all(type(s) is int for s in out_shape))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_k={}_bdims={}".format(
          jtu.format_shape_dtype_string(shape, dtype), k, bdims),
       "shape": shape, "dtype": dtype, "k": k, "bdims": bdims, "rng_factory": rng_factory}
      for shape in [(4,), (3, 5, 3)]
      for k in [1, 3]
      for bdims in all_bdims(shape)
      # TODO(b/155170120): test with repeats once the XLA:CPU stable top_k bug is fixed:
      # The top_k indices for integer arrays with identical entries won't match between
      # vmap'd version and manual reference, so only test unique integer arrays for int_dtypes.
      # Note also that we chose 3 * 5 * 3 * 5 such that it fits in the range of
      # values a bfloat16 can represent exactly to avoid ties.
      for dtype, rng_factory in itertools.chain(
        unsafe_zip(default_dtypes, itertools.repeat(jtu.rand_unique_int)))))
  def testTopK(self, shape, dtype, k, bdims, rng_factory):
    rng = rng_factory(self.rng())
    # _CheckBatching doesn't work with tuple outputs, so test outputs separately.
    op1 = lambda x: lax.top_k(x, k=k)[0]
    self._CheckBatching(op1, 5, bdims, (shape,), (dtype,), rng)
    op2 = lambda x: lax.top_k(x, k=k)[1]
    self._CheckBatching(op2, 5, bdims, (shape,), (dtype,), rng)


  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_dimension={}_arity={}_bdims={}_isstable={}"
       .format(jtu.format_shape_dtype_string(shape, np.float32), dimension,
               arity, bdims, is_stable),
       "shape": shape, "dimension": dimension, "arity": arity, "bdims": bdims,
       "is_stable": is_stable}
      for shape in [(2, 3)]
      for dimension in [0, 1]
      for arity in range(3)
      for bdims in all_bdims(*((shape,) * arity))
      for is_stable in [False, True]))
  def testSort(self, shape, dimension, arity, bdims, is_stable):
    rng = jtu.rand_default(self.rng())
    if arity == 1:
      fun = partial(lax.sort, dimension=dimension)
      self._CheckBatching(fun, 5, bdims, (shape,) * arity, (np.float32,) * arity,
                          rng)
    else:
      for i in range(arity):
        fun = lambda *args, i=i: lax.sort(args,
                                          dimension=dimension,
                                          is_stable=is_stable)[i]
        self._CheckBatching(fun, 5, bdims, (shape,) * arity,
                            (np.float32,) * arity, rng)
Esempio n. 2
0
class XMapTest(XMapTestCase):

  def testBasic(self):
    local_devices = list(jax.local_devices())
    if len(local_devices) < 4:
      raise SkipTest("Test requires at least 4 local devices")
    def f(a, b):
      return a * 2, b * 4
    devices = np.array(local_devices[:4]).reshape((2, 2))
    with mesh(devices, ('x', 'y')):
      fm = xmap(f,
                in_axes=[{0: 'a', 1: 'b'}, ['c', ...]],
                out_axes=[{0: 'a', 1: 'b'}, ['c', ...]],
                axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
      ashape = (16, 8, 5)
      a = jnp.arange(np.prod(ashape)).reshape(ashape)
      bshape = (2, 7)
      b = jnp.arange(np.prod(bshape)).reshape(bshape)
      c, d = fm(a, b)
      self.assertAllClose(c, a * 2)
      self.assertAllClose(d, b * 4)

  @jtu.with_mesh([('x', 2), ('y', 2)])
  def testCollectiveReduce(self):
    fm = xmap(lambda a, b: (lax.psum(a * 2, 'a'), b * 4),
              in_axes=[['a', 'b', ...], {0: 'c'}],
              out_axes=[['b', ...], {0: 'c'}],
              axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
    ashape = (16, 8, 5)
    a = jnp.arange(np.prod(ashape)).reshape(ashape)
    bshape = (2, 7)
    b = jnp.arange(np.prod(bshape)).reshape(bshape)
    c, d = fm(a, b)
    self.assertAllClose(c, (a * 2).sum(0))
    self.assertAllClose(d, b * 4)

  @jtu.with_mesh([('x', 2), ('y', 2)])
  def testCollectivePermute2D(self):
    perm = np.array([3, 1, 2, 0])
    x = jnp.arange(4).reshape((2, 2))
    result = xmap(lambda x: lax.pshuffle(x, ('i', 'j'), perm),
                  in_axes=['i', 'j', ...],
                  out_axes=['i', 'j', ...],
                  axis_resources={'i': 'x', 'j': 'y'})(x).reshape((-1,))
    self.assertAllClose(result, perm)

  def testCollectivePermute1D(self):
    perm = np.array([3, 1, 2, 0])
    x = jnp.arange(4)
    result = xmap(lambda x: lax.pshuffle(x, 'i', perm),
                  in_axes=['i', ...],
                  out_axes=['i', ...])(x)
    self.assertAllClose(result, perm)

  def testCollectiveAllGather(self):
    x = jnp.arange(4)
    result = xmap(lambda x: lax.all_gather(x, 'i') + lax.axis_index('i'),
                  in_axes=['i', ...], out_axes=['i', ...])(x)
    self.assertAllClose(result, x + x[jnp.newaxis].T)

  @jtu.with_mesh([('x', 2), ('y', 2)])
  def testOneLogicalTwoMeshAxesBasic(self):
    def f(v):
      return lax.psum(v * 2, 'a'), v * 4
    fm = xmap(f, in_axes=['a', ...], out_axes=[{}, {1: 'a'}],
              axis_resources={'a': ('x', 'y')})
    vshape = (4, 5)
    v = jnp.arange(np.prod(vshape)).reshape(vshape)
    ans, ans2 = fm(v)
    self.assertAllClose(ans, (v * 2).sum(0))
    self.assertAllClose(ans2, v.T * 4)

  @jtu.with_mesh([('x', 2), ('y', 2)])
  def testOneLogicalTwoMeshAxesSharding(self):
    def f(v):
      return v * 4
    fxy = xmap(f, in_axes=['a', ...], out_axes={1: 'a'},
               axis_resources={'a': ('x', 'y')})
    fyx = xmap(f, in_axes=['a', ...], out_axes={1: 'a'},
               axis_resources={'a': ('y', 'x')})
    vshape = (4, 5)
    v = jnp.arange(np.prod(vshape)).reshape(vshape)
    zxy = fxy(v)
    self.assertEqual(
        zxy.sharding_spec,
        pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
                          (pxla.ShardedAxis(0), pxla.ShardedAxis(1))))
    zyx = fyx(v)
    self.assertEqual(
        zyx.sharding_spec,
        pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
                          (pxla.ShardedAxis(1), pxla.ShardedAxis(0))))

  @jtu.with_mesh([('x', 2), ('y', 2)])
  def testSkipFirstMeshDim(self):
    def run(axis_resources):
      return xmap(lambda x: x * 2, in_axes=['i', ...], out_axes=['i', ...],
                  axis_resources=axis_resources)(jnp.ones((4,)))
    self.assertAllClose(run({'i': 'x'}), run({'i': 'y'}))

  def testCaching(self):
    def f(x):
      assert python_should_be_executing
      return x * 2
    devices = np.array(jax.local_devices()[:2])
    if devices.size < 2:
      raise SkipTest("Test requires 2 devices")
    x = np.arange(8).reshape((2, 2, 2))
    with mesh(devices, ('x',)):
      python_should_be_executing = True
      xmap(f, in_axes=['a', ...], out_axes=['a', ...],
           axis_resources={'a': 'x'})(x)
      python_should_be_executing = False
      xmap(f, in_axes=['a', ...], out_axes=['a', ...],
           axis_resources={'a': 'x'})(x)
    with mesh(devices, ('x',)):
      python_should_be_executing = False
      xmap(f, in_axes=['a', ...], out_axes=['a', ...],
           axis_resources={'a': 'x'})(x)

  @parameterized.named_parameters(
    {"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
    for name, mesh, axis_resources in (
      ('OneToOne', (('x', 2), ('y', 2)), (('a', 'y'), ('b', 'x'))),
      ('Multiple', (('x', 2), ('y', 2), ('z', 2)), (('a', 'y'), ('b', ('x', 'z')))),
    ))
  @jtu.with_mesh_from_kwargs
  def testNestedMesh(self, mesh, axis_resources):
    @partial(xmap, in_axes={1: 'a'}, out_axes=({0: 'a'}, {}),
              axis_resources=dict([axis_resources[0]]))
    def f(x):
      y = x * 2
      @partial(xmap, in_axes={0: 'b'}, out_axes=({1: 'b'}, {}),
               axis_resources=dict([axis_resources[1]]))
      def h(y):
        # Multiply by a constant array to better exercise the partial_eval rule
        return jnp.sin(y) * np.arange(y.size), lax.psum(y, ('a', 'b'))
      return h(y)

    xshape = (4, 2, 5)
    x = jnp.arange(np.prod(xshape)).reshape(xshape)
    y = f(x)
    self.assertAllClose(y, ((jnp.sin(x * 2) * np.arange(xshape[-1])).transpose((1, 2, 0)), (x * 2).sum((0, 1))))
    self.assertEqual(y[0].sharding_spec.sharding,
                     (pxla.Chunked([2]), pxla.NoSharding(), pxla.NoSharding()))
    self.assertEqual(y[0].sharding_spec.mesh_mapping,
                     (pxla.Replicated(2), pxla.ShardedAxis(0)) + (pxla.Replicated(2),) * (len(mesh) - 2))
    if maps.EXPERIMENTAL_SPMD_LOWERING:
      hlo = jax.xla_computation(f)(x).as_hlo_text()
      # Make sure that there are non-partial sharding specs in the HLO
      self.assertRegex(hlo, r"sharding={devices=\[[0-9,]+\][0-9,]+}")

  @jtu.with_and_without_mesh
  def testMultipleCalls(self, mesh, axis_resources):
    def f(x, y):
      assert x.shape == y.shape == (3, 5)
      return jnp.tensordot(x, y, axes=([1], [1]))

    f_mapped = xmap(f,
                    in_axes=(['i', ...], ['j', ...]),
                    out_axes=['i', 'j', ...],
                    axis_resources=dict(axis_resources))
    x = jnp.arange(30).reshape(2, 3, 5)
    expected = jnp.einsum('imk,jnk->ijmn', x, x)
    for i in range(10):
      self.assertAllClose(f_mapped(x, x), expected)

  @jtu.with_and_without_mesh
  @jtu.skip_on_devices("cpu")  # In/out aliasing not supported on CPU.
  def testBufferDonation(self, mesh, axis_resources):
    shard = lambda x: x
    if axis_resources:
      shard = xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...],
                   axis_resources=dict(axis_resources))
    f = xmap(lambda x, y: x + y * 4,
             in_axes=['i', ...], out_axes=['i', ...],
             axis_resources=dict(axis_resources),
             donate_argnums=0)
    # The multiplications below disable some optimizations that prevent reuse
    x = shard(jnp.zeros((2, 5)) * 4)
    y = shard(jnp.ones((2, 5)) * 2)
    f(x, y)
    self.assertNotDeleted(y)
    self.assertDeleted(x)

  def testControlFlow(self):
    x = jnp.arange(5)
    xmap(lambda x: lax.fori_loop(0, 10, lambda _, x: lax.psum(x, 'i'), x),
         in_axes=['i', ...], out_axes=['i', ...])(x)

  @jtu.with_and_without_mesh
  def testAxisSizes(self, mesh, axis_resources):
    result = xmap(lambda: lax.axis_index('i'),
                  in_axes=(), out_axes=['i', ...],
                  axis_sizes={'i': 6},
                  axis_resources=dict(axis_resources))()
    self.assertAllClose(result, jnp.arange(6, dtype=result.dtype))

  def testCollectiveOverNoName(self):
    result = xmap(lambda: lax.psum(jnp.array(2) ** 2, 'i'),
                  in_axes={}, out_axes={}, axis_sizes={'i': 4})()
    self.assertEqual(result, 16)

  def VmapOfXmapCases(s):
    xmap_in_axes = ([{}] +
                    [{i: 'x'} for i in range(3)] +
                    [{i: 'x', j: 'y'} for i in range(4) for j in range(4) if i != j])
    for xmap_dim_x, xmap_dim_y in s(product(xmap_in_axes, repeat=2)):
      xmap_axes = sorted(set(xmap_dim_x.values()) | set(xmap_dim_y.values()))
      num_axes = len(xmap_axes)
      if xmap_axes is None:
        continue
      xmap_out_axes = [dict(zip(dims, xmap_axes))
                       for dims in permutations(range(2 + num_axes), num_axes)]
      for xmap_dim_z in s(xmap_out_axes):
        for vmap_dim_x in s([*range(2 + len(xmap_dim_x)), None]):
          for vmap_dim_y in s([*range(2 + len(xmap_dim_y)), None]):
            if vmap_dim_x is None and vmap_dim_y is None:
              continue
            for vmap_dim_result in s(range(3)):
              for vmap_dim_z in s(range(2 + len(xmap_axes))):
                for vmap_as_xmap in s([False, True]):
                  yield {"testcase_name":
                             f"_xin={(sorted(xmap_dim_x.items()), sorted(xmap_dim_y.items()))}_"
                             f"xout={sorted(xmap_dim_z.items())}_vin={(vmap_dim_x, vmap_dim_y)}_"
                             f"vout={vmap_dim_z}_vresult={vmap_dim_result}_vmap_as_xmap={vmap_as_xmap}",
                         "xmap_in_axes": (xmap_dim_x, xmap_dim_y),
                         "xmap_out_axes": xmap_dim_z,
                         "vmap_in_axes": (vmap_dim_x, vmap_dim_y),
                         "vmap_out_axes": vmap_dim_z,
                         "vmap_result_axis": vmap_dim_result,
                         "vmap_as_xmap": vmap_as_xmap}

  @parameterized.named_parameters(jtu.named_cases_from_sampler(VmapOfXmapCases))
  def testNestedMap(self,
                    xmap_in_axes, xmap_out_axes,
                    vmap_in_axes, vmap_out_axes, vmap_result_axis,
                    vmap_as_xmap):
    """Test various vmap(xmap) and xmap(xmap) combinations.

    The outer map always introduces a single dimension, the inner map introduces one or two.
    """
    (xin_x, xin_y) = xmap_in_axes
    (vin_x, vin_y) = vmap_in_axes
    vmap_size = 7
    xmap_sizes = {'x': 11, 'y': 13}

    xshape = [2, 3]
    yshape = [3, 5]
    zshape = [2, 5]
    xind = ['n', 'k']
    yind = ['k', 'm']
    zind = ['n', 'm']
    f = lambda x, y: ensure_bdim(jnp.einsum('nk,km->nm', x, y), 'v', vmap_result_axis)

    for pos, name in sorted(xin_x.items()):
      xshape.insert(pos, xmap_sizes[name])
      xind.insert(pos, name)
    for pos, name in sorted(xin_y.items()):
      yshape.insert(pos, xmap_sizes[name])
      yind.insert(pos, name)
    for pos, name in sorted(xmap_out_axes.items()):
      zshape.insert(pos, xmap_sizes[name])
      zind.insert(pos, name)

    if vin_x is not None:
      xshape.insert(vin_x, vmap_size)
      xind.insert(vin_x, 'v')
    if vin_y is not None:
      yshape.insert(vin_y, vmap_size)
      yind.insert(vin_y, 'v')
    zshape.insert(vmap_out_axes, vmap_size)
    zind.insert(vmap_out_axes, 'v')

    if vmap_as_xmap:
      do_vmap = partial(xmap,
                        in_axes=({vin_x: 'v'} if vin_x is not None else {},
                                 {vin_y: 'v'} if vin_y is not None else {}),
                        out_axes={vmap_out_axes: 'v'})
    else:
      do_vmap = partial(vmap, in_axes=vmap_in_axes, out_axes=vmap_out_axes, axis_name='v')

    fm = do_vmap(xmap(f, in_axes=xmap_in_axes, out_axes=xmap_out_axes))
    fref = partial(jnp.einsum, f"{''.join(xind)},{''.join(yind)}->{''.join(zind)}")

    rng = np.random.RandomState(0)
    x = rng.randn(*xshape)
    y = rng.randn(*yshape)
    self.assertAllClose(fm(x, y), fref(x, y))

  def testJVP(self):
    f = xmap(lambda x, y: jnp.cos(lax.dot(x, jnp.sin(y),
                                          precision=lax.Precision.HIGHEST)),
             in_axes=[['i', ...], {}], out_axes=['i', ...])
    x = jnp.arange(12, dtype=jnp.float32).reshape((3, 4)) / 100
    y = jnp.arange(20, dtype=jnp.float32).reshape((4, 5)) / 100
    jtu.check_grads(f, (x, y), order=2, modes=['fwd'])

  @jtu.with_and_without_mesh
  def testNamedShape(self, mesh, axis_resources):
    x = np.arange(4,)
    y = 2
    f = xmap(lambda x, y: (x + y, y * lax.axis_index('i')),
             in_axes=(['i', ...], {}),
             out_axes=(['i', ...], ['i', ...]),
             axis_resources=dict(axis_resources))
    z, w = f(x, y)
    self.assertEqual(z.aval.named_shape, {})
    self.assertEqual(w.aval.named_shape, {})

  @jtu.with_and_without_mesh
  def testBroadcast(self, mesh, axis_resources):
    x = jnp.asarray(2.0)
    f = xmap(lambda x: x, in_axes={}, out_axes=['i'],
             axis_sizes={'i': 4}, axis_resources=dict(axis_resources))
    self.assertAllClose(f(x), jnp.asarray([2.0, 2.0, 2.0, 2.0]))

  def testNestedBroadcast(self):
    x = jnp.asarray(2.0)
    f = xmap(lambda x: x, in_axes={}, out_axes=['i'], axis_sizes={'i': 4})
    g = xmap(f, in_axes={}, out_axes=['j', ...], axis_sizes={'j': 7})
    self.assertAllClose(g(x), jnp.tile(x.reshape((1, 1)), (7, 4)))

  @loop('l', 4)
  def testLoopBasic(self):
    x = jnp.arange(16)
    y = xmap(lambda x: x + 4, in_axes=['i'], out_axes=['i'],
              axis_resources={'i': 'l'})(x)
    self.assertAllClose(y, x + 4)

  @jtu.with_mesh([('x', 2)])
  @loop('l', 4)
  def testLoopWithMesh(self):
    x = jnp.arange(16)
    y = xmap(lambda x: x + 4, in_axes=['i'], out_axes=['i'],
              axis_resources={'i': ('x', 'l')})(x)
    self.assertAllClose(y, x + 4)
Esempio n. 3
0
class PDotTests(XMapTestCase):

  @jtu.with_mesh([('r1', 2)])
  def testPdotBasic(self):
    def f(x, y):
      return lax.pdot(x, y, 'i')

    f_mapped = xmap(f,
                    in_axes=[{1: 'i'}, {0: 'i'}],
                    out_axes={},
                    axis_resources={'i': 'r1'})

    rng = np.random.RandomState(0)
    x = rng.randn(3, 8)
    y = rng.randn(8, 5)

    z = f_mapped(x, y)

    self.assertAllClose(z, jnp.dot(x, y))

  @jtu.with_mesh([('r1', 2)])
  def testPdotBatching(self):
    def f(x, y):
      return lax.pdot(x, y, 'i')

    rng = np.random.RandomState(0)
    x = rng.randn(2, 3, 8)
    y = rng.randn(2, 8, 5)

    f_mapped = xmap(f,
                    in_axes=[{0: 'j', 2: 'i'}, {0: 'j', 1: 'i'}],
                    out_axes=['j', ...],
                    axis_resources={'i': 'r1'})

    z = f_mapped(x, y)

    self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))

  @jtu.with_mesh([('r1', 2)])
  def testPdotBatchingShardUncontractedDim(self):
    def f(x, y):
      return lax.pdot(x, y, 'i')

    rng = np.random.RandomState(0)
    x = rng.randn(2, 3, 8)
    y = rng.randn(2, 8, 5)

    f_mapped = xmap(f,
                    in_axes=[{0: 'j', 2: 'i'}, {0: 'j', 1: 'i'}],
                    out_axes=['j', ...],
                    axis_resources={'j': 'r1'})

    z = f_mapped(x, y)

    self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))

  @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
       "testcase_name": f"_{next(test_counter)}",
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "pdot_spec": pdot_spec,
       "axis_resources": axis_resources, "mesh_data": mesh_data
    } for test_counter in [it.count()]
      for lhs_shape, rhs_shape in s(product([(2,), (2, 4, 2, 1)], repeat=2))
      for pdot_spec in s(all_pdot_specs(lhs_shape, rhs_shape))
      for axis_resources, mesh_data in s(schedules_from_pdot_spec(
          pdot_spec, lhs_shape, rhs_shape))
  )))
  def testPdotSystematic(self, lhs_shape, rhs_shape, pdot_spec, axis_resources,
                         mesh_data):
    rng = jtu.rand_default(self.rng())
    lhs = rng(lhs_shape, np.float32)
    rhs = rng(rhs_shape, np.float32)

    def pdot_fun(x, y):
      # print(f'pdot(x:{x.aval.str_short()}, y:{y.aval.str_short()},\n'
      #       f'     axis_name={contract_names},\n'
      #       f'     pos_contract={spec.pos_contract_after_mapping}\n'
      #       f'     pos_batch={spec.pos_batch_after_mapping})')
      return jax.lax.pdot(x, y, axis_name=pdot_spec.contract_names,
                          pos_batch=pdot_spec.pos_batch_after_mapping,
                          pos_contract=pdot_spec.pos_contract_after_mapping)

    fun = xmap(pdot_fun, in_axes=[pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes],
               out_axes=[*pdot_spec.batch_names, ...],
               axis_resources=axis_resources)

    with jtu.with_mesh(mesh_data):
      result = fun(lhs, rhs)

    expected = lax.dot_general(lhs, rhs, pdot_spec.dot_general_dim_nums)
    tol = 1e-1 if jtu.device_under_test() == "tpu" else None
    self.assertAllClose(result, expected, check_dtypes=False,
                        atol=tol, rtol=tol)

  @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
       "testcase_name": f"_{next(test_counter)}",
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "pdot_spec": pdot_spec,
       "axis_resources": axis_resources, "mesh_data": mesh_data
    } for test_counter in [it.count()]
      for lhs_shape, rhs_shape in s(product([(2,), (2, 4, 2, 1)], repeat=2))
      for pdot_spec in s(all_pdot_specs(lhs_shape, rhs_shape))
      for axis_resources, mesh_data in s(schedules_from_pdot_spec(
          pdot_spec, lhs_shape, rhs_shape))
  )))
  def testPdotVJPSystematic(self, lhs_shape, rhs_shape, pdot_spec,
                            axis_resources, mesh_data):
    rng = jtu.rand_default(self.rng())
    lhs = rng(lhs_shape, np.float32)
    rhs = rng(rhs_shape, np.float32)

    expected_out, ref_vjp = jax.vjp(
        lambda x, y: lax.dot_general(x, y, pdot_spec.dot_general_dim_nums),
        lhs, rhs)
    out_bar = rng(expected_out.shape, np.float32)
    expected_lhs, expected_rhs = ref_vjp(out_bar)

    def pdot_fun(x, y, out_bar):
      pdot = partial(jax.lax.pdot,
                     axis_name=pdot_spec.contract_names,
                     pos_batch=pdot_spec.pos_batch_after_mapping,
                     pos_contract=pdot_spec.pos_contract_after_mapping)
      _, pdot_vjp = jax.vjp(pdot, x, y)
      return pdot_vjp(out_bar)

    fun = xmap(pdot_fun,
               in_axes=[pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes,
                        [*pdot_spec.batch_names, ...]],
               out_axes=(pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes),
               axis_resources=axis_resources)

    with jtu.with_mesh(mesh_data):
      lhs_bar, rhs_bar = fun(lhs, rhs, out_bar)

    tol = 1e-1 if jtu.device_under_test() == "tpu" else None
    self.assertAllClose(lhs_bar, expected_lhs, check_dtypes=False,
                        atol=tol, rtol=tol)
    self.assertAllClose(rhs_bar, expected_rhs, check_dtypes=False,
                        atol=tol, rtol=tol)

  def test_xeinsum_vector_dot(self):
    rng = np.random.RandomState(0)
    x = rng.randn(3)
    y = rng.randn(3)
    out = xmap(partial(jnp.einsum, '{i},{i}->'),
               in_axes=(['i'], ['i']), out_axes=[])(x, y)
    expected = np.einsum('i,i->', x, y)
    self.assertAllClose(out, expected, check_dtypes=False)

  def test_xeinsum_outer_product(self):
    rng = np.random.RandomState(0)
    x = rng.randn(3)
    y = rng.randn(3)
    out = xmap(partial(jnp.einsum, '{i},{j}->{i,j}'),
               in_axes=(['i'], ['j']), out_axes=['i', 'j'])(x, y)
    expected = np.einsum('i,j->ij', x, y)
    self.assertAllClose(out, expected, check_dtypes=True)

  def test_xeinsum_matmul(self):
    rng = np.random.RandomState(0)
    x = rng.randn(3, 4)
    y = rng.randn(4, 5)

    def check(spec):
      out = xmap(partial(jnp.einsum, spec),
                 in_axes=(['i', 'j'], ['j', 'k']),
                 out_axes=['i', 'k'])(x, y)
      expected = np.einsum('ij,jk->ik', x, y)
      tol = 1e-1 if jtu.device_under_test() == "tpu" else None
      self.assertAllClose(out, expected, check_dtypes=True,
                          atol=tol, rtol=tol)
    check('{i,j},{j,k}->{i,k}')
    check('{i,j},{k,j}->{k,i}')  # order of named axes in the spec doesn't matter!
    check('{j},{k,j}->{k}')
    check('{i,j},{j}->{i}')
    check('{j},{j}->{}')

  def test_xeinsum_no_named_axes_vector_dot(self):
    rng = np.random.RandomState(0)
    x = rng.randn(3)
    y = rng.randn(3)
    out = jnp.einsum('i,i->', x, y, _use_xeinsum=True)
    expected = np.einsum('i,i->', x, y)
    self.assertAllClose(out, expected, check_dtypes=False)

  def test_xeinsum_no_named_axes_batch_vector_dot(self):
    rng = np.random.RandomState(0)
    x = rng.randn(3, 2)
    y = rng.randn(3, 2)
    out = jnp.einsum('ij,ij->i', x, y, _use_xeinsum=True)
    expected = np.einsum('ij,ij->i', x, y)
    self.assertAllClose(out, expected, check_dtypes=True)

  def test_xeinsum_no_named_axes_reduce_sum(self):
    rng = np.random.RandomState(0)
    x = rng.randn(3)
    y = rng.randn()
    out = jnp.einsum('i,->', x, y, _use_xeinsum=True)
    expected = np.einsum('i,->', x, y)
    self.assertAllClose(out, expected, check_dtypes=True)
Esempio n. 4
0
class IndexedUpdateTest(jtu.JaxTestCase):

  @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
      "testcase_name": "{}_inshape={}_indexer={}_update={}_sugared={}_op={}".format(
          name, jtu.format_shape_dtype_string(shape, dtype), indexer,
          jtu.format_shape_dtype_string(update_shape, update_dtype), sugared, op.name),
       "shape": shape, "dtype": dtype, "indexer": indexer,
       "update_shape": update_shape, "update_dtype": update_dtype,
       "op": op, "sugared": sugared
  } for name, index_specs in s(STATIC_INDEXING_TESTS)
    for shape, indexer in s(index_specs)
    for op in s(UpdateOps)
    for dtype in s(UpdateOps.dtypes(op))
    for update_shape in s(_broadcastable_shapes(_update_shape(shape, indexer)))
    for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)
    for sugared in (s([True, False]) if op not in [UpdateOps.DIV, UpdateOps.POW] else [True]))))
  def testStaticIndexing(self, shape, dtype, update_shape, update_dtype,
                         indexer, sugared, op):
    rng = jtu.rand_default(self.rng())
    args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
    np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
    if sugared:
      jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
    else:
      jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
    self._CheckAgainstNumpy(np_fn, jax_fn, args_maker,
                            tol={np.complex128: 1e-14})
    self._CompileAndCheck(jax_fn, args_maker)

  @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
      "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
          name, jtu.format_shape_dtype_string(shape, dtype), indexer,
          jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
       "shape": shape, "dtype": dtype, "indexer": indexer,
       "update_shape": update_shape, "update_dtype": update_dtype,
       "op": op
  } for name, index_specs in s(ADVANCED_INDEXING_TESTS_NO_REPEATS)
    for shape, indexer in s(index_specs)
    for op in s(UpdateOps)
    for dtype in s(UpdateOps.dtypes(op))
    for update_shape in s(_broadcastable_shapes(_update_shape(shape, indexer)))
    for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes))))
  def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
                           indexer, op):
    rng = jtu.rand_default(self.rng())
    args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
    np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
    jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y, unique_indices=True)
    self._CheckAgainstNumpy(np_fn, jax_fn, args_maker,
                            tol={np.complex128: 1e-14})
    self._CompileAndCheck(jax_fn, args_maker)

  @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
      "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
          name, jtu.format_shape_dtype_string(shape, dtype), indexer,
          jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
       "shape": shape, "dtype": dtype, "indexer": indexer,
       "update_shape": update_shape, "update_dtype": update_dtype,
       "op": op
  } for name, index_specs in s(ADVANCED_INDEXING_TESTS_NO_REPEATS_SORTED)
    for shape, indexer in s(index_specs)
    for op in s(UpdateOps)
    for dtype in s(UpdateOps.dtypes(op))
    for update_shape in s(_broadcastable_shapes(_update_shape(shape, indexer)))
    for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes))))
  def testAdvancedIndexingSorted(self, shape, dtype, update_shape, update_dtype,
                           indexer, op):
    rng = jtu.rand_default(self.rng())
    args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
    np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
    jax_fn = lambda x, y: UpdateOps.sugar_fn(
      op, indexer, x, y, indices_are_sorted=True, unique_indices=True)
    self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, check_dtypes=True,
                            tol={np.complex128: 1e-14})
    self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
      "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
          name, jtu.format_shape_dtype_string(shape, dtype), indexer,
          jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
       "shape": shape, "dtype": dtype, "indexer": indexer,
       "update_shape": update_shape, "update_dtype": update_dtype,
       "op": op
  } for name, index_specs in s(MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS)
    for shape, indexer in s(index_specs)
    for op in s(UpdateOps)
    for dtype in s(UpdateOps.dtypes(op))
    for update_shape in s(_broadcastable_shapes(_update_shape(shape, indexer)))
    for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes))))
  def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
                                indexer, op):
    rng = jtu.rand_default(self.rng())
    args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
    np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
    jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
    self._CheckAgainstNumpy(np_fn, jax_fn, args_maker,
                            tol={np.complex128: 1e-14})
    self._CompileAndCheck(jax_fn, args_maker)

  @parameterized.named_parameters(jtu.cases_from_list({
      "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
          name, jtu.format_shape_dtype_string(shape, dtype), indexer,
          jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
       "shape": shape, "dtype": dtype, "indexer": indexer,
       "update_shape": update_shape, "update_dtype": update_dtype,
       "op": op
  } for name, index_specs in STATIC_INDEXING_TESTS
    for shape, indexer in index_specs
    for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE]
    for dtype in float_dtypes
    for update_shape in _broadcastable_shapes(_update_shape(shape, indexer))
    for update_dtype in ([dtype] if op == UpdateOps.ADD else float_dtypes)))
  @jtu.skip_on_devices("tpu")  # TODO(mattjj,phawkins): tpu issues
  def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
                              indexer, op):
    rng = jtu.rand_default(self.rng())
    jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
    x = rng(shape, dtype)
    y = rng(update_shape, update_dtype)
    check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.)

  def testSegmentSumBehavior(self):
    # testAdvancedIndexing compares against NumPy, and as a result doesn't check
    # repeated indices. This test is just a simple manual check, based on
    # https://www.tensorflow.org/api_docs/python/tf/math/segment_sum
    data = np.array([5, 1, 7, 2, 3, 4, 1, 3])
    segment_ids = np.array([0, 0, 0, 1, 2, 2, 3, 3])

    ans = ops.index_add(np.zeros(np.max(segment_ids) + 1), segment_ids, data)
    expected = np.array([13, 2, 7, 4])
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testSegmentSum(self):
    data = jnp.array([5, 1, 7, 2, 3, 4, 1, 3])
    segment_ids = jnp.array([0, 0, 0, 1, 2, 2, 3, 3])

    # test with explicit num_segments
    ans = ops.segment_sum(data, segment_ids, num_segments=4)
    expected = jnp.array([13, 2, 7, 4])
    self.assertAllClose(ans, expected, check_dtypes=False)

    # test with explicit num_segments larger than the higher index.
    ans = ops.segment_sum(data, segment_ids, num_segments=5)
    expected = jnp.array([13, 2, 7, 4, 0])
    self.assertAllClose(ans, expected, check_dtypes=False)

    # test without explicit num_segments
    ans = ops.segment_sum(data, segment_ids)
    expected = jnp.array([13, 2, 7, 4])
    self.assertAllClose(ans, expected, check_dtypes=False)

    # test with negative segment ids and segment ids larger than num_segments,
    # that will be wrapped with the `mod`.
    segment_ids = jnp.array([0, 4, 8, 1, 2, -6, -1, 3])
    ans = ops.segment_sum(data, segment_ids, num_segments=4)
    expected = jnp.array([5, 2, 3, 3])
    self.assertAllClose(ans, expected, check_dtypes=False)

    # test with negative segment ids and without without explicit num_segments
    # such as num_segments is defined by the smaller index.
    segment_ids = jnp.array([3, 3, 3, 4, 5, 5, -7, -6])
    ans = ops.segment_sum(data, segment_ids)
    expected = jnp.array([0, 0, 0, 13, 2, 7])
    self.assertAllClose(ans, expected, check_dtypes=False)


  @parameterized.named_parameters(itertools.chain.from_iterable(
      jtu.cases_from_list({
        "testcase_name": "_{}_{}_num_segments={}_bucket_size={}".format(
          jtu.format_shape_dtype_string(shape, dtype),
          reducer.__name__, num_segments, bucket_size),
        "dtype": dtype, "shape": shape,
        "reducer": reducer, "op": op, "identity": identity,
        "num_segments": num_segments, "bucket_size": bucket_size}
      for dtype in default_dtypes
      for shape in [(8,), (7, 4), (6, 4, 2)]
      for bucket_size in [None, 2]
      for num_segments in [None, 1, 3])
    for reducer, op, identity in [
      (ops.segment_sum, np.add, 0),
      (ops.segment_prod, np.multiply, 1),
      (ops.segment_min, np.minimum, float('inf')),
      (ops.segment_max, np.maximum, -float('inf')),
    ]))
  def testSegmentReduce(self, shape, dtype, reducer, op, identity, num_segments, bucket_size):
    rng = jtu.rand_default(self.rng())
    idx_rng = jtu.rand_int(self.rng(), low=-2, high=3)
    args_maker = lambda: [rng(shape, dtype), idx_rng(shape[:1], jnp.int32)]

    if np.issubdtype(dtype, np.integer):
      if np.isposinf(identity):
        identity = np.iinfo(dtype).max
      elif np.isneginf(identity):
        identity = np.iinfo(dtype).min

    jnp_fun = lambda data, segment_ids: reducer(
      data, segment_ids, num_segments=num_segments, bucket_size=bucket_size)

    def np_fun(data, segment_ids):
      size = num_segments if num_segments is not None else (segment_ids.max() + 1)
      out = np.full((size,) + shape[1:], identity, dtype)
      for i, val in zip(segment_ids, data):
        if 0 <= i < size:
          out[i] = op(out[i], val).astype(dtype)
      return out

    self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
    if num_segments is not None:
      self._CompileAndCheck(jnp_fun, args_maker)

  def testIndexDtypeError(self):
    # https://github.com/google/jax/issues/2795
    jnp.array(1)  # get rid of startup warning
    with warnings.catch_warnings(record=True) as w:
      warnings.simplefilter("error")
      jnp.zeros(5).at[::2].set(1)
      self.assertLen(w, 0)

  @contextmanager
  def assertNoWarnings(self):
    with warnings.catch_warnings(record=True) as caught_warnings:
      yield
    self.assertEmpty(caught_warnings)

  @parameterized.named_parameters(jtu.cases_from_list({
      "testcase_name": "idx={}".format(idx), "idx": idx, "idx_type": idx_type}
    for idx, idx_type in [
      ([0], "array"),
      ([0, 0], "array"),
      ([[0, 0]], "tuple"),
      ([0, [0, 1]], "tuple"),
      ([0, np.arange(2)], "tuple"),
      ([0, None], "tuple"),
      ([0, slice(None)], "tuple"),
    ]))
  def testIndexSequenceDeprecation(self, idx, idx_type):
    normalize = {"array": np.array, "tuple": tuple}[idx_type]
    msg = {"array": ARRAY_MSG, "tuple": TUPLE_MSG}[idx_type]
    x = jnp.arange(6).reshape(3, 2)

    with self.assertRaisesRegex(TypeError, msg):
      x[idx]
    with self.assertNoWarnings():
      x[normalize(idx)]

    with self.assertRaisesRegex(TypeError, msg):
      x.at[idx].set(0)
    with self.assertNoWarnings():
      x.at[normalize(idx)].set(0)