def _get_min_identity(dtype): if dtypes.issubdtype(dtype, np.inexact): return np.array(np.inf, dtype) elif dtypes.issubdtype(dtype, np.integer): return np.array(dtypes.iinfo(dtype).max, dtype) elif dtypes.issubdtype(dtype, np.bool_): return np.array(True, np.bool_)
def _get_max_identity(dtype): if dtypes.issubdtype(dtype, np.inexact): return np.array(-np.inf, dtype) elif dtypes.issubdtype(dtype, np.integer): return np.array(dtypes.iinfo(dtype).min, dtype) elif dtypes.issubdtype(dtype, np.bool_): return np.array(False, np.bool_)
class LaxVmapTest(jtu.JaxTestCase): def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng, rtol=None, atol=None): batched_shapes = map(partial(add_bdim, bdim_size), bdims, shapes) args = [ rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes) ] args_slice = args_slicer(args, bdims) ans = api.vmap(op, bdims)(*args) if bdim_size == 0: args = [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] out = op(*args) expected = np.zeros((0, ) + out.shape, out.dtype) else: expected = np.stack([op(*args_slice(i)) for i in range(bdim_size)]) self.assertAllClose(ans, expected, rtol=rtol, atol=atol) @parameterized.named_parameters( itertools.chain.from_iterable( jtu.cases_from_list( { "testcase_name": "{}_bdims={}".format( jtu.format_test_name_suffix( rec.op, shapes, itertools.repeat(dtype)), bdims), "op_name": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype, "bdims": bdims, "tol": rec.tol } for shape_group in compatible_shapes for shapes in itertools.combinations_with_replacement( shape_group, rec.nargs) for bdims in all_bdims(*shapes) for dtype in rec.dtypes) for rec in LAX_OPS)) def testOp(self, op_name, rng_factory, shapes, dtype, bdims, tol): rng = rng_factory(self.rng()) op = getattr(lax, op_name) self._CheckBatching(op, 10, bdims, shapes, [dtype] * len(shapes), rng, atol=tol, rtol=tol) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}" "_lhs_bdim={}_rhs_bdim={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums), feature_group_count, batch_group_count, lhs_bdim, rhs_bdim), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "strides": strides, "padding": padding, "lhs_dil": lhs_dil, "rhs_dil": rhs_dil, "dimension_numbers": dim_nums, "perms": perms, "lhs_bdim": lhs_bdim, "rhs_bdim": rhs_bdim, "feature_group_count": feature_group_count, "batch_group_count": batch_group_count, } for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)]) for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in [ ( (b * batch_group_count, i * feature_group_count, 6, 7), # lhs_shape (j * batch_group_count * feature_group_count, i, 1, 2), # rhs_shape [(1, 1), (1, 2), (2, 1)], # strides [((0, 0), (0, 0)), ((1, 0), (0, 1)), ((0, -1), (0, 0))], # pads [(1, 1), (2, 1)], # lhs_dils [(1, 1), (2, 2)]) # rhs_dils for b, i, j in itertools.product([1, 2], repeat=3) ] for strides in all_strides for rhs_dil in rhs_dils for lhs_dil in lhs_dils for dtype in [np.float32] for padding in all_pads for dim_nums, perms in [(("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])), (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])), (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))] for lhs_bdim in itertools.chain( [cast(Optional[int], None)], range(len(lhs_shape) + 1)) for rhs_bdim in itertools.chain( [cast(Optional[int], None)], range(len(rhs_shape) + 1)) if (lhs_bdim, rhs_bdim) != (None, None))) def testConvGeneralDilatedBatching(self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil, dimension_numbers, perms, feature_group_count, batch_group_count, lhs_bdim, rhs_bdim): rng = jtu.rand_default(self.rng()) tol = 1e-1 if dtypes.finfo(dtype).bits <= 32 else 1e-3 # permute shapes to match dim_spec, scale by feature_group_count lhs_perm, rhs_perm = perms lhs_shape = list(np.take(lhs_shape, lhs_perm)) rhs_shape = list(np.take(rhs_shape, rhs_perm)) conv = partial(lax.conv_general_dilated, window_strides=strides, padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision=lax.Precision.HIGHEST) self._CheckBatching(conv, 5, (lhs_bdim, rhs_bdim), (lhs_shape, rhs_shape), (dtype, dtype), rng, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format( shape, from_dtype, to_dtype, bdims), "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype, "bdims": bdims } for from_dtype, to_dtype in itertools.product( [np.float32, np.int32, "float32", "int32"], repeat=2) for shape in [(2, 3)] for bdims in all_bdims(shape))) def testConvertElementType(self, shape, from_dtype, to_dtype, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.convert_element_type(x, to_dtype) self._CheckBatching(op, 10, bdims, (shape, ), (from_dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format( shape, from_dtype, to_dtype, bdims), "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype, "bdims": bdims } for from_dtype, to_dtype in itertools.product( [np.float32, np.int32, "float32", "int32"], repeat=2) for shape in [(2, 3)] for bdims in all_bdims(shape))) def testBitcastElementType( self, shape, from_dtype, to_dtype, bdims, ): rng = jtu.rand_default(self.rng()) op = lambda x: lax.bitcast_convert_type(x, to_dtype) self._CheckBatching(op, 10, bdims, (shape, ), (from_dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_min_shape={}_operand_shape={}_max_shape={}_bdims={}".format( jtu.format_shape_dtype_string(min_shape, dtype), jtu.format_shape_dtype_string(operand_shape, dtype), jtu.format_shape_dtype_string(max_shape, dtype), bdims), "min_shape": min_shape, "operand_shape": operand_shape, "max_shape": max_shape, "dtype": dtype, "bdims": bdims } for min_shape, operand_shape, max_shape in [ [(), (2, 3), ()], [(2, 3), (2, 3), ()], [(), (2, 3), (2, 3)], [(2, 3), (2, 3), (2, 3)], ] for dtype in default_dtypes for bdims in all_bdims(min_shape, operand_shape, max_shape))) def testClamp(self, min_shape, operand_shape, max_shape, dtype, bdims): rng = jtu.rand_default(self.rng()) raise SkipTest( "batching rule for clamp not implemented") # TODO(mattj) shapes = [min_shape, operand_shape, max_shape] self._CheckBatching(lax.clamp, 10, bdims, shapes, [dtype] * 3, rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs_shape={}_rhs_shape={}_bdims={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), bdims), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "bdims": bdims } for lhs_shape in [(3, ), (4, 3)] for rhs_shape in [(3, ), (3, 6)] for bdims in all_bdims(lhs_shape, rhs_shape) for dtype in default_dtypes)) def testDot(self, lhs_shape, rhs_shape, dtype, bdims): rng = jtu.rand_default(self.rng()) op = partial(lax.dot, precision=lax.Precision.HIGHEST) self._CheckBatching(op, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng, rtol={ np.float16: 5e-2, np.float64: 5e-14 }) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}_bdims={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), lhs_contracting, rhs_contracting, bdims), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting, "bdims": bdims } for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [ [(5, ), (5, ), [0], [0]], [(5, 7), (5, ), [0], [0]], [(7, 5), (5, ), [1], [0]], [(3, 5), (2, 5), [1], [1]], [(5, 3), (5, 2), [0], [0]], [(5, 3, 2), (5, 2, 4), [0], [0]], [(5, 3, 2), (5, 2, 4), [0, 2], [0, 1]], [(5, 3, 2), (3, 5, 2, 4), [0, 2], [1, 2]], [(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]], [(3, 2), (2, 4), [1], [0]], ] for bdims in all_bdims(lhs_shape, rhs_shape) for dtype in default_dtypes)) def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype, lhs_contracting, rhs_contracting, bdims): rng = jtu.rand_small(self.rng()) dimension_numbers = ((lhs_contracting, rhs_contracting), ([], [])) dot = partial(lax.dot_general, dimension_numbers=dimension_numbers) self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_bdims={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), dimension_numbers, bdims), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "dimension_numbers": dimension_numbers, "bdims": bdims } for lhs_shape, rhs_shape, dimension_numbers in [ ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))), ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))), ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))), ] for bdims in all_bdims(lhs_shape, rhs_shape) for dtype in default_dtypes)) def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype, dimension_numbers, bdims): rng = jtu.rand_small(self.rng()) dot = partial(lax.dot_general, dimension_numbers=dimension_numbers) self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng) # Checks that batching didn't introduce any transposes or broadcasts. jaxpr = api.make_jaxpr(dot)(np.zeros(lhs_shape, dtype), np.zeros(rhs_shape, dtype)) for eqn in jtu.iter_eqns(jaxpr.jaxpr): self.assertFalse(eqn.primitive in ["transpose", "broadcast"]) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_dtype={}_broadcast_sizes={}_bdims={}".format( shape, np.dtype(dtype).name, broadcast_sizes, bdims), "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes, "bdims": bdims } for shape in [(), (2, 3)] for dtype in default_dtypes for broadcast_sizes in [(), (2, ), (1, 2)] for bdims in all_bdims(shape))) def testBroadcast(self, shape, dtype, broadcast_sizes, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.broadcast(x, broadcast_sizes) self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_outshape={}_bcdims={}_bdims={}".format( jtu.format_shape_dtype_string(inshape, dtype), outshape, broadcast_dimensions, bdims), "inshape": inshape, "dtype": dtype, "outshape": outshape, "dimensions": broadcast_dimensions, "bdims": bdims } for inshape, outshape, broadcast_dimensions in [ ([2], [2, 2], [0]), ([2], [2, 2], [1]), ([2], [2, 3], [0]), ([], [2, 3], []), ] for dtype in default_dtypes for bdims in all_bdims(inshape))) def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims): rng = jtu.rand_default(self.rng()) raise SkipTest("this test has failures in some cases") # TODO(mattjj) op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions) self._CheckBatching(op, 5, bdims, (inshape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_dimensions={}_bdims={}".format( jtu.format_shape_dtype_string(arg_shape, np.float32), dimensions, bdims), "arg_shape": arg_shape, "dimensions": dimensions, "bdims": bdims } for arg_shape, dimensions in [ [(1, ), (0, )], [(1, ), (-1, )], [(2, 1, 4), (1, )], [(2, 1, 4), (-2, )], [(2, 1, 3, 1), (1, )], [(2, 1, 3, 1), (1, 3)], [(2, 1, 3, 1), (3, )], [(2, 1, 3, 1), (1, -1)], ] for bdims in all_bdims(arg_shape))) def testSqueeze(self, arg_shape, dimensions, bdims): dtype = np.float32 rng = jtu.rand_default(self.rng()) op = lambda x: lax.squeeze(x, dimensions) self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_inshape={}_outshape={}_dims={}_bdims={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), jtu.format_shape_dtype_string(out_shape, dtype), dimensions, bdims), "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, "dimensions": dimensions, "bdims": bdims } for dtype in default_dtypes for arg_shape, dimensions, out_shape in [[(3, 4), None, ( 12, )], [(2, 1, 4), None, (8, )], [(2, 2, 4), None, ( 2, 8)], [(2, 2, 4), (0, 1, 2), (2, 8)], [(2, 2, 4), (1, 0, 2), ( 8, 2)], [(2, 2, 4), (2, 1, 0), (4, 2, 2)]] for bdims in all_bdims(arg_shape))) def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.reshape(x, out_shape, dimensions=dimensions) self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_pads={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), pads, bdims), "shape": shape, "dtype": dtype, "pads": pads, "bdims": bdims } for shape in [(2, 3)] for bdims in all_bdims(shape) for dtype in default_dtypes for pads in [[(1, 2, 1), (0, 1, 0)]])) def testPad(self, shape, dtype, pads, bdims): rng = jtu.rand_small(self.rng()) fun = lambda operand: lax.pad(operand, np.array(0, dtype), pads) self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_predshape={}_argshapes={}_bdims={}".format( jtu.format_shape_dtype_string(pred_shape, np.bool_), jtu.format_shape_dtype_string(arg_shape, arg_dtype), bdims), "pred_shape": pred_shape, "arg_shape": arg_shape, "arg_dtype": arg_dtype, "bdims": bdims } for arg_shape in [(), (3, ), (2, 3)] for pred_shape in ([(), arg_shape] if arg_shape else [()]) for bdims in all_bdims(pred_shape, arg_shape, arg_shape) for arg_dtype in default_dtypes)) def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims): rng = jtu.rand_default(self.rng()) op = lambda c, x, y: lax.select(c < 0, x, y) self._CheckBatching(op, 5, bdims, ( pred_shape, arg_shape, arg_shape, ), (np.bool_, arg_dtype, arg_dtype), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_start_indices={}_limit_indices={}_strides={}_bdims={}". format(jtu.format_shape_dtype_string(shape, dtype), start_indices, limit_indices, strides, bdims), "shape": shape, "dtype": dtype, "starts": start_indices, "limits": limit_indices, "strides": strides, "bdims": bdims } for shape, start_indices, limit_indices, strides in [ [(3, ), (1, ), (2, ), None], [(7, ), (4, ), (7, ), None], [(5, ), (1, ), (5, ), (2, )], [(8, ), (1, ), (6, ), (2, )], [(5, 3), (1, 1), (3, 2), None], [(5, 3), (1, 1), (3, 1), None], [(7, 5, 3), (4, 0, 1), (7, 1, 3), None], [(5, 3), (1, 1), (2, 1), (1, 1)], [(5, 3), (1, 1), (5, 3), (2, 1)], ] for bdims in all_bdims(shape) for dtype in default_dtypes)) def testSlice(self, shape, dtype, starts, limits, strides, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.slice(x, starts, limits, strides) self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_perm={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), perm, bdims), "shape": shape, "dtype": dtype, "perm": perm, "bdims": bdims } for shape, perm in [ [(3, 4), (1, 0)], [(3, 4), (0, 1)], [(3, 4, 5), (2, 1, 0)], [(3, 4, 5), (1, 0, 2)], ] for bdims in all_bdims(shape) for dtype in default_dtypes)) def testTranspose(self, shape, dtype, perm, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.transpose(x, perm) self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_op={}_inshape={}_reducedims={}_initval={}_bdims={}".format( op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims, init_val, bdims), "op": op, "init_val": init_val, "shape": shape, "dtype": dtype, "dims": dims, "bdims": bdims } for init_val, op, dtypes in [ (0, lax.add, default_dtypes), (1, lax.mul, default_dtypes), (0, lax.max, all_dtypes), # non-monoidal (-np.inf, lax.max, float_dtypes), (dtypes.iinfo(np.int32).min, lax.max, [np.int32]), (dtypes.iinfo(np.int64).min, lax.max, [np.int64]), (dtypes.iinfo(np.uint32).min, lax.max, [np.uint32]), (dtypes.iinfo(np.uint64).min, lax.max, [np.uint64]), (np.inf, lax.min, float_dtypes), (dtypes.iinfo(np.int32).max, lax.min, [np.int32]), (dtypes.iinfo(np.int64).max, lax.min, [np.int64]), (dtypes.iinfo(np.uint32).max, lax.min, [np.uint32]), (dtypes.iinfo(np.uint64).max, lax.min, [np.uint64]), ] for dtype in dtypes for shape, dims in [[(3, 4, 5), ( 0, )], [(3, 4, 5), (1, 2)], [(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]] for bdims in all_bdims(shape))) def testReduce(self, op, init_val, shape, dtype, dims, bdims): rng = jtu.rand_small(self.rng()) init_val = np.asarray(init_val, dtype=dtype) fun = lambda operand: lax.reduce(operand, init_val, op, dims) self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_op={}_inshape={}_reducedims={}_bdims={}".format( op.__name__, jtu.format_shape_dtype_string(shape, dtype), dim, bdims), "op": op, "shape": shape, "dtype": dtype, "dim": dim, "bdims": bdims } for op in [lax.argmin, lax.argmax] for dtype in default_dtypes for shape in [(3, 4, 5)] for dim in range(len(shape)) for bdims in all_bdims(shape))) def testArgminmax(self, op, shape, dtype, dim, bdims): rng = jtu.rand_default(self.rng()) fun = lambda operand: op(operand, dim, np.int32) self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}" "_basedilation={}_windowdilation={}").format( op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims, strides, padding, base_dilation, window_dilation), "op": op, "init_val": init_val, "dtype": dtype, "shape": shape, "dims": dims, "strides": strides, "padding": padding, "base_dilation": base_dilation, "window_dilation": window_dilation } for init_val, op, dtypes in [ (0, lax.add, [np.float32]), (-np.inf, lax.max, [np.float32]), (np.inf, lax.min, [np.float32]), ] for shape, dims, strides, padding, base_dilation, window_dilation in (itertools.chain( itertools.product([(4, 6)], [(2, 1), (1, 2)], [(1, 1), ( 2, 1), (1, 2)], ["VALID", "SAME", [(0, 3), (1, 2)]], [( 1, 1), (2, 3)], [(1, 1), (1, 2)]), itertools.product([(3, 2, 4, 6)], [(1, 1, 2, 1), ( 2, 1, 2, 1)], [(1, 2, 2, 1), ( 1, 1, 1, 1 )], ["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]], [( 1, 1, 1, 1), (2, 1, 3, 2)], [(1, 1, 1, 1), (1, 2, 2, 1)]))) for dtype in dtypes)) def testReduceWindow(self, op, init_val, dtype, shape, dims, strides, padding, base_dilation, window_dilation): rng = jtu.rand_small(self.rng()) init_val = np.asarray(init_val, dtype=dtype) def fun(operand): return lax.reduce_window(operand, init_val, op, dims, strides, padding, base_dilation, window_dilation) for bdims in all_bdims(shape): self._CheckBatching(fun, 3, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_op={}_shape={}_axis={}_bdims={}_reverse={}".format( op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis, bdims, reverse), "op": op, "shape": shape, "dtype": dtype, "bdims": bdims, "axis": axis, "reverse": reverse } for op, types in [ (lax.cumsum, [np.float32, np.float64]), (lax.cumprod, [np.float32, np.float64]), ] for dtype in types for shape in [[10], [3, 4, 5]] for axis in range(len(shape)) for bdims in all_bdims(shape) for reverse in [False, True])) def testCumulativeReduce(self, op, shape, dtype, axis, bdims, reverse): rng_factory = (jtu.rand_default if dtypes.issubdtype( dtype, np.integer) else jtu.rand_small) rng = rng_factory(self.rng()) self._CheckBatching(partial(op, axis=axis, reverse=reverse), 7, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_dtype={}_padding={}".format(np.dtype(dtype).name, padding), "dtype": dtype, "padding": padding } for dtype in float_dtypes for padding in ["VALID", "SAME"])) @jtu.skip_on_flag("jax_skip_slow_tests", True) @jtu.ignore_warning(message="Using reduced precision for gradient.*") def testSelectAndGatherAdd(self, dtype, padding): if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16: raise SkipTest( "bfloat16 _select_and_gather_add doesn't work on tpu") rng = jtu.rand_small(self.rng()) all_configs = itertools.chain( itertools.product([(4, 6)], [(2, 1), (1, 2)], [(1, 1), (2, 1), (1, 2)]), itertools.product([(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], [(1, 2, 2, 1), (1, 1, 1, 1)])) def fun(operand, tangents): pads = lax.padtype_to_pads(operand.shape, dims, strides, padding) ones = (1, ) * len(operand.shape) return lax._select_and_gather_add(operand, tangents, lax.ge_p, dims, strides, pads, ones, ones) for shape, dims, strides in all_configs: for bdims in all_bdims(shape, shape): self._CheckBatching(fun, 3, bdims, (shape, shape), (dtype, dtype), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_dtype={jtu.format_shape_dtype_string(shape, dtype)}" f"_padding={padding}_dims={dims}_strides={strides}", "dtype": dtype, "padding": padding, "shape": shape, "dims": dims, "strides": strides } for dtype in float_dtypes for padding in ["VALID", "SAME"] for shape in [(3, 2, 4, 6)] for dims in [(1, 1, 2, 1)] for strides in [(1, 2, 2, 1), (1, 1, 1, 1)])) def testSelectAndScatterAdd(self, dtype, padding, shape, dims, strides): rng = jtu.rand_small(self.rng()) pads = lax.padtype_to_pads(shape, dims, strides, padding) def fun(operand, cotangents): return lax._select_and_scatter_add(operand, cotangents, lax.ge_p, dims, strides, pads) ones = (1, ) * len(shape) cotangent_shape = api.eval_shape( lambda x: lax._select_and_gather_add(x, x, lax.ge_p, dims, strides, pads, ones, ones), np.ones(shape, dtype)).shape for bdims in all_bdims(cotangent_shape, shape): self._CheckBatching(fun, 3, bdims, (cotangent_shape, shape), (dtype, dtype), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_bdims={}_fft_ndims={}".format(shape, bdims, fft_ndims), "shape": shape, "bdims": bdims, "fft_ndims": fft_ndims } for shape in [(5, ), (3, 4, 5), (2, 3, 4, 5)] for bdims in all_bdims(shape) for fft_ndims in range(0, min(3, len(shape)) + 1))) @jtu.skip_on_devices("tpu") # TODO(b/137993701): unimplemented cases. def testFft(self, fft_ndims, shape, bdims): rng = jtu.rand_default(self.rng()) ndims = len(shape) axes = range(ndims - fft_ndims, ndims) fft_lengths = [shape[axis] for axis in axes] op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths) self._CheckBatching(op, 5, bdims, [shape], [np.complex64], rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), idxs, dnums, slice_sizes, bdims), "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "bdims": bdims } for dtype in all_dtypes for shape, idxs, dnums, slice_sizes in [ ((5, ), np.array([[0], [2]]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, )), ((10, ), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, )), (2, )), (( 10, 5, ), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, 3)), ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, 1)), (1, 3)), ] for bdims in all_bdims(shape, idxs.shape))) def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) self._CheckBatching(fun, 0, bdims, [shape, idxs.shape], [dtype, idxs.dtype], jtu.rand_default(self.rng())) self._CheckBatching(fun, 5, bdims, [shape, idxs.shape], [dtype, idxs.dtype], jtu.rand_default(self.rng())) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_shape={}_idxs={}_update={}_dnums={}_bdims={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums, bdims), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "bdims": bdims } for dtype in float_dtypes for arg_shape, idxs, update_shape, dnums in [ ((5, ), np.array([[0], [2]]), (2, ), lax.ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, ), scatter_dims_to_operand_dims=( 0, ))), ((10, ), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(update_window_dims=(1, ), inserted_window_dims=(), scatter_dims_to_operand_dims=( 0, ))), (( 10, 5, ), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(update_window_dims=(1, ), inserted_window_dims=(0, ), scatter_dims_to_operand_dims=( 0, ))), ] for bdims in all_bdims(arg_shape, idxs.shape, update_shape))) def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims): fun = partial(lax.scatter_add, dimension_numbers=dnums) self._CheckBatching(fun, 5, bdims, [arg_shape, idxs.shape, update_shape], [dtype, idxs.dtype, dtype], jtu.rand_default(self.rng()), rtol={np.float16: 5e-3}) def testShapeUsesBuiltinInt(self): x = lax.iota(np.int32, 3) + 1 self.assertIsInstance(x.shape[0], int) # not np.int64 def testBroadcastShapesReturnsPythonInts(self): shape1, shape2 = (1, 2, 3), (2, 3) out_shape = lax.broadcast_shapes(shape1, shape2) self.assertTrue(all(type(s) is int for s in out_shape)) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_shape={}_k={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), k, bdims), "shape": shape, "dtype": dtype, "k": k, "bdims": bdims, "rng_factory": rng_factory } for shape in [(4, ), (3, 5, 3)] for k in [1, 3] for bdims in all_bdims(shape) # TODO(b/155170120): test with repeats once the XLA:CPU stable top_k bug is fixed: # The top_k indices for integer arrays with identical entries won't match between # vmap'd version and manual reference, so only test unique integer arrays for int_dtypes. # Note also that we chose 3 * 5 * 3 * 5 such that it fits in the range of # values a bfloat16 can represent exactly to avoid ties. for dtype, rng_factory in itertools.chain( unsafe_zip(default_dtypes, itertools.repeat( jtu.rand_unique_int))))) def testTopK(self, shape, dtype, k, bdims, rng_factory): rng = rng_factory(self.rng()) # _CheckBatching doesn't work with tuple outputs, so test outputs separately. op1 = lambda x: lax.top_k(x, k=k)[0] self._CheckBatching(op1, 5, bdims, (shape, ), (dtype, ), rng) op2 = lambda x: lax.top_k(x, k=k)[1] self._CheckBatching(op2, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_dimension={}_arity={}_bdims={}_isstable={}".format( jtu.format_shape_dtype_string(shape, np.float32), dimension, arity, bdims, is_stable), "shape": shape, "dimension": dimension, "arity": arity, "bdims": bdims, "is_stable": is_stable } for shape in [(2, 3)] for dimension in [0, 1] for arity in range(3) for bdims in all_bdims(*((shape, ) * arity)) for is_stable in [False, True])) def testSort(self, shape, dimension, arity, bdims, is_stable): rng = jtu.rand_default(self.rng()) if arity == 1: fun = partial(lax.sort, dimension=dimension) self._CheckBatching(fun, 5, bdims, (shape, ) * arity, (np.float32, ) * arity, rng) else: for i in range(arity): fun = lambda *args, i=i: lax.sort( args, dimension=dimension, is_stable=is_stable)[i] self._CheckBatching(fun, 5, bdims, (shape, ) * arity, (np.float32, ) * arity, rng)