class LaxVmapTest(jtu.JaxTestCase): def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng, rtol=None, atol=None): batched_shapes = map(partial(add_bdim, bdim_size), bdims, shapes) args = [rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes)] args_slice = args_slicer(args, bdims) ans = api.vmap(op, bdims)(*args) if bdim_size == 0: args = [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] out = op(*args) expected = np.zeros((0,) + out.shape, out.dtype) else: expected = np.stack([op(*args_slice(i)) for i in range(bdim_size)]) self.assertAllClose(ans, expected, rtol=rtol, atol=atol) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": "{}_bdims={}".format( jtu.format_test_name_suffix(rec.op, shapes, itertools.repeat(dtype)), bdims), "op_name": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype, "bdims": bdims, "tol": rec.tol} for shape_group in compatible_shapes for shapes in itertools.combinations_with_replacement(shape_group, rec.nargs) for bdims in all_bdims(*shapes) for dtype in rec.dtypes) for rec in LAX_OPS)) def testOp(self, op_name, rng_factory, shapes, dtype, bdims, tol): rng = rng_factory(self.rng()) op = getattr(lax, op_name) self._CheckBatching(op, 10, bdims, shapes, [dtype] * len(shapes), rng, atol=tol, rtol=tol) @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ "testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}" "_lhs_bdim={}_rhs_bdim={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums), feature_group_count, batch_group_count, lhs_bdim, rhs_bdim), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "strides": strides, "padding": padding, "lhs_dil": lhs_dil, "rhs_dil": rhs_dil, "dimension_numbers": dim_nums, "perms": perms, "lhs_bdim": lhs_bdim, "rhs_bdim": rhs_bdim, "feature_group_count": feature_group_count, "batch_group_count": batch_group_count, } for batch_group_count, feature_group_count in s([(1, 1), (2, 1), (1, 2)]) for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in s([ ((b * batch_group_count, i * feature_group_count, 6, 7), # lhs_shape (j * batch_group_count * feature_group_count, i, 1, 2), # rhs_shape [(1, 1), (1, 2), (2, 1)], # strides [((0, 0), (0, 0)), ((1, 0), (0, 1)), ((0, -1), (0, 0))], # pads [(1, 1), (2, 1)], # lhs_dils [(1, 1), (2, 2)]) # rhs_dils for b, i, j in itertools.product([1, 2], repeat=3)]) for strides in s(all_strides) for rhs_dil in s(rhs_dils) for lhs_dil in s(lhs_dils) for dtype in s([np.float32]) for padding in s(all_pads) for dim_nums, perms in s([ (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])), (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])), (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))]) for lhs_bdim in s(itertools.chain([cast(Optional[int], None)], range(len(lhs_shape) + 1))) for rhs_bdim in s(itertools.chain([cast(Optional[int], None)], range(len(rhs_shape) + 1))) if (lhs_bdim, rhs_bdim) != (None, None) ))) def testConvGeneralDilatedBatching( self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil, dimension_numbers, perms, feature_group_count, batch_group_count, lhs_bdim, rhs_bdim): rng = jtu.rand_default(self.rng()) tol = 1e-1 if dtypes.finfo(dtype).bits <= 32 else 1e-3 # permute shapes to match dim_spec, scale by feature_group_count lhs_perm, rhs_perm = perms lhs_shape = list(np.take(lhs_shape, lhs_perm)) rhs_shape = list(np.take(rhs_shape, rhs_perm)) conv = partial(lax.conv_general_dilated, window_strides=strides, padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision=lax.Precision.HIGHEST) self._CheckBatching(conv, 5, (lhs_bdim, rhs_bdim), (lhs_shape, rhs_shape), (dtype, dtype), rng, rtol=tol, atol=tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format( shape, from_dtype, to_dtype, bdims), "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype, "bdims": bdims} for from_dtype, to_dtype in itertools.product( [np.float32, np.int32, "float32", "int32"], repeat=2) for shape in [(2, 3)] for bdims in all_bdims(shape))) def testConvertElementType(self, shape, from_dtype, to_dtype, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.convert_element_type(x, to_dtype) self._CheckBatching(op, 10, bdims, (shape,), (from_dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_nmant={}_nexp={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), nmant, nexp, bdims), "shape": shape, "dtype": dtype, "nmant": nmant, "nexp": nexp, "bdims": bdims} for dtype in float_dtypes for shape in [(2, 4)] for nexp in [1, 3, 5] for nmant in [0, 2, 4] for bdims in all_bdims(shape))) def testReducePrecision(self, shape, dtype, nmant, nexp, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.reduce_precision(x, exponent_bits=nexp, mantissa_bits=nmant) self._CheckBatching(op, 10, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format( shape, from_dtype, to_dtype, bdims), "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype, "bdims": bdims} for from_dtype, to_dtype in itertools.product( [np.float32, np.int32, "float32", "int32"], repeat=2) for shape in [(2, 3)] for bdims in all_bdims(shape))) def testBitcastElementType(self, shape, from_dtype, to_dtype, bdims,): rng = jtu.rand_default(self.rng()) op = lambda x: lax.bitcast_convert_type(x, to_dtype) self._CheckBatching(op, 10, bdims, (shape,), (from_dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_min_shape={}_operand_shape={}_max_shape={}_bdims={}" .format(jtu.format_shape_dtype_string(min_shape, dtype), jtu.format_shape_dtype_string(operand_shape, dtype), jtu.format_shape_dtype_string(max_shape, dtype), bdims), "min_shape": min_shape, "operand_shape": operand_shape, "max_shape": max_shape, "dtype": dtype, "bdims": bdims} for min_shape, operand_shape, max_shape in [ [(), (2, 3), ()], [(2, 3), (2, 3), ()], [(), (2, 3), (2, 3)], [(2, 3), (2, 3), (2, 3)], ] for dtype in default_dtypes for bdims in all_bdims(min_shape, operand_shape, max_shape))) def testClamp(self, min_shape, operand_shape, max_shape, dtype, bdims): rng = jtu.rand_default(self.rng()) shapes = [min_shape, operand_shape, max_shape] self._CheckBatching(lax.clamp, 10, bdims, shapes, [dtype] * 3, rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_bdims={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), bdims), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "bdims": bdims} for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)] for bdims in all_bdims(lhs_shape, rhs_shape) for dtype in default_dtypes)) def testDot(self, lhs_shape, rhs_shape, dtype, bdims): rng = jtu.rand_default(self.rng()) op = partial(lax.dot, precision=lax.Precision.HIGHEST) self._CheckBatching(op, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng, rtol={np.float16: 5e-2, np.float64: 5e-14}) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}_bdims={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), lhs_contracting, rhs_contracting, bdims), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting, "bdims": bdims} for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [ [(5,), (5,), [0], [0]], [(5, 7), (5,), [0], [0]], [(7, 5), (5,), [1], [0]], [(3, 5), (2, 5), [1], [1]], [(5, 3), (5, 2), [0], [0]], [(5, 3, 2), (5, 2, 4), [0], [0]], [(5, 3, 2), (5, 2, 4), [0,2], [0,1]], [(5, 3, 2), (3, 5, 2, 4), [0,2], [1,2]], [(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]], [(3, 2), (2, 4), [1], [0]], ] for bdims in all_bdims(lhs_shape, rhs_shape) for dtype in default_dtypes)) def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype, lhs_contracting, rhs_contracting, bdims): rng = jtu.rand_small(self.rng()) dimension_numbers = ((lhs_contracting, rhs_contracting), ([], [])) dot = partial(lax.dot_general, dimension_numbers=dimension_numbers) self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_bdims={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), dimension_numbers, bdims), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "dimension_numbers": dimension_numbers, "bdims": bdims} for lhs_shape, rhs_shape, dimension_numbers in [ ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))), ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))), ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))), ] for bdims in all_bdims(lhs_shape, rhs_shape) for dtype in default_dtypes)) def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype, dimension_numbers, bdims): rng = jtu.rand_small(self.rng()) dot = partial(lax.dot_general, dimension_numbers=dimension_numbers) self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng) # Checks that batching didn't introduce any transposes or broadcasts. jaxpr = api.make_jaxpr(dot)(np.zeros(lhs_shape, dtype), np.zeros(rhs_shape, dtype)) for eqn in jtu.iter_eqns(jaxpr.jaxpr): self.assertFalse(eqn.primitive in ["transpose", "broadcast"]) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}_bdims={}".format( shape, np.dtype(dtype).name, broadcast_sizes, bdims), "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes, "bdims": bdims} for shape in [(), (2, 3)] for dtype in default_dtypes for broadcast_sizes in [(), (2,), (1, 2)] for bdims in all_bdims(shape))) def testBroadcast(self, shape, dtype, broadcast_sizes, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.broadcast(x, broadcast_sizes) self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}_bcdims={}_bdims={}".format( jtu.format_shape_dtype_string(inshape, dtype), outshape, broadcast_dimensions, bdims), "inshape": inshape, "dtype": dtype, "outshape": outshape, "dimensions": broadcast_dimensions, "bdims": bdims} for inshape, outshape, broadcast_dimensions in [ ([2], [2, 2], [0]), ([2], [2, 2], [1]), ([2], [2, 3], [0]), ([], [2, 3], []), ] for dtype in default_dtypes for bdims in all_bdims(inshape))) def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims): rng = jtu.rand_default(self.rng()) raise SkipTest("this test has failures in some cases") # TODO(mattjj) op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions) self._CheckBatching(op, 5, bdims, (inshape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_dimensions={}_bdims={}".format( jtu.format_shape_dtype_string(arg_shape, np.float32), dimensions, bdims), "arg_shape": arg_shape, "dimensions": dimensions, "bdims": bdims} for arg_shape, dimensions in [ [(1,), (0,)], [(1,), (-1,)], [(2, 1, 4), (1,)], [(2, 1, 4), (-2,)], [(2, 1, 3, 1), (1,)], [(2, 1, 3, 1), (1, 3)], [(2, 1, 3, 1), (3,)], [(2, 1, 3, 1), (1, -1)], ] for bdims in all_bdims(arg_shape))) def testSqueeze(self, arg_shape, dimensions, bdims): dtype = np.float32 rng = jtu.rand_default(self.rng()) op = lambda x: lax.squeeze(x, dimensions) self._CheckBatching(op, 10, bdims, (arg_shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}_dims={}_bdims={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), jtu.format_shape_dtype_string(out_shape, dtype), dimensions, bdims), "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, "dimensions": dimensions, "bdims": bdims} for dtype in default_dtypes for arg_shape, dimensions, out_shape in [ [(3, 4), None, (12,)], [(2, 1, 4), None, (8,)], [(2, 2, 4), None, (2, 8)], [(2, 2, 4), (0, 1, 2), (2, 8)], [(2, 2, 4), (1, 0, 2), (8, 2)], [(2, 2, 4), (2, 1, 0), (4, 2, 2)] ] for bdims in all_bdims(arg_shape))) def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.reshape(x, out_shape, dimensions=dimensions) self._CheckBatching(op, 10, bdims, (arg_shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_pads={}_bdims={}" .format(jtu.format_shape_dtype_string(shape, dtype), pads, bdims), "shape": shape, "dtype": dtype, "pads": pads, "bdims": bdims} for shape in [(2, 3)] for bdims in all_bdims(shape, ()) for dtype in default_dtypes for pads in [[(1, 2, 1), (0, 1, 0)]])) def testPad(self, shape, dtype, pads, bdims): rng = jtu.rand_small(self.rng()) fun = lambda operand, padding: lax.pad(operand, padding, pads) self._CheckBatching(fun, 5, bdims, (shape, ()), (dtype, dtype), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_predshape={}_argshapes={}_bdims={}".format( jtu.format_shape_dtype_string(pred_shape, np.bool_), jtu.format_shape_dtype_string(arg_shape, arg_dtype), bdims), "pred_shape": pred_shape, "arg_shape": arg_shape, "arg_dtype": arg_dtype, "bdims": bdims} for arg_shape in [(), (3,), (2, 3)] for pred_shape in ([(), arg_shape] if arg_shape else [()]) for bdims in all_bdims(pred_shape, arg_shape, arg_shape) for arg_dtype in default_dtypes)) def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims): rng = jtu.rand_default(self.rng()) op = lambda c, x, y: lax.select(c < 0, x, y) self._CheckBatching(op, 5, bdims, (pred_shape, arg_shape, arg_shape,), (np.bool_, arg_dtype, arg_dtype), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_start_indices={}_limit_indices={}_strides={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), start_indices, limit_indices, strides, bdims), "shape": shape, "dtype": dtype, "starts": start_indices, "limits": limit_indices, "strides": strides, "bdims": bdims} for shape, start_indices, limit_indices, strides in [ [(3,), (1,), (2,), None], [(7,), (4,), (7,), None], [(5,), (1,), (5,), (2,)], [(8,), (1,), (6,), (2,)], [(5, 3), (1, 1), (3, 2), None], [(5, 3), (1, 1), (3, 1), None], [(7, 5, 3), (4, 0, 1), (7, 1, 3), None], [(5, 3), (1, 1), (2, 1), (1, 1)], [(5, 3), (1, 1), (5, 3), (2, 1)], ] for bdims in all_bdims(shape) for dtype in default_dtypes)) def testSlice(self, shape, dtype, starts, limits, strides, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.slice(x, starts, limits, strides) self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_perm={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), perm, bdims), "shape": shape, "dtype": dtype, "perm": perm, "bdims": bdims} for shape, perm in [ [(3, 4), (1, 0)], [(3, 4), (0, 1)], [(3, 4, 5), (2, 1, 0)], [(3, 4, 5), (1, 0, 2)], ] for bdims in all_bdims(shape) for dtype in default_dtypes)) def testTranspose(self, shape, dtype, perm, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.transpose(x, perm) self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_inshape={}_reducedims={}_initval={}_bdims={}" .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims, init_val, bdims), "op": op, "init_val": init_val, "shape": shape, "dtype": dtype, "dims": dims, "bdims": bdims} for init_val, op, dtypes in [ (0, lax.add, default_dtypes), (1, lax.mul, default_dtypes), (0, lax.max, all_dtypes), # non-monoidal (-np.inf, lax.max, float_dtypes), (dtypes.iinfo(np.int32).min, lax.max, [np.int32]), (dtypes.iinfo(np.int64).min, lax.max, [np.int64]), (dtypes.iinfo(np.uint32).min, lax.max, [np.uint32]), (dtypes.iinfo(np.uint64).min, lax.max, [np.uint64]), (np.inf, lax.min, float_dtypes), (dtypes.iinfo(np.int32).max, lax.min, [np.int32]), (dtypes.iinfo(np.int64).max, lax.min, [np.int64]), (dtypes.iinfo(np.uint32).max, lax.min, [np.uint32]), (dtypes.iinfo(np.uint64).max, lax.min, [np.uint64]), ] for dtype in dtypes for shape, dims in [ [(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)], [(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)] ] for bdims in all_bdims(shape))) def testReduce(self, op, init_val, shape, dtype, dims, bdims): rng = jtu.rand_small(self.rng()) init_val = np.asarray(init_val, dtype=dtype) fun = lambda operand: lax.reduce(operand, init_val, op, dims) self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_inshape={}_reducedims={}_bdims={}" .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dim, bdims), "op": op, "shape": shape, "dtype": dtype, "dim": dim, "bdims": bdims} for op in [lax.argmin, lax.argmax] for dtype in default_dtypes for shape in [(3, 4, 5)] for dim in range(len(shape)) for bdims in all_bdims(shape))) def testArgminmax(self, op, shape, dtype, dim, bdims): rng = jtu.rand_default(self.rng()) fun = lambda operand: op(operand, dim, np.int32) self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}" "_basedilation={}_windowdilation={}") .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims, strides, padding, base_dilation, window_dilation), "op": op, "init_val": init_val, "dtype": dtype, "shape": shape, "dims": dims, "strides": strides, "padding": padding, "base_dilation": base_dilation, "window_dilation": window_dilation} for init_val, op, dtypes in [ (0, lax.add, [np.float32]), (-np.inf, lax.max, [np.float32]), (np.inf, lax.min, [np.float32]), ] for shape, dims, strides, padding, base_dilation, window_dilation in ( itertools.chain( itertools.product( [(4, 6)], [(2, 1), (1, 2)], [(1, 1), (2, 1), (1, 2)], ["VALID", "SAME", [(0, 3), (1, 2)]], [(1, 1), (2, 3)], [(1, 1), (1, 2)]), itertools.product( [(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], [(1, 2, 2, 1), (1, 1, 1, 1)], ["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]], [(1, 1, 1, 1), (2, 1, 3, 2)], [(1, 1, 1, 1), (1, 2, 2, 1)]))) for dtype in dtypes)) def testReduceWindow(self, op, init_val, dtype, shape, dims, strides, padding, base_dilation, window_dilation): rng = jtu.rand_small(self.rng()) init_val = np.asarray(init_val, dtype=dtype) def fun(operand): return lax.reduce_window(operand, init_val, op, dims, strides, padding, base_dilation, window_dilation) for bdims in all_bdims(shape): self._CheckBatching(fun, 3, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_shape={}_axis={}_bdims={}_reverse={}" .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis, bdims, reverse), "op": op, "shape": shape, "dtype": dtype, "bdims": bdims, "axis": axis, "reverse": reverse} for op, types in [ (lax.cumsum, [np.float32, np.float64]), (lax.cumprod, [np.float32, np.float64]), ] for dtype in types for shape in [[10], [3, 4, 5]] for axis in range(len(shape)) for bdims in all_bdims(shape) for reverse in [False, True])) def testCumulativeReduce(self, op, shape, dtype, axis, bdims, reverse): rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer) else jtu.rand_small) rng = rng_factory(self.rng()) self._CheckBatching(partial(op, axis=axis, reverse=reverse), 7, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}_padding={}".format(np.dtype(dtype).name, padding), "dtype": dtype, "padding": padding} for dtype in float_dtypes for padding in ["VALID", "SAME"])) @jtu.skip_on_flag("jax_skip_slow_tests", True) @jtu.ignore_warning(message="Using reduced precision for gradient.*") def testSelectAndGatherAdd(self, dtype, padding): if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16: raise SkipTest("bfloat16 _select_and_gather_add doesn't work on tpu") rng = jtu.rand_small(self.rng()) all_configs = itertools.chain( itertools.product( [(4, 6)], [(2, 1), (1, 2)], [(1, 1), (2, 1), (1, 2)]), itertools.product( [(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], [(1, 2, 2, 1), (1, 1, 1, 1)])) def fun(operand, tangents): pads = lax.padtype_to_pads(operand.shape, dims, strides, padding) ones = (1,) * len(operand.shape) return lax._select_and_gather_add(operand, tangents, lax.ge_p, dims, strides, pads, ones, ones) for shape, dims, strides in all_configs: for bdims in all_bdims(shape, shape): self._CheckBatching(fun, 3, bdims, (shape, shape), (dtype, dtype), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": f"_dtype={jtu.format_shape_dtype_string(shape, dtype)}" f"_padding={padding}_dims={dims}_strides={strides}", "dtype": dtype, "padding": padding, "shape": shape, "dims": dims, "strides": strides} for dtype in float_dtypes for padding in ["VALID", "SAME"] for shape in [(3, 2, 4, 6)] for dims in [(1, 1, 2, 1)] for strides in [(1, 2, 2, 1), (1, 1, 1, 1)])) def testSelectAndScatterAdd(self, dtype, padding, shape, dims, strides): rng = jtu.rand_small(self.rng()) pads = lax.padtype_to_pads(shape, dims, strides, padding) def fun(operand, cotangents): return lax._select_and_scatter_add(operand, cotangents, lax.ge_p, dims, strides, pads) ones = (1,) * len(shape) cotangent_shape = api.eval_shape( lambda x: lax._select_and_gather_add(x, x, lax.ge_p, dims, strides, pads, ones, ones), np.ones(shape, dtype)).shape for bdims in all_bdims(cotangent_shape, shape): self._CheckBatching(fun, 3, bdims, (cotangent_shape, shape), (dtype, dtype), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_bdims={}_fft_ndims={}" .format(shape, bdims, fft_ndims), "shape": shape, "bdims": bdims, "fft_ndims": fft_ndims} for shape in [(5,), (3, 4, 5), (2, 3, 4, 5)] for bdims in all_bdims(shape) for fft_ndims in range(0, min(3, len(shape)) + 1))) @jtu.skip_on_devices("tpu") # TODO(b/137993701): unimplemented cases. def testFft(self, fft_ndims, shape, bdims): rng = jtu.rand_default(self.rng()) ndims = len(shape) axes = range(ndims - fft_ndims, ndims) fft_lengths = [shape[axis] for axis in axes] op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths) self._CheckBatching(op, 5, bdims, [shape], [np.complex64], rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}_bdims={}" .format(jtu.format_shape_dtype_string(shape, dtype), idxs, dnums, slice_sizes, bdims), "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "bdims": bdims} for dtype in all_dtypes for shape, idxs, dnums, slice_sizes in [ ((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)), ((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)), ((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)), ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), ] for bdims in all_bdims(shape, idxs.shape))) def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) self._CheckBatching(fun, 0, bdims, [shape, idxs.shape], [dtype, idxs.dtype], jtu.rand_default(self.rng())) self._CheckBatching(fun, 5, bdims, [shape, idxs.shape], [dtype, idxs.dtype], jtu.rand_default(self.rng())) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}_bdims={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums, bdims), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "bdims": bdims} for dtype in float_dtypes for arg_shape, idxs, update_shape, dnums in [ ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ] for bdims in all_bdims(arg_shape, idxs.shape, update_shape))) def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims): fun = partial(lax.scatter_add, dimension_numbers=dnums) self._CheckBatching(fun, 5, bdims, [arg_shape, idxs.shape, update_shape], [dtype, idxs.dtype, dtype], jtu.rand_default(self.rng()), rtol={np.float16: 5e-3, dtypes.bfloat16: 3e-2}) def testShapeUsesBuiltinInt(self): x = lax.iota(np.int32, 3) + 1 self.assertIsInstance(x.shape[0], int) # not np.int64 def testBroadcastShapesReturnsPythonInts(self): shape1, shape2 = (1, 2, 3), (2, 3) out_shape = lax.broadcast_shapes(shape1, shape2) self.assertTrue(all(type(s) is int for s in out_shape)) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_k={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), k, bdims), "shape": shape, "dtype": dtype, "k": k, "bdims": bdims, "rng_factory": rng_factory} for shape in [(4,), (3, 5, 3)] for k in [1, 3] for bdims in all_bdims(shape) # TODO(b/155170120): test with repeats once the XLA:CPU stable top_k bug is fixed: # The top_k indices for integer arrays with identical entries won't match between # vmap'd version and manual reference, so only test unique integer arrays for int_dtypes. # Note also that we chose 3 * 5 * 3 * 5 such that it fits in the range of # values a bfloat16 can represent exactly to avoid ties. for dtype, rng_factory in itertools.chain( unsafe_zip(default_dtypes, itertools.repeat(jtu.rand_unique_int))))) def testTopK(self, shape, dtype, k, bdims, rng_factory): rng = rng_factory(self.rng()) # _CheckBatching doesn't work with tuple outputs, so test outputs separately. op1 = lambda x: lax.top_k(x, k=k)[0] self._CheckBatching(op1, 5, bdims, (shape,), (dtype,), rng) op2 = lambda x: lax.top_k(x, k=k)[1] self._CheckBatching(op2, 5, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_dimension={}_arity={}_bdims={}_isstable={}" .format(jtu.format_shape_dtype_string(shape, np.float32), dimension, arity, bdims, is_stable), "shape": shape, "dimension": dimension, "arity": arity, "bdims": bdims, "is_stable": is_stable} for shape in [(2, 3)] for dimension in [0, 1] for arity in range(3) for bdims in all_bdims(*((shape,) * arity)) for is_stable in [False, True])) def testSort(self, shape, dimension, arity, bdims, is_stable): rng = jtu.rand_default(self.rng()) if arity == 1: fun = partial(lax.sort, dimension=dimension) self._CheckBatching(fun, 5, bdims, (shape,) * arity, (np.float32,) * arity, rng) else: for i in range(arity): fun = lambda *args, i=i: lax.sort(args, dimension=dimension, is_stable=is_stable)[i] self._CheckBatching(fun, 5, bdims, (shape,) * arity, (np.float32,) * arity, rng)
class XMapTest(XMapTestCase): def testBasic(self): local_devices = list(jax.local_devices()) if len(local_devices) < 4: raise SkipTest("Test requires at least 4 local devices") def f(a, b): return a * 2, b * 4 devices = np.array(local_devices[:4]).reshape((2, 2)) with mesh(devices, ('x', 'y')): fm = xmap(f, in_axes=[{0: 'a', 1: 'b'}, ['c', ...]], out_axes=[{0: 'a', 1: 'b'}, ['c', ...]], axis_resources={'a': 'x', 'b': 'y', 'c': 'x'}) ashape = (16, 8, 5) a = jnp.arange(np.prod(ashape)).reshape(ashape) bshape = (2, 7) b = jnp.arange(np.prod(bshape)).reshape(bshape) c, d = fm(a, b) self.assertAllClose(c, a * 2) self.assertAllClose(d, b * 4) @jtu.with_mesh([('x', 2), ('y', 2)]) def testCollectiveReduce(self): fm = xmap(lambda a, b: (lax.psum(a * 2, 'a'), b * 4), in_axes=[['a', 'b', ...], {0: 'c'}], out_axes=[['b', ...], {0: 'c'}], axis_resources={'a': 'x', 'b': 'y', 'c': 'x'}) ashape = (16, 8, 5) a = jnp.arange(np.prod(ashape)).reshape(ashape) bshape = (2, 7) b = jnp.arange(np.prod(bshape)).reshape(bshape) c, d = fm(a, b) self.assertAllClose(c, (a * 2).sum(0)) self.assertAllClose(d, b * 4) @jtu.with_mesh([('x', 2), ('y', 2)]) def testCollectivePermute2D(self): perm = np.array([3, 1, 2, 0]) x = jnp.arange(4).reshape((2, 2)) result = xmap(lambda x: lax.pshuffle(x, ('i', 'j'), perm), in_axes=['i', 'j', ...], out_axes=['i', 'j', ...], axis_resources={'i': 'x', 'j': 'y'})(x).reshape((-1,)) self.assertAllClose(result, perm) def testCollectivePermute1D(self): perm = np.array([3, 1, 2, 0]) x = jnp.arange(4) result = xmap(lambda x: lax.pshuffle(x, 'i', perm), in_axes=['i', ...], out_axes=['i', ...])(x) self.assertAllClose(result, perm) def testCollectiveAllGather(self): x = jnp.arange(4) result = xmap(lambda x: lax.all_gather(x, 'i') + lax.axis_index('i'), in_axes=['i', ...], out_axes=['i', ...])(x) self.assertAllClose(result, x + x[jnp.newaxis].T) @jtu.with_mesh([('x', 2), ('y', 2)]) def testOneLogicalTwoMeshAxesBasic(self): def f(v): return lax.psum(v * 2, 'a'), v * 4 fm = xmap(f, in_axes=['a', ...], out_axes=[{}, {1: 'a'}], axis_resources={'a': ('x', 'y')}) vshape = (4, 5) v = jnp.arange(np.prod(vshape)).reshape(vshape) ans, ans2 = fm(v) self.assertAllClose(ans, (v * 2).sum(0)) self.assertAllClose(ans2, v.T * 4) @jtu.with_mesh([('x', 2), ('y', 2)]) def testOneLogicalTwoMeshAxesSharding(self): def f(v): return v * 4 fxy = xmap(f, in_axes=['a', ...], out_axes={1: 'a'}, axis_resources={'a': ('x', 'y')}) fyx = xmap(f, in_axes=['a', ...], out_axes={1: 'a'}, axis_resources={'a': ('y', 'x')}) vshape = (4, 5) v = jnp.arange(np.prod(vshape)).reshape(vshape) zxy = fxy(v) self.assertEqual( zxy.sharding_spec, pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))), (pxla.ShardedAxis(0), pxla.ShardedAxis(1)))) zyx = fyx(v) self.assertEqual( zyx.sharding_spec, pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))), (pxla.ShardedAxis(1), pxla.ShardedAxis(0)))) @jtu.with_mesh([('x', 2), ('y', 2)]) def testSkipFirstMeshDim(self): def run(axis_resources): return xmap(lambda x: x * 2, in_axes=['i', ...], out_axes=['i', ...], axis_resources=axis_resources)(jnp.ones((4,))) self.assertAllClose(run({'i': 'x'}), run({'i': 'y'})) def testCaching(self): def f(x): assert python_should_be_executing return x * 2 devices = np.array(jax.local_devices()[:2]) if devices.size < 2: raise SkipTest("Test requires 2 devices") x = np.arange(8).reshape((2, 2, 2)) with mesh(devices, ('x',)): python_should_be_executing = True xmap(f, in_axes=['a', ...], out_axes=['a', ...], axis_resources={'a': 'x'})(x) python_should_be_executing = False xmap(f, in_axes=['a', ...], out_axes=['a', ...], axis_resources={'a': 'x'})(x) with mesh(devices, ('x',)): python_should_be_executing = False xmap(f, in_axes=['a', ...], out_axes=['a', ...], axis_resources={'a': 'x'})(x) @parameterized.named_parameters( {"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources} for name, mesh, axis_resources in ( ('OneToOne', (('x', 2), ('y', 2)), (('a', 'y'), ('b', 'x'))), ('Multiple', (('x', 2), ('y', 2), ('z', 2)), (('a', 'y'), ('b', ('x', 'z')))), )) @jtu.with_mesh_from_kwargs def testNestedMesh(self, mesh, axis_resources): @partial(xmap, in_axes={1: 'a'}, out_axes=({0: 'a'}, {}), axis_resources=dict([axis_resources[0]])) def f(x): y = x * 2 @partial(xmap, in_axes={0: 'b'}, out_axes=({1: 'b'}, {}), axis_resources=dict([axis_resources[1]])) def h(y): # Multiply by a constant array to better exercise the partial_eval rule return jnp.sin(y) * np.arange(y.size), lax.psum(y, ('a', 'b')) return h(y) xshape = (4, 2, 5) x = jnp.arange(np.prod(xshape)).reshape(xshape) y = f(x) self.assertAllClose(y, ((jnp.sin(x * 2) * np.arange(xshape[-1])).transpose((1, 2, 0)), (x * 2).sum((0, 1)))) self.assertEqual(y[0].sharding_spec.sharding, (pxla.Chunked([2]), pxla.NoSharding(), pxla.NoSharding())) self.assertEqual(y[0].sharding_spec.mesh_mapping, (pxla.Replicated(2), pxla.ShardedAxis(0)) + (pxla.Replicated(2),) * (len(mesh) - 2)) if maps.EXPERIMENTAL_SPMD_LOWERING: hlo = jax.xla_computation(f)(x).as_hlo_text() # Make sure that there are non-partial sharding specs in the HLO self.assertRegex(hlo, r"sharding={devices=\[[0-9,]+\][0-9,]+}") @jtu.with_and_without_mesh def testMultipleCalls(self, mesh, axis_resources): def f(x, y): assert x.shape == y.shape == (3, 5) return jnp.tensordot(x, y, axes=([1], [1])) f_mapped = xmap(f, in_axes=(['i', ...], ['j', ...]), out_axes=['i', 'j', ...], axis_resources=dict(axis_resources)) x = jnp.arange(30).reshape(2, 3, 5) expected = jnp.einsum('imk,jnk->ijmn', x, x) for i in range(10): self.assertAllClose(f_mapped(x, x), expected) @jtu.with_and_without_mesh @jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU. def testBufferDonation(self, mesh, axis_resources): shard = lambda x: x if axis_resources: shard = xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...], axis_resources=dict(axis_resources)) f = xmap(lambda x, y: x + y * 4, in_axes=['i', ...], out_axes=['i', ...], axis_resources=dict(axis_resources), donate_argnums=0) # The multiplications below disable some optimizations that prevent reuse x = shard(jnp.zeros((2, 5)) * 4) y = shard(jnp.ones((2, 5)) * 2) f(x, y) self.assertNotDeleted(y) self.assertDeleted(x) def testControlFlow(self): x = jnp.arange(5) xmap(lambda x: lax.fori_loop(0, 10, lambda _, x: lax.psum(x, 'i'), x), in_axes=['i', ...], out_axes=['i', ...])(x) @jtu.with_and_without_mesh def testAxisSizes(self, mesh, axis_resources): result = xmap(lambda: lax.axis_index('i'), in_axes=(), out_axes=['i', ...], axis_sizes={'i': 6}, axis_resources=dict(axis_resources))() self.assertAllClose(result, jnp.arange(6, dtype=result.dtype)) def testCollectiveOverNoName(self): result = xmap(lambda: lax.psum(jnp.array(2) ** 2, 'i'), in_axes={}, out_axes={}, axis_sizes={'i': 4})() self.assertEqual(result, 16) def VmapOfXmapCases(s): xmap_in_axes = ([{}] + [{i: 'x'} for i in range(3)] + [{i: 'x', j: 'y'} for i in range(4) for j in range(4) if i != j]) for xmap_dim_x, xmap_dim_y in s(product(xmap_in_axes, repeat=2)): xmap_axes = sorted(set(xmap_dim_x.values()) | set(xmap_dim_y.values())) num_axes = len(xmap_axes) if xmap_axes is None: continue xmap_out_axes = [dict(zip(dims, xmap_axes)) for dims in permutations(range(2 + num_axes), num_axes)] for xmap_dim_z in s(xmap_out_axes): for vmap_dim_x in s([*range(2 + len(xmap_dim_x)), None]): for vmap_dim_y in s([*range(2 + len(xmap_dim_y)), None]): if vmap_dim_x is None and vmap_dim_y is None: continue for vmap_dim_result in s(range(3)): for vmap_dim_z in s(range(2 + len(xmap_axes))): for vmap_as_xmap in s([False, True]): yield {"testcase_name": f"_xin={(sorted(xmap_dim_x.items()), sorted(xmap_dim_y.items()))}_" f"xout={sorted(xmap_dim_z.items())}_vin={(vmap_dim_x, vmap_dim_y)}_" f"vout={vmap_dim_z}_vresult={vmap_dim_result}_vmap_as_xmap={vmap_as_xmap}", "xmap_in_axes": (xmap_dim_x, xmap_dim_y), "xmap_out_axes": xmap_dim_z, "vmap_in_axes": (vmap_dim_x, vmap_dim_y), "vmap_out_axes": vmap_dim_z, "vmap_result_axis": vmap_dim_result, "vmap_as_xmap": vmap_as_xmap} @parameterized.named_parameters(jtu.named_cases_from_sampler(VmapOfXmapCases)) def testNestedMap(self, xmap_in_axes, xmap_out_axes, vmap_in_axes, vmap_out_axes, vmap_result_axis, vmap_as_xmap): """Test various vmap(xmap) and xmap(xmap) combinations. The outer map always introduces a single dimension, the inner map introduces one or two. """ (xin_x, xin_y) = xmap_in_axes (vin_x, vin_y) = vmap_in_axes vmap_size = 7 xmap_sizes = {'x': 11, 'y': 13} xshape = [2, 3] yshape = [3, 5] zshape = [2, 5] xind = ['n', 'k'] yind = ['k', 'm'] zind = ['n', 'm'] f = lambda x, y: ensure_bdim(jnp.einsum('nk,km->nm', x, y), 'v', vmap_result_axis) for pos, name in sorted(xin_x.items()): xshape.insert(pos, xmap_sizes[name]) xind.insert(pos, name) for pos, name in sorted(xin_y.items()): yshape.insert(pos, xmap_sizes[name]) yind.insert(pos, name) for pos, name in sorted(xmap_out_axes.items()): zshape.insert(pos, xmap_sizes[name]) zind.insert(pos, name) if vin_x is not None: xshape.insert(vin_x, vmap_size) xind.insert(vin_x, 'v') if vin_y is not None: yshape.insert(vin_y, vmap_size) yind.insert(vin_y, 'v') zshape.insert(vmap_out_axes, vmap_size) zind.insert(vmap_out_axes, 'v') if vmap_as_xmap: do_vmap = partial(xmap, in_axes=({vin_x: 'v'} if vin_x is not None else {}, {vin_y: 'v'} if vin_y is not None else {}), out_axes={vmap_out_axes: 'v'}) else: do_vmap = partial(vmap, in_axes=vmap_in_axes, out_axes=vmap_out_axes, axis_name='v') fm = do_vmap(xmap(f, in_axes=xmap_in_axes, out_axes=xmap_out_axes)) fref = partial(jnp.einsum, f"{''.join(xind)},{''.join(yind)}->{''.join(zind)}") rng = np.random.RandomState(0) x = rng.randn(*xshape) y = rng.randn(*yshape) self.assertAllClose(fm(x, y), fref(x, y)) def testJVP(self): f = xmap(lambda x, y: jnp.cos(lax.dot(x, jnp.sin(y), precision=lax.Precision.HIGHEST)), in_axes=[['i', ...], {}], out_axes=['i', ...]) x = jnp.arange(12, dtype=jnp.float32).reshape((3, 4)) / 100 y = jnp.arange(20, dtype=jnp.float32).reshape((4, 5)) / 100 jtu.check_grads(f, (x, y), order=2, modes=['fwd']) @jtu.with_and_without_mesh def testNamedShape(self, mesh, axis_resources): x = np.arange(4,) y = 2 f = xmap(lambda x, y: (x + y, y * lax.axis_index('i')), in_axes=(['i', ...], {}), out_axes=(['i', ...], ['i', ...]), axis_resources=dict(axis_resources)) z, w = f(x, y) self.assertEqual(z.aval.named_shape, {}) self.assertEqual(w.aval.named_shape, {}) @jtu.with_and_without_mesh def testBroadcast(self, mesh, axis_resources): x = jnp.asarray(2.0) f = xmap(lambda x: x, in_axes={}, out_axes=['i'], axis_sizes={'i': 4}, axis_resources=dict(axis_resources)) self.assertAllClose(f(x), jnp.asarray([2.0, 2.0, 2.0, 2.0])) def testNestedBroadcast(self): x = jnp.asarray(2.0) f = xmap(lambda x: x, in_axes={}, out_axes=['i'], axis_sizes={'i': 4}) g = xmap(f, in_axes={}, out_axes=['j', ...], axis_sizes={'j': 7}) self.assertAllClose(g(x), jnp.tile(x.reshape((1, 1)), (7, 4))) @loop('l', 4) def testLoopBasic(self): x = jnp.arange(16) y = xmap(lambda x: x + 4, in_axes=['i'], out_axes=['i'], axis_resources={'i': 'l'})(x) self.assertAllClose(y, x + 4) @jtu.with_mesh([('x', 2)]) @loop('l', 4) def testLoopWithMesh(self): x = jnp.arange(16) y = xmap(lambda x: x + 4, in_axes=['i'], out_axes=['i'], axis_resources={'i': ('x', 'l')})(x) self.assertAllClose(y, x + 4)
class PDotTests(XMapTestCase): @jtu.with_mesh([('r1', 2)]) def testPdotBasic(self): def f(x, y): return lax.pdot(x, y, 'i') f_mapped = xmap(f, in_axes=[{1: 'i'}, {0: 'i'}], out_axes={}, axis_resources={'i': 'r1'}) rng = np.random.RandomState(0) x = rng.randn(3, 8) y = rng.randn(8, 5) z = f_mapped(x, y) self.assertAllClose(z, jnp.dot(x, y)) @jtu.with_mesh([('r1', 2)]) def testPdotBatching(self): def f(x, y): return lax.pdot(x, y, 'i') rng = np.random.RandomState(0) x = rng.randn(2, 3, 8) y = rng.randn(2, 8, 5) f_mapped = xmap(f, in_axes=[{0: 'j', 2: 'i'}, {0: 'j', 1: 'i'}], out_axes=['j', ...], axis_resources={'i': 'r1'}) z = f_mapped(x, y) self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y)) @jtu.with_mesh([('r1', 2)]) def testPdotBatchingShardUncontractedDim(self): def f(x, y): return lax.pdot(x, y, 'i') rng = np.random.RandomState(0) x = rng.randn(2, 3, 8) y = rng.randn(2, 8, 5) f_mapped = xmap(f, in_axes=[{0: 'j', 2: 'i'}, {0: 'j', 1: 'i'}], out_axes=['j', ...], axis_resources={'j': 'r1'}) z = f_mapped(x, y) self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y)) @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ "testcase_name": f"_{next(test_counter)}", "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "pdot_spec": pdot_spec, "axis_resources": axis_resources, "mesh_data": mesh_data } for test_counter in [it.count()] for lhs_shape, rhs_shape in s(product([(2,), (2, 4, 2, 1)], repeat=2)) for pdot_spec in s(all_pdot_specs(lhs_shape, rhs_shape)) for axis_resources, mesh_data in s(schedules_from_pdot_spec( pdot_spec, lhs_shape, rhs_shape)) ))) def testPdotSystematic(self, lhs_shape, rhs_shape, pdot_spec, axis_resources, mesh_data): rng = jtu.rand_default(self.rng()) lhs = rng(lhs_shape, np.float32) rhs = rng(rhs_shape, np.float32) def pdot_fun(x, y): # print(f'pdot(x:{x.aval.str_short()}, y:{y.aval.str_short()},\n' # f' axis_name={contract_names},\n' # f' pos_contract={spec.pos_contract_after_mapping}\n' # f' pos_batch={spec.pos_batch_after_mapping})') return jax.lax.pdot(x, y, axis_name=pdot_spec.contract_names, pos_batch=pdot_spec.pos_batch_after_mapping, pos_contract=pdot_spec.pos_contract_after_mapping) fun = xmap(pdot_fun, in_axes=[pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes], out_axes=[*pdot_spec.batch_names, ...], axis_resources=axis_resources) with jtu.with_mesh(mesh_data): result = fun(lhs, rhs) expected = lax.dot_general(lhs, rhs, pdot_spec.dot_general_dim_nums) tol = 1e-1 if jtu.device_under_test() == "tpu" else None self.assertAllClose(result, expected, check_dtypes=False, atol=tol, rtol=tol) @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ "testcase_name": f"_{next(test_counter)}", "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "pdot_spec": pdot_spec, "axis_resources": axis_resources, "mesh_data": mesh_data } for test_counter in [it.count()] for lhs_shape, rhs_shape in s(product([(2,), (2, 4, 2, 1)], repeat=2)) for pdot_spec in s(all_pdot_specs(lhs_shape, rhs_shape)) for axis_resources, mesh_data in s(schedules_from_pdot_spec( pdot_spec, lhs_shape, rhs_shape)) ))) def testPdotVJPSystematic(self, lhs_shape, rhs_shape, pdot_spec, axis_resources, mesh_data): rng = jtu.rand_default(self.rng()) lhs = rng(lhs_shape, np.float32) rhs = rng(rhs_shape, np.float32) expected_out, ref_vjp = jax.vjp( lambda x, y: lax.dot_general(x, y, pdot_spec.dot_general_dim_nums), lhs, rhs) out_bar = rng(expected_out.shape, np.float32) expected_lhs, expected_rhs = ref_vjp(out_bar) def pdot_fun(x, y, out_bar): pdot = partial(jax.lax.pdot, axis_name=pdot_spec.contract_names, pos_batch=pdot_spec.pos_batch_after_mapping, pos_contract=pdot_spec.pos_contract_after_mapping) _, pdot_vjp = jax.vjp(pdot, x, y) return pdot_vjp(out_bar) fun = xmap(pdot_fun, in_axes=[pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes, [*pdot_spec.batch_names, ...]], out_axes=(pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes), axis_resources=axis_resources) with jtu.with_mesh(mesh_data): lhs_bar, rhs_bar = fun(lhs, rhs, out_bar) tol = 1e-1 if jtu.device_under_test() == "tpu" else None self.assertAllClose(lhs_bar, expected_lhs, check_dtypes=False, atol=tol, rtol=tol) self.assertAllClose(rhs_bar, expected_rhs, check_dtypes=False, atol=tol, rtol=tol) def test_xeinsum_vector_dot(self): rng = np.random.RandomState(0) x = rng.randn(3) y = rng.randn(3) out = xmap(partial(jnp.einsum, '{i},{i}->'), in_axes=(['i'], ['i']), out_axes=[])(x, y) expected = np.einsum('i,i->', x, y) self.assertAllClose(out, expected, check_dtypes=False) def test_xeinsum_outer_product(self): rng = np.random.RandomState(0) x = rng.randn(3) y = rng.randn(3) out = xmap(partial(jnp.einsum, '{i},{j}->{i,j}'), in_axes=(['i'], ['j']), out_axes=['i', 'j'])(x, y) expected = np.einsum('i,j->ij', x, y) self.assertAllClose(out, expected, check_dtypes=True) def test_xeinsum_matmul(self): rng = np.random.RandomState(0) x = rng.randn(3, 4) y = rng.randn(4, 5) def check(spec): out = xmap(partial(jnp.einsum, spec), in_axes=(['i', 'j'], ['j', 'k']), out_axes=['i', 'k'])(x, y) expected = np.einsum('ij,jk->ik', x, y) tol = 1e-1 if jtu.device_under_test() == "tpu" else None self.assertAllClose(out, expected, check_dtypes=True, atol=tol, rtol=tol) check('{i,j},{j,k}->{i,k}') check('{i,j},{k,j}->{k,i}') # order of named axes in the spec doesn't matter! check('{j},{k,j}->{k}') check('{i,j},{j}->{i}') check('{j},{j}->{}') def test_xeinsum_no_named_axes_vector_dot(self): rng = np.random.RandomState(0) x = rng.randn(3) y = rng.randn(3) out = jnp.einsum('i,i->', x, y, _use_xeinsum=True) expected = np.einsum('i,i->', x, y) self.assertAllClose(out, expected, check_dtypes=False) def test_xeinsum_no_named_axes_batch_vector_dot(self): rng = np.random.RandomState(0) x = rng.randn(3, 2) y = rng.randn(3, 2) out = jnp.einsum('ij,ij->i', x, y, _use_xeinsum=True) expected = np.einsum('ij,ij->i', x, y) self.assertAllClose(out, expected, check_dtypes=True) def test_xeinsum_no_named_axes_reduce_sum(self): rng = np.random.RandomState(0) x = rng.randn(3) y = rng.randn() out = jnp.einsum('i,->', x, y, _use_xeinsum=True) expected = np.einsum('i,->', x, y) self.assertAllClose(out, expected, check_dtypes=True)
class IndexedUpdateTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ "testcase_name": "{}_inshape={}_indexer={}_update={}_sugared={}_op={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), sugared, op.name), "shape": shape, "dtype": dtype, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op, "sugared": sugared } for name, index_specs in s(STATIC_INDEXING_TESTS) for shape, indexer in s(index_specs) for op in s(UpdateOps) for dtype in s(UpdateOps.dtypes(op)) for update_shape in s(_broadcastable_shapes(_update_shape(shape, indexer))) for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes) for sugared in (s([True, False]) if op not in [UpdateOps.DIV, UpdateOps.POW] else [True])))) def testStaticIndexing(self, shape, dtype, update_shape, update_dtype, indexer, sugared, op): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) if sugared: jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y) else: jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y) self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol={np.complex128: 1e-14}) self._CompileAndCheck(jax_fn, args_maker) @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in s(ADVANCED_INDEXING_TESTS_NO_REPEATS) for shape, indexer in s(index_specs) for op in s(UpdateOps) for dtype in s(UpdateOps.dtypes(op)) for update_shape in s(_broadcastable_shapes(_update_shape(shape, indexer))) for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)))) def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, indexer, op): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y, unique_indices=True) self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol={np.complex128: 1e-14}) self._CompileAndCheck(jax_fn, args_maker) @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in s(ADVANCED_INDEXING_TESTS_NO_REPEATS_SORTED) for shape, indexer in s(index_specs) for op in s(UpdateOps) for dtype in s(UpdateOps.dtypes(op)) for update_shape in s(_broadcastable_shapes(_update_shape(shape, indexer))) for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)))) def testAdvancedIndexingSorted(self, shape, dtype, update_shape, update_dtype, indexer, op): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) jax_fn = lambda x, y: UpdateOps.sugar_fn( op, indexer, x, y, indices_are_sorted=True, unique_indices=True) self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, check_dtypes=True, tol={np.complex128: 1e-14}) self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in s(MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS) for shape, indexer in s(index_specs) for op in s(UpdateOps) for dtype in s(UpdateOps.dtypes(op)) for update_shape in s(_broadcastable_shapes(_update_shape(shape, indexer))) for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)))) def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, indexer, op): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y) self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol={np.complex128: 1e-14}) self._CompileAndCheck(jax_fn, args_maker) @parameterized.named_parameters(jtu.cases_from_list({ "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer in index_specs for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE] for dtype in float_dtypes for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) for update_dtype in ([dtype] if op == UpdateOps.ADD else float_dtypes))) @jtu.skip_on_devices("tpu") # TODO(mattjj,phawkins): tpu issues def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype, indexer, op): rng = jtu.rand_default(self.rng()) jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y) x = rng(shape, dtype) y = rng(update_shape, update_dtype) check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.) def testSegmentSumBehavior(self): # testAdvancedIndexing compares against NumPy, and as a result doesn't check # repeated indices. This test is just a simple manual check, based on # https://www.tensorflow.org/api_docs/python/tf/math/segment_sum data = np.array([5, 1, 7, 2, 3, 4, 1, 3]) segment_ids = np.array([0, 0, 0, 1, 2, 2, 3, 3]) ans = ops.index_add(np.zeros(np.max(segment_ids) + 1), segment_ids, data) expected = np.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False) def testSegmentSum(self): data = jnp.array([5, 1, 7, 2, 3, 4, 1, 3]) segment_ids = jnp.array([0, 0, 0, 1, 2, 2, 3, 3]) # test with explicit num_segments ans = ops.segment_sum(data, segment_ids, num_segments=4) expected = jnp.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False) # test with explicit num_segments larger than the higher index. ans = ops.segment_sum(data, segment_ids, num_segments=5) expected = jnp.array([13, 2, 7, 4, 0]) self.assertAllClose(ans, expected, check_dtypes=False) # test without explicit num_segments ans = ops.segment_sum(data, segment_ids) expected = jnp.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False) # test with negative segment ids and segment ids larger than num_segments, # that will be wrapped with the `mod`. segment_ids = jnp.array([0, 4, 8, 1, 2, -6, -1, 3]) ans = ops.segment_sum(data, segment_ids, num_segments=4) expected = jnp.array([5, 2, 3, 3]) self.assertAllClose(ans, expected, check_dtypes=False) # test with negative segment ids and without without explicit num_segments # such as num_segments is defined by the smaller index. segment_ids = jnp.array([3, 3, 3, 4, 5, 5, -7, -6]) ans = ops.segment_sum(data, segment_ids) expected = jnp.array([0, 0, 0, 13, 2, 7]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list({ "testcase_name": "_{}_{}_num_segments={}_bucket_size={}".format( jtu.format_shape_dtype_string(shape, dtype), reducer.__name__, num_segments, bucket_size), "dtype": dtype, "shape": shape, "reducer": reducer, "op": op, "identity": identity, "num_segments": num_segments, "bucket_size": bucket_size} for dtype in default_dtypes for shape in [(8,), (7, 4), (6, 4, 2)] for bucket_size in [None, 2] for num_segments in [None, 1, 3]) for reducer, op, identity in [ (ops.segment_sum, np.add, 0), (ops.segment_prod, np.multiply, 1), (ops.segment_min, np.minimum, float('inf')), (ops.segment_max, np.maximum, -float('inf')), ])) def testSegmentReduce(self, shape, dtype, reducer, op, identity, num_segments, bucket_size): rng = jtu.rand_default(self.rng()) idx_rng = jtu.rand_int(self.rng(), low=-2, high=3) args_maker = lambda: [rng(shape, dtype), idx_rng(shape[:1], jnp.int32)] if np.issubdtype(dtype, np.integer): if np.isposinf(identity): identity = np.iinfo(dtype).max elif np.isneginf(identity): identity = np.iinfo(dtype).min jnp_fun = lambda data, segment_ids: reducer( data, segment_ids, num_segments=num_segments, bucket_size=bucket_size) def np_fun(data, segment_ids): size = num_segments if num_segments is not None else (segment_ids.max() + 1) out = np.full((size,) + shape[1:], identity, dtype) for i, val in zip(segment_ids, data): if 0 <= i < size: out[i] = op(out[i], val).astype(dtype) return out self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) if num_segments is not None: self._CompileAndCheck(jnp_fun, args_maker) def testIndexDtypeError(self): # https://github.com/google/jax/issues/2795 jnp.array(1) # get rid of startup warning with warnings.catch_warnings(record=True) as w: warnings.simplefilter("error") jnp.zeros(5).at[::2].set(1) self.assertLen(w, 0) @contextmanager def assertNoWarnings(self): with warnings.catch_warnings(record=True) as caught_warnings: yield self.assertEmpty(caught_warnings) @parameterized.named_parameters(jtu.cases_from_list({ "testcase_name": "idx={}".format(idx), "idx": idx, "idx_type": idx_type} for idx, idx_type in [ ([0], "array"), ([0, 0], "array"), ([[0, 0]], "tuple"), ([0, [0, 1]], "tuple"), ([0, np.arange(2)], "tuple"), ([0, None], "tuple"), ([0, slice(None)], "tuple"), ])) def testIndexSequenceDeprecation(self, idx, idx_type): normalize = {"array": np.array, "tuple": tuple}[idx_type] msg = {"array": ARRAY_MSG, "tuple": TUPLE_MSG}[idx_type] x = jnp.arange(6).reshape(3, 2) with self.assertRaisesRegex(TypeError, msg): x[idx] with self.assertNoWarnings(): x[normalize(idx)] with self.assertRaisesRegex(TypeError, msg): x.at[idx].set(0) with self.assertNoWarnings(): x.at[normalize(idx)].set(0)