def _scale_and_translate(x, output_shape, scale, translate, kernel, antialias): input_shape = x.shape assert len(input_shape) == len(output_shape) assert len(input_shape) == len(scale) assert len(input_shape) == len(translate) spatial_dims = np.nonzero( np.not_equal(input_shape, output_shape) | np.not_equal(scale, 1) | np.not_equal(translate, 0))[0] if len(spatial_dims) == 0: return x output_spatial_shape = tuple(np.array(output_shape)[spatial_dims]) indices = [] contractions = [] slice_shape = list(input_shape) in_indices = list(range(len(output_shape) + len(spatial_dims))) out_indices = list(range(len(output_shape))) for i, d in enumerate(spatial_dims): m = input_shape[d] n = output_shape[d] starts, weights = _compute_spans(m, n, scale[d], translate[d], kernel, antialias=antialias) starts = lax.broadcast_in_dim(starts, output_spatial_shape + (1, ), (i, )) slice_shape[d] = weights.shape[1] indices.append(starts.astype(np.int32)) contractions.append(weights.astype(x.dtype)) contractions.append([len(output_shape) + i, d]) out_indices[d] = len(output_shape) + i index = lax.concatenate(indices, len(output_spatial_shape)) dnums = lax.GatherDimensionNumbers(offset_dims=tuple( range(len(output_shape))), collapsed_slice_dims=(), start_index_map=tuple(spatial_dims)) out = lax.gather(x, index, dnums, slice_shape) contractions.append(out_indices) return jnp.einsum(out, in_indices, *contractions, precision=lax.Precision.HIGHEST)
class LaxVmapTest(jtu.JaxTestCase): def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng, rtol=None, atol=None): batched_shapes = map(partial(add_bdim, bdim_size), bdims, shapes) args = [ rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes) ] args_slice = args_slicer(args, bdims) ans = api.vmap(op, bdims)(*args) if bdim_size == 0: args = [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] out = op(*args) expected = np.zeros((0, ) + out.shape, out.dtype) else: expected = np.stack([op(*args_slice(i)) for i in range(bdim_size)]) self.assertAllClose(ans, expected, rtol=rtol, atol=atol) @parameterized.named_parameters( itertools.chain.from_iterable( jtu.cases_from_list( { "testcase_name": "{}_bdims={}".format( jtu.format_test_name_suffix( rec.op, shapes, itertools.repeat(dtype)), bdims), "op_name": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype, "bdims": bdims, "tol": rec.tol } for shape_group in compatible_shapes for shapes in itertools.combinations_with_replacement( shape_group, rec.nargs) for bdims in all_bdims(*shapes) for dtype in rec.dtypes) for rec in LAX_OPS)) def testOp(self, op_name, rng_factory, shapes, dtype, bdims, tol): rng = rng_factory(self.rng()) op = getattr(lax, op_name) self._CheckBatching(op, 10, bdims, shapes, [dtype] * len(shapes), rng, atol=tol, rtol=tol) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}" "_lhs_bdim={}_rhs_bdim={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums), feature_group_count, batch_group_count, lhs_bdim, rhs_bdim), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "strides": strides, "padding": padding, "lhs_dil": lhs_dil, "rhs_dil": rhs_dil, "dimension_numbers": dim_nums, "perms": perms, "lhs_bdim": lhs_bdim, "rhs_bdim": rhs_bdim, "feature_group_count": feature_group_count, "batch_group_count": batch_group_count, } for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)]) for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in [ ( (b * batch_group_count, i * feature_group_count, 6, 7), # lhs_shape (j * batch_group_count * feature_group_count, i, 1, 2), # rhs_shape [(1, 1), (1, 2), (2, 1)], # strides [((0, 0), (0, 0)), ((1, 0), (0, 1)), ((0, -1), (0, 0))], # pads [(1, 1), (2, 1)], # lhs_dils [(1, 1), (2, 2)]) # rhs_dils for b, i, j in itertools.product([1, 2], repeat=3) ] for strides in all_strides for rhs_dil in rhs_dils for lhs_dil in lhs_dils for dtype in [np.float32] for padding in all_pads for dim_nums, perms in [(("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])), (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])), (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))] for lhs_bdim in itertools.chain( [cast(Optional[int], None)], range(len(lhs_shape) + 1)) for rhs_bdim in itertools.chain( [cast(Optional[int], None)], range(len(rhs_shape) + 1)) if (lhs_bdim, rhs_bdim) != (None, None))) def testConvGeneralDilatedBatching(self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil, dimension_numbers, perms, feature_group_count, batch_group_count, lhs_bdim, rhs_bdim): rng = jtu.rand_default(self.rng()) tol = 1e-1 if dtypes.finfo(dtype).bits <= 32 else 1e-3 # permute shapes to match dim_spec, scale by feature_group_count lhs_perm, rhs_perm = perms lhs_shape = list(np.take(lhs_shape, lhs_perm)) rhs_shape = list(np.take(rhs_shape, rhs_perm)) conv = partial(lax.conv_general_dilated, window_strides=strides, padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision=lax.Precision.HIGHEST) self._CheckBatching(conv, 5, (lhs_bdim, rhs_bdim), (lhs_shape, rhs_shape), (dtype, dtype), rng, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format( shape, from_dtype, to_dtype, bdims), "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype, "bdims": bdims } for from_dtype, to_dtype in itertools.product( [np.float32, np.int32, "float32", "int32"], repeat=2) for shape in [(2, 3)] for bdims in all_bdims(shape))) def testConvertElementType(self, shape, from_dtype, to_dtype, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.convert_element_type(x, to_dtype) self._CheckBatching(op, 10, bdims, (shape, ), (from_dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format( shape, from_dtype, to_dtype, bdims), "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype, "bdims": bdims } for from_dtype, to_dtype in itertools.product( [np.float32, np.int32, "float32", "int32"], repeat=2) for shape in [(2, 3)] for bdims in all_bdims(shape))) def testBitcastElementType( self, shape, from_dtype, to_dtype, bdims, ): rng = jtu.rand_default(self.rng()) op = lambda x: lax.bitcast_convert_type(x, to_dtype) self._CheckBatching(op, 10, bdims, (shape, ), (from_dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_min_shape={}_operand_shape={}_max_shape={}_bdims={}".format( jtu.format_shape_dtype_string(min_shape, dtype), jtu.format_shape_dtype_string(operand_shape, dtype), jtu.format_shape_dtype_string(max_shape, dtype), bdims), "min_shape": min_shape, "operand_shape": operand_shape, "max_shape": max_shape, "dtype": dtype, "bdims": bdims } for min_shape, operand_shape, max_shape in [ [(), (2, 3), ()], [(2, 3), (2, 3), ()], [(), (2, 3), (2, 3)], [(2, 3), (2, 3), (2, 3)], ] for dtype in default_dtypes for bdims in all_bdims(min_shape, operand_shape, max_shape))) def testClamp(self, min_shape, operand_shape, max_shape, dtype, bdims): rng = jtu.rand_default(self.rng()) raise SkipTest( "batching rule for clamp not implemented") # TODO(mattj) shapes = [min_shape, operand_shape, max_shape] self._CheckBatching(lax.clamp, 10, bdims, shapes, [dtype] * 3, rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs_shape={}_rhs_shape={}_bdims={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), bdims), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "bdims": bdims } for lhs_shape in [(3, ), (4, 3)] for rhs_shape in [(3, ), (3, 6)] for bdims in all_bdims(lhs_shape, rhs_shape) for dtype in default_dtypes)) def testDot(self, lhs_shape, rhs_shape, dtype, bdims): rng = jtu.rand_default(self.rng()) op = partial(lax.dot, precision=lax.Precision.HIGHEST) self._CheckBatching(op, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng, rtol={ np.float16: 5e-2, np.float64: 5e-14 }) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}_bdims={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), lhs_contracting, rhs_contracting, bdims), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting, "bdims": bdims } for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [ [(5, ), (5, ), [0], [0]], [(5, 7), (5, ), [0], [0]], [(7, 5), (5, ), [1], [0]], [(3, 5), (2, 5), [1], [1]], [(5, 3), (5, 2), [0], [0]], [(5, 3, 2), (5, 2, 4), [0], [0]], [(5, 3, 2), (5, 2, 4), [0, 2], [0, 1]], [(5, 3, 2), (3, 5, 2, 4), [0, 2], [1, 2]], [(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]], [(3, 2), (2, 4), [1], [0]], ] for bdims in all_bdims(lhs_shape, rhs_shape) for dtype in default_dtypes)) def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype, lhs_contracting, rhs_contracting, bdims): rng = jtu.rand_small(self.rng()) dimension_numbers = ((lhs_contracting, rhs_contracting), ([], [])) dot = partial(lax.dot_general, dimension_numbers=dimension_numbers) self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_bdims={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), dimension_numbers, bdims), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "dimension_numbers": dimension_numbers, "bdims": bdims } for lhs_shape, rhs_shape, dimension_numbers in [ ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))), ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))), ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))), ] for bdims in all_bdims(lhs_shape, rhs_shape) for dtype in default_dtypes)) def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype, dimension_numbers, bdims): rng = jtu.rand_small(self.rng()) dot = partial(lax.dot_general, dimension_numbers=dimension_numbers) self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng) # Checks that batching didn't introduce any transposes or broadcasts. jaxpr = api.make_jaxpr(dot)(np.zeros(lhs_shape, dtype), np.zeros(rhs_shape, dtype)) for eqn in jtu.iter_eqns(jaxpr.jaxpr): self.assertFalse(eqn.primitive in ["transpose", "broadcast"]) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_dtype={}_broadcast_sizes={}_bdims={}".format( shape, np.dtype(dtype).name, broadcast_sizes, bdims), "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes, "bdims": bdims } for shape in [(), (2, 3)] for dtype in default_dtypes for broadcast_sizes in [(), (2, ), (1, 2)] for bdims in all_bdims(shape))) def testBroadcast(self, shape, dtype, broadcast_sizes, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.broadcast(x, broadcast_sizes) self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_outshape={}_bcdims={}_bdims={}".format( jtu.format_shape_dtype_string(inshape, dtype), outshape, broadcast_dimensions, bdims), "inshape": inshape, "dtype": dtype, "outshape": outshape, "dimensions": broadcast_dimensions, "bdims": bdims } for inshape, outshape, broadcast_dimensions in [ ([2], [2, 2], [0]), ([2], [2, 2], [1]), ([2], [2, 3], [0]), ([], [2, 3], []), ] for dtype in default_dtypes for bdims in all_bdims(inshape))) def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims): rng = jtu.rand_default(self.rng()) raise SkipTest("this test has failures in some cases") # TODO(mattjj) op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions) self._CheckBatching(op, 5, bdims, (inshape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_dimensions={}_bdims={}".format( jtu.format_shape_dtype_string(arg_shape, np.float32), dimensions, bdims), "arg_shape": arg_shape, "dimensions": dimensions, "bdims": bdims } for arg_shape, dimensions in [ [(1, ), (0, )], [(1, ), (-1, )], [(2, 1, 4), (1, )], [(2, 1, 4), (-2, )], [(2, 1, 3, 1), (1, )], [(2, 1, 3, 1), (1, 3)], [(2, 1, 3, 1), (3, )], [(2, 1, 3, 1), (1, -1)], ] for bdims in all_bdims(arg_shape))) def testSqueeze(self, arg_shape, dimensions, bdims): dtype = np.float32 rng = jtu.rand_default(self.rng()) op = lambda x: lax.squeeze(x, dimensions) self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_inshape={}_outshape={}_dims={}_bdims={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), jtu.format_shape_dtype_string(out_shape, dtype), dimensions, bdims), "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, "dimensions": dimensions, "bdims": bdims } for dtype in default_dtypes for arg_shape, dimensions, out_shape in [[(3, 4), None, ( 12, )], [(2, 1, 4), None, (8, )], [(2, 2, 4), None, ( 2, 8)], [(2, 2, 4), (0, 1, 2), (2, 8)], [(2, 2, 4), (1, 0, 2), ( 8, 2)], [(2, 2, 4), (2, 1, 0), (4, 2, 2)]] for bdims in all_bdims(arg_shape))) def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.reshape(x, out_shape, dimensions=dimensions) self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_pads={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), pads, bdims), "shape": shape, "dtype": dtype, "pads": pads, "bdims": bdims } for shape in [(2, 3)] for bdims in all_bdims(shape) for dtype in default_dtypes for pads in [[(1, 2, 1), (0, 1, 0)]])) def testPad(self, shape, dtype, pads, bdims): rng = jtu.rand_small(self.rng()) fun = lambda operand: lax.pad(operand, np.array(0, dtype), pads) self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_predshape={}_argshapes={}_bdims={}".format( jtu.format_shape_dtype_string(pred_shape, np.bool_), jtu.format_shape_dtype_string(arg_shape, arg_dtype), bdims), "pred_shape": pred_shape, "arg_shape": arg_shape, "arg_dtype": arg_dtype, "bdims": bdims } for arg_shape in [(), (3, ), (2, 3)] for pred_shape in ([(), arg_shape] if arg_shape else [()]) for bdims in all_bdims(pred_shape, arg_shape, arg_shape) for arg_dtype in default_dtypes)) def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims): rng = jtu.rand_default(self.rng()) op = lambda c, x, y: lax.select(c < 0, x, y) self._CheckBatching(op, 5, bdims, ( pred_shape, arg_shape, arg_shape, ), (np.bool_, arg_dtype, arg_dtype), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_start_indices={}_limit_indices={}_strides={}_bdims={}". format(jtu.format_shape_dtype_string(shape, dtype), start_indices, limit_indices, strides, bdims), "shape": shape, "dtype": dtype, "starts": start_indices, "limits": limit_indices, "strides": strides, "bdims": bdims } for shape, start_indices, limit_indices, strides in [ [(3, ), (1, ), (2, ), None], [(7, ), (4, ), (7, ), None], [(5, ), (1, ), (5, ), (2, )], [(8, ), (1, ), (6, ), (2, )], [(5, 3), (1, 1), (3, 2), None], [(5, 3), (1, 1), (3, 1), None], [(7, 5, 3), (4, 0, 1), (7, 1, 3), None], [(5, 3), (1, 1), (2, 1), (1, 1)], [(5, 3), (1, 1), (5, 3), (2, 1)], ] for bdims in all_bdims(shape) for dtype in default_dtypes)) def testSlice(self, shape, dtype, starts, limits, strides, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.slice(x, starts, limits, strides) self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_perm={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), perm, bdims), "shape": shape, "dtype": dtype, "perm": perm, "bdims": bdims } for shape, perm in [ [(3, 4), (1, 0)], [(3, 4), (0, 1)], [(3, 4, 5), (2, 1, 0)], [(3, 4, 5), (1, 0, 2)], ] for bdims in all_bdims(shape) for dtype in default_dtypes)) def testTranspose(self, shape, dtype, perm, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.transpose(x, perm) self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_op={}_inshape={}_reducedims={}_initval={}_bdims={}".format( op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims, init_val, bdims), "op": op, "init_val": init_val, "shape": shape, "dtype": dtype, "dims": dims, "bdims": bdims } for init_val, op, dtypes in [ (0, lax.add, default_dtypes), (1, lax.mul, default_dtypes), (0, lax.max, all_dtypes), # non-monoidal (-np.inf, lax.max, float_dtypes), (dtypes.iinfo(np.int32).min, lax.max, [np.int32]), (dtypes.iinfo(np.int64).min, lax.max, [np.int64]), (dtypes.iinfo(np.uint32).min, lax.max, [np.uint32]), (dtypes.iinfo(np.uint64).min, lax.max, [np.uint64]), (np.inf, lax.min, float_dtypes), (dtypes.iinfo(np.int32).max, lax.min, [np.int32]), (dtypes.iinfo(np.int64).max, lax.min, [np.int64]), (dtypes.iinfo(np.uint32).max, lax.min, [np.uint32]), (dtypes.iinfo(np.uint64).max, lax.min, [np.uint64]), ] for dtype in dtypes for shape, dims in [[(3, 4, 5), ( 0, )], [(3, 4, 5), (1, 2)], [(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]] for bdims in all_bdims(shape))) def testReduce(self, op, init_val, shape, dtype, dims, bdims): rng = jtu.rand_small(self.rng()) init_val = np.asarray(init_val, dtype=dtype) fun = lambda operand: lax.reduce(operand, init_val, op, dims) self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_op={}_inshape={}_reducedims={}_bdims={}".format( op.__name__, jtu.format_shape_dtype_string(shape, dtype), dim, bdims), "op": op, "shape": shape, "dtype": dtype, "dim": dim, "bdims": bdims } for op in [lax.argmin, lax.argmax] for dtype in default_dtypes for shape in [(3, 4, 5)] for dim in range(len(shape)) for bdims in all_bdims(shape))) def testArgminmax(self, op, shape, dtype, dim, bdims): rng = jtu.rand_default(self.rng()) fun = lambda operand: op(operand, dim, np.int32) self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}" "_basedilation={}_windowdilation={}").format( op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims, strides, padding, base_dilation, window_dilation), "op": op, "init_val": init_val, "dtype": dtype, "shape": shape, "dims": dims, "strides": strides, "padding": padding, "base_dilation": base_dilation, "window_dilation": window_dilation } for init_val, op, dtypes in [ (0, lax.add, [np.float32]), (-np.inf, lax.max, [np.float32]), (np.inf, lax.min, [np.float32]), ] for shape, dims, strides, padding, base_dilation, window_dilation in (itertools.chain( itertools.product([(4, 6)], [(2, 1), (1, 2)], [(1, 1), ( 2, 1), (1, 2)], ["VALID", "SAME", [(0, 3), (1, 2)]], [( 1, 1), (2, 3)], [(1, 1), (1, 2)]), itertools.product([(3, 2, 4, 6)], [(1, 1, 2, 1), ( 2, 1, 2, 1)], [(1, 2, 2, 1), ( 1, 1, 1, 1 )], ["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]], [( 1, 1, 1, 1), (2, 1, 3, 2)], [(1, 1, 1, 1), (1, 2, 2, 1)]))) for dtype in dtypes)) def testReduceWindow(self, op, init_val, dtype, shape, dims, strides, padding, base_dilation, window_dilation): rng = jtu.rand_small(self.rng()) init_val = np.asarray(init_val, dtype=dtype) def fun(operand): return lax.reduce_window(operand, init_val, op, dims, strides, padding, base_dilation, window_dilation) for bdims in all_bdims(shape): self._CheckBatching(fun, 3, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_op={}_shape={}_axis={}_bdims={}_reverse={}".format( op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis, bdims, reverse), "op": op, "shape": shape, "dtype": dtype, "bdims": bdims, "axis": axis, "reverse": reverse } for op, types in [ (lax.cumsum, [np.float32, np.float64]), (lax.cumprod, [np.float32, np.float64]), ] for dtype in types for shape in [[10], [3, 4, 5]] for axis in range(len(shape)) for bdims in all_bdims(shape) for reverse in [False, True])) def testCumulativeReduce(self, op, shape, dtype, axis, bdims, reverse): rng_factory = (jtu.rand_default if dtypes.issubdtype( dtype, np.integer) else jtu.rand_small) rng = rng_factory(self.rng()) self._CheckBatching(partial(op, axis=axis, reverse=reverse), 7, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_dtype={}_padding={}".format(np.dtype(dtype).name, padding), "dtype": dtype, "padding": padding } for dtype in float_dtypes for padding in ["VALID", "SAME"])) @jtu.skip_on_flag("jax_skip_slow_tests", True) @jtu.ignore_warning(message="Using reduced precision for gradient.*") def testSelectAndGatherAdd(self, dtype, padding): if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16: raise SkipTest( "bfloat16 _select_and_gather_add doesn't work on tpu") rng = jtu.rand_small(self.rng()) all_configs = itertools.chain( itertools.product([(4, 6)], [(2, 1), (1, 2)], [(1, 1), (2, 1), (1, 2)]), itertools.product([(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], [(1, 2, 2, 1), (1, 1, 1, 1)])) def fun(operand, tangents): pads = lax.padtype_to_pads(operand.shape, dims, strides, padding) ones = (1, ) * len(operand.shape) return lax._select_and_gather_add(operand, tangents, lax.ge_p, dims, strides, pads, ones, ones) for shape, dims, strides in all_configs: for bdims in all_bdims(shape, shape): self._CheckBatching(fun, 3, bdims, (shape, shape), (dtype, dtype), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_dtype={jtu.format_shape_dtype_string(shape, dtype)}" f"_padding={padding}_dims={dims}_strides={strides}", "dtype": dtype, "padding": padding, "shape": shape, "dims": dims, "strides": strides } for dtype in float_dtypes for padding in ["VALID", "SAME"] for shape in [(3, 2, 4, 6)] for dims in [(1, 1, 2, 1)] for strides in [(1, 2, 2, 1), (1, 1, 1, 1)])) def testSelectAndScatterAdd(self, dtype, padding, shape, dims, strides): rng = jtu.rand_small(self.rng()) pads = lax.padtype_to_pads(shape, dims, strides, padding) def fun(operand, cotangents): return lax._select_and_scatter_add(operand, cotangents, lax.ge_p, dims, strides, pads) ones = (1, ) * len(shape) cotangent_shape = api.eval_shape( lambda x: lax._select_and_gather_add(x, x, lax.ge_p, dims, strides, pads, ones, ones), np.ones(shape, dtype)).shape for bdims in all_bdims(cotangent_shape, shape): self._CheckBatching(fun, 3, bdims, (cotangent_shape, shape), (dtype, dtype), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_bdims={}_fft_ndims={}".format(shape, bdims, fft_ndims), "shape": shape, "bdims": bdims, "fft_ndims": fft_ndims } for shape in [(5, ), (3, 4, 5), (2, 3, 4, 5)] for bdims in all_bdims(shape) for fft_ndims in range(0, min(3, len(shape)) + 1))) @jtu.skip_on_devices("tpu") # TODO(b/137993701): unimplemented cases. def testFft(self, fft_ndims, shape, bdims): rng = jtu.rand_default(self.rng()) ndims = len(shape) axes = range(ndims - fft_ndims, ndims) fft_lengths = [shape[axis] for axis in axes] op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths) self._CheckBatching(op, 5, bdims, [shape], [np.complex64], rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), idxs, dnums, slice_sizes, bdims), "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "bdims": bdims } for dtype in all_dtypes for shape, idxs, dnums, slice_sizes in [ ((5, ), np.array([[0], [2]]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, )), ((10, ), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, )), (2, )), (( 10, 5, ), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, 3)), ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, 1)), (1, 3)), ] for bdims in all_bdims(shape, idxs.shape))) def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) self._CheckBatching(fun, 0, bdims, [shape, idxs.shape], [dtype, idxs.dtype], jtu.rand_default(self.rng())) self._CheckBatching(fun, 5, bdims, [shape, idxs.shape], [dtype, idxs.dtype], jtu.rand_default(self.rng())) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_shape={}_idxs={}_update={}_dnums={}_bdims={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums, bdims), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "bdims": bdims } for dtype in float_dtypes for arg_shape, idxs, update_shape, dnums in [ ((5, ), np.array([[0], [2]]), (2, ), lax.ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, ), scatter_dims_to_operand_dims=( 0, ))), ((10, ), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(update_window_dims=(1, ), inserted_window_dims=(), scatter_dims_to_operand_dims=( 0, ))), (( 10, 5, ), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(update_window_dims=(1, ), inserted_window_dims=(0, ), scatter_dims_to_operand_dims=( 0, ))), ] for bdims in all_bdims(arg_shape, idxs.shape, update_shape))) def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims): fun = partial(lax.scatter_add, dimension_numbers=dnums) self._CheckBatching(fun, 5, bdims, [arg_shape, idxs.shape, update_shape], [dtype, idxs.dtype, dtype], jtu.rand_default(self.rng()), rtol={np.float16: 5e-3}) def testShapeUsesBuiltinInt(self): x = lax.iota(np.int32, 3) + 1 self.assertIsInstance(x.shape[0], int) # not np.int64 def testBroadcastShapesReturnsPythonInts(self): shape1, shape2 = (1, 2, 3), (2, 3) out_shape = lax.broadcast_shapes(shape1, shape2) self.assertTrue(all(type(s) is int for s in out_shape)) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_shape={}_k={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), k, bdims), "shape": shape, "dtype": dtype, "k": k, "bdims": bdims, "rng_factory": rng_factory } for shape in [(4, ), (3, 5, 3)] for k in [1, 3] for bdims in all_bdims(shape) # TODO(b/155170120): test with repeats once the XLA:CPU stable top_k bug is fixed: # The top_k indices for integer arrays with identical entries won't match between # vmap'd version and manual reference, so only test unique integer arrays for int_dtypes. # Note also that we chose 3 * 5 * 3 * 5 such that it fits in the range of # values a bfloat16 can represent exactly to avoid ties. for dtype, rng_factory in itertools.chain( unsafe_zip(default_dtypes, itertools.repeat( jtu.rand_unique_int))))) def testTopK(self, shape, dtype, k, bdims, rng_factory): rng = rng_factory(self.rng()) # _CheckBatching doesn't work with tuple outputs, so test outputs separately. op1 = lambda x: lax.top_k(x, k=k)[0] self._CheckBatching(op1, 5, bdims, (shape, ), (dtype, ), rng) op2 = lambda x: lax.top_k(x, k=k)[1] self._CheckBatching(op2, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_dimension={}_arity={}_bdims={}_isstable={}".format( jtu.format_shape_dtype_string(shape, np.float32), dimension, arity, bdims, is_stable), "shape": shape, "dimension": dimension, "arity": arity, "bdims": bdims, "is_stable": is_stable } for shape in [(2, 3)] for dimension in [0, 1] for arity in range(3) for bdims in all_bdims(*((shape, ) * arity)) for is_stable in [False, True])) def testSort(self, shape, dimension, arity, bdims, is_stable): rng = jtu.rand_default(self.rng()) if arity == 1: fun = partial(lax.sort, dimension=dimension) self._CheckBatching(fun, 5, bdims, (shape, ) * arity, (np.float32, ) * arity, rng) else: for i in range(arity): fun = lambda *args, i=i: lax.sort( args, dimension=dimension, is_stable=is_stable)[i] self._CheckBatching(fun, 5, bdims, (shape, ) * arity, (np.float32, ) * arity, rng)
class BatchingTest(jtu.JaxTestCase): def testConstantFunction(self): ans = vmap(lambda x: 3)(np.ones(4)) expected = 3 * np.ones(4) self.assertAllClose(ans, expected, check_dtypes=False) def testNestedBatchingMatMat(self): matvec = vmap(jnp.vdot, in_axes=(0, None)) matmat = vmap(matvec, in_axes=(None, 1), out_axes=1) R = np.random.RandomState(0).randn A = R(4, 3) B = R(3, 2) ans = matmat(A, B) expected = np.dot(A, B) self.assertAllClose( ans, expected, check_dtypes=False, rtol={np.float32:1e-2} if jtu.device_under_test() == "tpu" else None) jaxpr = make_jaxpr(matmat)(A, B) self.assertEqual(len(jaxpr.jaxpr.eqns), 1) def testPerExampleGradients(self): def predict(params, inputs): for W, b in params: outputs = jnp.dot(W, inputs) + b inputs = jnp.tanh(outputs) return outputs def loss(params, data): inputs, targets = data predictions = predict(params, inputs) return jnp.sum((predictions - targets)**2) batch_size = 5 layer_sizes = [3, 2, 4] R = np.random.RandomState(0).randn params = [(R(m, n), R(m)) for m, n in zip(layer_sizes[1:], layer_sizes[:-1])] input_batch = R(5, 3) target_batch = R(5, 4) batch = (input_batch, target_batch) ans = vmap(partial(grad(loss), params))(batch) for ans_pair, param_pair in zip(ans, params): dW, db = ans_pair W, b = param_pair self.assertEqual(dW.shape, (batch_size,) + W.shape) self.assertEqual(db.shape, (batch_size,) + b.shape) def testJacobians(self): def jacbwd(f, x): y, pullback = vjp(f, x) std_basis = np.eye(np.size(y)).reshape((-1,) + np.shape(y)) jac_flat, = vmap(pullback, out_axes=np.ndim(y))(std_basis) return jac_flat.reshape(np.shape(y) + np.shape(x)) def jacfwd(f, x): pushfwd = lambda v: jvp(f, (x,), (v,)) std_basis = np.eye(np.size(x)).reshape((-1,) + np.shape(x)) y, jac_flat = vmap(pushfwd, out_axes=(None, 0))(std_basis) return jac_flat.reshape(np.shape(y) + np.shape(x)) R = np.random.RandomState(0).randn A = R(4, 3) b = R(4) f = lambda x: jnp.tanh(jnp.dot(A, x) + b) x = R(3) self.assertAllClose(jacfwd(f, x), jacbwd(f, x), check_dtypes=False) def testBatchOfCompile(self): side = [] @jit def f(x): side.append(None) return x + x g = jit(vmap(f)) self.assertAllClose(g(np.ones(2)), 2 * np.ones(2), check_dtypes=False) self.assertEqual(len(side), 1) self.assertAllClose(g(2 * np.ones(2)), 4 * np.ones(2), check_dtypes=False) self.assertEqual(len(side), 1) def testSliceLax(self): fun = lambda x: lax.slice(x, (2,), (4,)) R = np.random.RandomState(0).randn x = R(5, 10) ans = vmap(fun)(x) expected_ans = x[:, 2:4] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testSliceNumpy(self): fun = lambda x: x[:, 2] R = np.random.RandomState(0).randn x = R(10, 5, 3, 7) ans = vmap(fun)(x) expected_ans = x[:, :, 2] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testRevLax(self): fun = lambda x: lax.rev(x, [0]) R = np.random.RandomState(0).randn x = R(2, 3) ans = vmap(fun)(x) expected_ans = x[:, ::-1] self.assertAllClose(ans, expected_ans, check_dtypes=False) ans = vmap(fun, (1,), 1)(x) expected_ans = x[::-1, :] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testRevNumpy(self): fun = lambda x: x[:, ::-1] R = np.random.RandomState(0).randn x = R(3, 2, 4) ans = vmap(fun)(x) expected_ans = x[:, :, ::-1] self.assertAllClose(ans, expected_ans, check_dtypes=False) ans = vmap(fun, (1,), 1)(x) expected_ans = x[:, :, ::-1] self.assertAllClose(ans, expected_ans, check_dtypes=False) ans = vmap(fun, (2,), 2)(x) expected_ans = x[:, ::-1, :] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testNpMaximum(self): fun = lambda x: jnp.maximum(x, 0.0) R = np.random.RandomState(0).randn x = R(10, 5, 3, 7) ans = vmap(fun)(x) expected_ans = np.maximum(x, 0.0) self.assertAllClose(ans, expected_ans, check_dtypes=False) def testNpGtrThan(self): R = np.random.RandomState(0).randn x = R(10, 5, 3, 7) ans = vmap(lambda x: x > 1.0)(x) expected_ans = x > 1.0 self.assertAllClose(ans, expected_ans) def testNpMaximumPerExampleGrad(self): R = np.random.RandomState(0).randn x = R(10, 5) W = R(5, 5) fun = lambda W, x: jnp.sum(jnp.maximum(jnp.dot(x, W), 0.0) ** 2) ans = vmap(partial(grad(fun), W))(x) W_t = jnp.transpose(W) for i in range(10): x_ex = x[i:i + 1] expected_ans = 2.0 * jnp.dot( jnp.maximum(jnp.dot(W_t, jnp.transpose(x_ex)), 0.0), x_ex) expected_ans = jnp.transpose(expected_ans) self.assertAllClose( ans[i], expected_ans, check_dtypes=False, atol={np.float32:5e-2} if jtu.device_under_test() == "tpu" else None) def testDotGeneral(self): R = np.random.RandomState(0).randn x = R(10, 3, 4, 5) y = R(10, 3, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) ans = vmap(fun)(x, y) expected = lax.dot_general(x, y, [((3,), (2,)), ((0, 1), (0, 1))]) self.assertAllClose(ans, expected) x = R(3, 4, 10, 5) y = R(3, 10, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) ans = vmap(fun, in_axes=(2, 1))(x, y) expected = np.stack([fun(x[..., i, :], y[:, i, ...]) for i in range(10)]) self.assertAllClose(ans, expected) x = R(3, 4, 5, 10) y = R(3, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) ans = vmap(fun, in_axes=(3, None))(x, y) expected = np.stack([fun(x[..., i], y) for i in range(10)]) self.assertAllClose(ans, expected) x = R(3, 4, 5) y = R(3, 5, 10, 6) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) ans = vmap(fun, in_axes=(None, 2))(x, y) expected = np.stack([fun(x, y[..., i, :]) for i in range(10)]) self.assertAllClose(ans, expected) x = R(4) y = R(4, 10) fun = lambda x, y: lax.dot_general(x, y, [((0,), (0,)), ((), ())]) ans = vmap(fun, in_axes=(None, 1))(x, y) expected = np.stack([fun(x, y[..., i]) for i in range(10)]) self.assertAllClose(ans, expected) def testDot(self): # these tests are based on @shoyer's notebook studying gufuncs def vecvec(a, b): dot = jnp.dot for ndim in range(1, max(a.ndim, b.ndim)): a_ax = 0 if a.ndim > ndim else None b_ax = 0 if b.ndim > ndim else None dot = vmap(dot, in_axes=(a_ax, b_ax)) return dot(a, b) assert vecvec(jnp.zeros((3,)), jnp.zeros((3,))).shape == () assert vecvec(jnp.zeros((2, 3)), jnp.zeros((3,))).shape == (2,) assert vecvec(jnp.zeros((4, 2, 3)), jnp.zeros((3,))).shape == (4, 2) def testDot2(self): R = np.random.RandomState(0).randn xs = R(10, 3) ys = R(10, 3) ans = vmap(jnp.dot)(xs, ys) expected = np.einsum('ni,ni->n', xs, ys) self.assertAllClose(ans, expected, check_dtypes=False) def testDot3(self): R = np.random.RandomState(0).randn xs = R(5, 8, 10) ys = R(10, 1) ans = vmap(jnp.dot, in_axes=(1, None))(xs, ys) expected = np.einsum('inj,jk->nik', xs, ys) self.assertAllClose(ans, expected, check_dtypes=False) def testDot4(self): R = np.random.RandomState(0).randn xs = R(3, 2) ys = R(3) ans = vmap(jnp.dot, in_axes=(1, None))(xs, ys) expected = np.einsum('ij,i->j', xs, ys) self.assertAllClose(ans, expected, check_dtypes=False) def testPad(self): R = np.random.RandomState(0).randn fun = lambda x: lax.pad(x, np.float32(0), [(1, 2, 1)]) x = R(5, 10).astype(np.float32) ans = vmap(fun)(x) expected_ans = jnp.stack(list(map(fun, x))) self.assertAllClose(ans, expected_ans, check_dtypes=False) fun = lambda x: lax.pad(x, np.float32(0), [(1, 2, 1), (0, 1, 0)]) x = R(5, 10, 3).astype(np.float32) ans = vmap(fun)(x) expected_ans = jnp.stack(list(map(fun, x))) self.assertAllClose(ans, expected_ans, check_dtypes=False) def testConcatenate(self): R = lambda *shape: np.random.RandomState(0).randn(*shape).astype(np.float32) fun = lambda *args: lax.concatenate(args, dimension=0) x, y, z = R(10, 2, 3), R(1, 10, 3), R(4, 3) ans = vmap(fun, in_axes=(0, 1, None))(x, y, z) expected_ans = np.concatenate([x, np.swapaxes(y, 0, 1), np.broadcast_to(z, (10, 4, 3))], 1) self.assertAllClose(ans, expected_ans, check_dtypes=False) fun = lambda *args: lax.concatenate(args, dimension=1) x, y, z = R(10, 2, 1), R(2, 3), R(2, 4, 10) ans = vmap(fun, in_axes=(0, None, 2))(x, y, z) expected_ans = np.concatenate([x, np.broadcast_to(y, (10, 2, 3)), np.moveaxis(z, 2, 0)], 2) self.assertAllClose(ans, expected_ans, check_dtypes=False) def testJacobianIssue54(self): # test modeling the code in https://github.com/google/jax/issues/54 def func(xs): return jnp.array(list(xs)) xs = jnp.ones((5, 1)) jacrev(func)(xs) # don't crash jacfwd(func)(xs) # don't crash def testAny(self): # test modeling the code in https://github.com/google/jax/issues/108 ans = vmap(jnp.any)(jnp.array([[True, False], [False, False]])) expected = jnp.array([True, False]) self.assertAllClose(ans, expected) def testHessian(self): # test based on code from sindhwani@google def fun(x, t): return jnp.sum(jnp.power(jnp.maximum(x, 0.0), 2)) + t x = np.array([-1., -0.5, 0., 0.5, 1.0]) ans = hessian(lambda x: fun(x, 0.0))(x) expected = np.array([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0.,0.5, 0., 0.], [0., 0., 0., 2., 0.], [0., 0., 0., 0., 2.]]) self.assertAllClose(ans, expected, check_dtypes=False) def testDynamicSlice(self): # test dynamic_slice via numpy indexing syntax # see https://github.com/google/jax/issues/1613 for an explanation of why we # need to use np rather than np to create x and idx x = jnp.arange(30).reshape((10, 3)) ans = vmap(lambda x, i: x[i], in_axes=(0, None))(x, 1) expected = x[:, 1] self.assertAllClose(ans, expected, check_dtypes=False) idx = jnp.array([0, 1, 2, 1, 0] * 2) ans = vmap(lambda x, i: x[i], in_axes=(0, 0))(x, idx) expected = x[np.arange(10), idx] self.assertAllClose(ans, expected, check_dtypes=False) x = jnp.arange(3) idx = jnp.array([0, 1, 2, 1, 0] * 2) ans = vmap(lambda x, i: x[i], in_axes=(None, 0))(x, idx) expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testDynamicUpdateSlice(self): x = np.random.randn(10, 3) y = np.random.randn(10) ans = vmap(lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0), in_axes=(0, 0, None))(x, y, 1) expected = x.copy() expected[:, 1] = y self.assertAllClose(ans, expected, check_dtypes=False) x = np.random.randn(3) idx = np.array([0, 1, 2, 1, 0] * 2) ans = vmap(lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0), in_axes=(None, 0, 0))(x, y, idx) expected = np.broadcast_to(x, (10, 3)).copy() expected[np.arange(10), idx] = y self.assertAllClose(ans, expected, check_dtypes=False) def testRandom(self): seeds = vmap(random.PRNGKey)(np.arange(10)) ans = vmap(partial(random.normal, shape=(3, 2)))(seeds) expected = np.stack([random.normal(random.PRNGKey(seed), (3, 2)) for seed in np.arange(10)]) self.assertAllClose(ans, expected, check_dtypes=False) assert len(np.unique(ans)) == 10 * 3 * 2 def testSort(self): v = np.arange(12)[::-1].reshape(3, 4) sv = vmap(partial(lax.sort, dimension=0), (0,))(v) self.assertAllClose(sv, v[:, ::-1]) sv = vmap(partial(lax.sort, dimension=-1), (0,))(v) self.assertAllClose(sv, v[:, ::-1]) sv = vmap(partial(lax.sort, dimension=0), (1,))(v) self.assertAllClose(sv, v[::-1, :].T) sv = vmap(partial(lax.sort, dimension=0), (1,), 1)(v) self.assertAllClose(sv, v[::-1, :]) def testSortKeyVal(self): k = np.arange(12)[::-1].reshape(3, 4) v = np.random.RandomState(0).permutation(12).reshape(3, 4) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 0))(k, v) self.assertAllClose(sk, k[:, ::-1]) self.assertAllClose(sv, v[:, ::-1]) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 1), 1)(k, v) self.assertAllClose(sk, k[::-1, :]) self.assertAllClose(sv, v[::-1, :]) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 1))(k, v.T) self.assertAllClose(sk, k[:, ::-1]) self.assertAllClose(sv, v[:, ::-1]) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 0))(k.T, v) self.assertAllClose(sk, k[:, ::-1]) self.assertAllClose(sv, v[:, ::-1]) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (None, 0))(k[0], v) self.assertAllClose(sk, np.broadcast_to(k[0, ::-1], (3, 4))) self.assertAllClose(sv, v[:, ::-1]) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, None))(k.T, v[0]) self.assertAllClose(sk, k[:, ::-1]) self.assertAllClose(sv, np.broadcast_to(v[0, ::-1], (3, 4))) def testConvGeneralDilated(self): W = jnp.array(np.random.randn(3, 3, 1, 5), dtype=np.float32) X = jnp.array(np.random.randn(10, 5, 5, 1), dtype=np.float32) def f(params, x): one = (1, 1) dimension_numbers = ('NHWC', 'HWIO', 'NHWC') y = lax.conv_general_dilated( x, params, one, 'SAME', one, one, dimension_numbers) return y grad_loss = grad(lambda params, x: jnp.mean(f(params, x) ** 2)) # Test forward prop. per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1))) per_example = jnp.reshape(per_example, (10, 5, 5, 5)) per_example_direct = f(W, X) self.assertAllClose(per_example, per_example_direct) # Test gradients. per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1))) per_example_direct = [] for i in range(10): g = grad_loss(W, jnp.reshape(X[i], (1, 5, 5, 1))) per_example_direct += [ jnp.reshape(g, (1,) + g.shape)] per_example_direct = jnp.concatenate(per_example_direct, axis=0) self.assertAllClose(per_example, per_example_direct, rtol=2e-2, atol=2e-3) def testConvGeneralDilatedBatchNotMajor(self): W = jnp.array(np.random.randn(3, 3, 1, 4), dtype=np.float32) x = jnp.array(np.random.randn(3, 5, 7, 5, 1), dtype=np.float32) def f(params, x): one = (1, 1) dimension_numbers = ('HNWC', 'HWIO', 'HWNC') y = lax.conv_general_dilated( x, params, one, 'SAME', one, one, dimension_numbers) return y per_example = vmap(partial(f, W))(x) per_example = jnp.reshape(jnp.transpose(per_example, (1, 2, 0, 3, 4)), (5, 5, 21, 4)) per_example_direct = f(W, jnp.reshape(jnp.transpose(x, (1, 0, 2, 3, 4)), (5, 21, 5, 1))) self.assertAllClose(per_example, per_example_direct) @parameterized.named_parameters( {"testcase_name": "_op={}".format(name), "op": op, "unit": unit} for name, op, unit in [("max", lax.max, -jnp.inf), ("min", lax.min, jnp.inf)]) def testMinMaxPool(self, op, unit): W = jnp.array(np.random.randn(3, 3, 1, 5), dtype=np.float32) X = jnp.array(np.random.randn(10, 5, 5, 1), dtype=np.float32) def f(params, x): one = (1, 1) dimension_numbers = ('NHWC', 'HWIO', 'NHWC') y = lax.conv_general_dilated( x, params, one, 'SAME', one, one, dimension_numbers) y = lax.reduce_window( y, unit, op, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME') return y grad_loss = grad(lambda params, x: jnp.mean(f(params, x) ** 2)) # Test forward prop. per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1))) per_example = jnp.reshape(per_example, (10, 5, 5, 5)) per_example_direct = f(W, X) self.assertAllClose(per_example, per_example_direct) # Test gradients. per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1))) per_example_direct = [] for i in range(10): g = grad_loss(W, jnp.reshape(X[i], (1, 5, 5, 1))) per_example_direct += [ jnp.reshape(g, (1,) + g.shape)] per_example_direct = jnp.concatenate(per_example_direct, axis=0) self.assertAllClose(per_example, per_example_direct, rtol=5e-2, atol=1e-3) def testSumPool(self): W = jnp.array(np.random.randn(3, 3, 1, 5), dtype=np.float32) X = jnp.array(np.random.randn(10, 5, 5, 1), dtype=np.float32) def f(params, x): one = (1, 1) dimension_numbers = ('NHWC', 'HWIO', 'NHWC') y = lax.conv_general_dilated( x, params, one, 'SAME', one, one, dimension_numbers) y = lax.reduce_window( y, 0.0, lax.add, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME') return y grad_loss = grad(lambda params, x: jnp.mean(f(params, x) ** 2)) # Test forward prop. per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1))) per_example = jnp.reshape(per_example, (10, 5, 5, 5)) per_example_direct = f(W, X) self.assertAllClose(per_example, per_example_direct) # Test gradients. per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1))) per_example_direct = [] for i in range(10): g = grad_loss(W, jnp.reshape(X[i], (1, 5, 5, 1))) per_example_direct += [ jnp.reshape(g, (1,) + g.shape)] per_example_direct = jnp.concatenate(per_example_direct, axis=0) self.assertAllClose(per_example, per_example_direct, rtol=3e-2, atol=1e-3) def testCumProd(self): x = jnp.arange(9).reshape(3, 3) + 1 y = vmap(lambda x: jnp.cumprod(x, axis=-1))(x) self.assertAllClose(np.cumprod(x, axis=1, dtype=int), y) def testSelect(self): pred = np.array([True, False]) on_true = np.array([0, 1]) on_false = np.array([2, 3]) ans = vmap(lax.select)(pred, on_true, on_false) expected = np.array([0, 3]) self.assertAllClose(ans, expected) pred = np.array([False, True]) on_true = np.array([0, 1]) on_false = np.array([2, 3]) ans = vmap(lax.select, (0, None, None))(pred, on_true, on_false) expected = np.array([[2, 3], [0, 1]]) self.assertAllClose(ans, expected) pred = True on_true = np.array([0, 1], np.float32) on_false = np.array(3, np.float32) ans = vmap(lax.select, (None, 0, None))(pred, on_true, on_false) expected = np.array([0, 1], np.float32) self.assertAllClose(ans, expected) pred = np.array([False, True]) on_true = np.array([0, 1], np.float32) on_false = np.array(3, np.float32) ans = vmap(lax.select, (0, 0, None))(pred, on_true, on_false) expected = np.array([3, 1], np.float32) self.assertAllClose(ans, expected) pred = np.array([False, True]) on_true = np.array([2], np.float32) on_false = np.array([[3, 4]], np.float32) ans = vmap(lax.select, (0, None, 1), 1)(pred, on_true, on_false) expected = np.array([[3, 2]], np.float32) self.assertAllClose(ans, expected) def testLaxLinalgCholesky(self): a = np.random.RandomState(0).randn(10, 5, 5).astype(np.float32) a = np.matmul(a, np.conj(np.swapaxes(a, -1, -2))) ans = vmap(lax.linalg.cholesky)(a) expected = np.linalg.cholesky(a) self.assertAllClose(ans, expected, check_dtypes=False, rtol=1e-4) b = np.random.RandomState(0).randn(10, 5, 5).astype(np.float32) b = np.matmul(b, np.conj(np.swapaxes(b, -1, -2))) b_trans = np.swapaxes(b, 0, 1) # shape is (5, 10, 5) ans = vmap(lax.linalg.cholesky, in_axes=1, out_axes=0)(b_trans) expected = np.linalg.cholesky(b) self.assertAllClose(ans, expected, check_dtypes=False, rtol=1e-4) def testLaxLinalgTriangularSolve(self): a = np.random.RandomState(0).randn(4, 10, 4).astype(np.float32) a += np.eye(4, dtype=jnp.float32)[:, None, :] b = np.random.RandomState(0).randn(5, 4, 10).astype(np.float32) ans = vmap(lax.linalg.triangular_solve, in_axes=(1, 2))(a, b) expected = np.stack( [lax.linalg.triangular_solve(a[:, i], b[..., i]) for i in range(10)]) self.assertAllClose(ans, expected) ans = vmap(lax.linalg.triangular_solve, in_axes=(None, 2))(a[:, 0], b) expected = np.stack( [lax.linalg.triangular_solve(a[:, 0], b[..., i]) for i in range(10)]) self.assertAllClose(ans, expected) ans = vmap(lax.linalg.triangular_solve, in_axes=(1, None))(a, b[..., 0]) expected = np.stack( [lax.linalg.triangular_solve(a[:, i], b[..., 0]) for i in range(10)]) self.assertAllClose(ans, expected) @parameterized.named_parameters( {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes} for dtype in [np.float32, np.int32] for axis, shape, idxs, dnums, slice_sizes in [ (0, (3, 5), np.array([[0], [2]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)), (1, (10, 3), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)), (1, (10, 3, 5), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)), (2, (10, 5, 3), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), ]) def testGatherBatchedOperand(self, axis, shape, dtype, idxs, dnums, slice_sizes): rng = jtu.rand_default(self.rng()) fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) operand = rng(shape, dtype) ans = vmap(fun, (axis, None))(operand, idxs) expected = np.stack([fun(operand[(slice(None),) * axis + (i,)], idxs) for i in range(operand.shape[axis])]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes} for dtype in [np.float32, np.float64] for axis, shape, idxs, dnums, slice_sizes in [ (0, (3, 5), np.array([[0], [2]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)), (1, (10, 3), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)), (1, (10, 3, 5), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)), (2, (10, 5, 3), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)) ]) def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums, slice_sizes): rng = jtu.rand_default(self.rng()) fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx)))) operand = rng(shape, dtype) ans = vmap(gfun, (axis, None))(operand, idxs) expected = np.stack([gfun(operand[(slice(None),) * axis + (i,)], idxs) for i in range(operand.shape[axis])]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes} for dtype in [np.float32, np.int32] for axis, shape, idxs, dnums, slice_sizes in [ (0, (5,), np.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)), (1, (10,), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None], lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)), (1, (10, 5), np.array([[0, 2, 1], [0, 3, 3]]).T[..., None], lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)), (0, (10, 5), np.array([[[0, 1], [2, 0]], [[1, 0], [2, 3]]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), ]) def testGatherBatchedIndices(self, axis, shape, dtype, idxs, dnums, slice_sizes): rng = jtu.rand_default(self.rng()) fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) operand = rng(shape, dtype) ans = vmap(fun, (None, axis))(operand, idxs) expected = np.stack([fun(operand, idxs[(slice(None),) * axis + (i,)]) for i in range(idxs.shape[axis])]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes} for dtype in [np.float32, np.float64] for axis, shape, idxs, dnums, slice_sizes in [ (0, (5,), np.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)), (1, (10,), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None], lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)), (1, (10, 5), np.array([[0, 2, 1], [0, 3, 3]]).T[..., None], lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)), (0, (10, 5), np.array([[[0, 1], [2, 0]], [[1, 0], [2, 3]]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), ]) def testGatherGradBatchedIndices(self, axis, shape, dtype, idxs, dnums, slice_sizes): rng = jtu.rand_default(self.rng()) fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx)))) operand = rng(shape, dtype) ans = vmap(gfun, (None, axis))(operand, idxs) expected = np.stack([gfun(operand, idxs[(slice(None),) * axis + (i,)]) for i in range(idxs.shape[axis])]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( {"testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs, dnums, slice_sizes), "op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes} for dtype in [np.float32, np.int32] for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [ (0, 0, (2, 5), np.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)), (1, 1, (10, 2), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None], lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)), (0, 1, (2, 10, 5,), np.array([[[0, 2, 1], [0, 3, 3]]]).T, lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)), (2, 0, (10, 5, 2), np.array([[[0, 2], [1, 0]], [[1, 0], [2, 0]]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), ]) def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums, slice_sizes): rng = jtu.rand_default(self.rng()) fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) operand = rng(shape, dtype) assert operand.shape[op_axis] == idxs.shape[idxs_axis] ans = vmap(fun, (op_axis, idxs_axis))(operand, idxs) expected = np.stack([fun(operand[(slice(None),) * op_axis + (i,)], idxs[(slice(None),) * idxs_axis + (i,)]) for i in range(idxs.shape[idxs_axis])]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( {"testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs, dnums, slice_sizes), "op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes} for dtype in [np.float32] for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [ (0, 0, (2, 5), np.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)), (1, 1, (10, 2), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None], lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)), (0, 1, (2, 10, 5,), np.array([[[0, 2, 1], [0, 3, 3]]]).T, lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)), (2, 0, (10, 5, 2), np.array([[[0, 2], [1, 0]], [[1, 0], [2, 0]]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), ]) def testGatherGradBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums, slice_sizes): rng = jtu.rand_default(self.rng()) fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx)))) operand = rng(shape, dtype) assert operand.shape[op_axis] == idxs.shape[idxs_axis] ans = vmap(gfun, (op_axis, idxs_axis))(operand, idxs) expected = np.stack([gfun(operand[(slice(None),) * op_axis + (i,)], idxs[(slice(None),) * idxs_axis + (i,)]) for i in range(idxs.shape[idxs_axis])]) self.assertAllClose(ans, expected, check_dtypes=False) def testNumpyIndexing1(self): a = jnp.arange(2 * 3 * 4).reshape((2, 3, 4)) ind = np.array([[0, 1], [2, 0]]) def f(a, ind): return a[:, ind] expected = np.stack([f(a, ind[i, :]) for i in range(ind.shape[0])]) ans = vmap(f, (None, 0))(a, ind) assert np.all(ans == expected) def testNumpyIndexing2(self): a = jnp.arange(2 * 3 * 4).reshape((2, 3, 4)) def f(a): inds = jnp.array([0, 2]) return a[:, inds] ans = vmap(f)(a) expected = np.stack([f(a[:, i, :]) for i in range(a.shape[1])], axis=1) assert np.all(ans == expected) def testTranspose(self): x = np.arange(4 * 3 * 3).reshape((4, 3, 3)) ans = vmap(lambda x: x + x.T)(x) expected = x + np.swapaxes(x, -1, -2) self.assertAllClose(ans, expected, check_dtypes=False) def testTransposePermutation(self): x = np.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5)) ans = vmap(lambda x: jnp.transpose(x, (1, 0, 2)))(x) expected = np.transpose(x, (0, 2, 1, 3)) self.assertAllClose(ans, expected, check_dtypes=False) x = np.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5)) ans = vmap(lambda x: jnp.transpose(x, (1, 2, 0)))(x) expected = np.transpose(x, (0, 2, 3, 1)) self.assertAllClose(ans, expected, check_dtypes=False) x = np.arange(6 * 3 * 4 * 5).reshape((3, 4, 6, 5)) ans = vmap(lambda x: jnp.transpose(x, (1, 2, 0)), in_axes=2)(x) expected = np.transpose(x, (2, 1, 3, 0)) self.assertAllClose(ans, expected, check_dtypes=False) def testIssue354(self): psd_mat = np.random.randn(20, 10) psd_mat = psd_mat.T.dot(psd_mat) vec = np.random.randn(10) def f(scale): scaled_mat = scale * psd_mat chol = jnp.linalg.cholesky(scaled_mat) return -0.5 * jnp.sum((jnp.einsum('ij,j->i', chol, vec))**2) vmapped_f = vmap(f) vmapped_f_grad = grad(lambda x: jnp.sum(vmapped_f(x))) scales = np.array([[0.1], [0.2], [0.3], [0.4], [0.5]]) ans = vmapped_f_grad(scales) # don't crash! expected = np.stack([grad(f)(scale) for scale in scales]) self.assertAllClose(ans, expected, check_dtypes=False, rtol=jtu.default_gradient_tolerance) def testIssue387(self): # https://github.com/google/jax/issues/387 R = np.random.RandomState(0).rand(100, 2) def dist_sq(R): dR = R[:, jnp.newaxis, :] - R[jnp.newaxis, :, :] zero = jnp.zeros_like(dR) dR = dR - jnp.where(jnp.abs(dR) < 0.5, zero, 0.5 * jnp.sign(dR)) return jnp.sum(dR ** 2, axis=2) @jit def f(R): _ = dist_sq(R) return jnp.sum(R ** 2) _ = hessian(f)(R) # don't crash on UnshapedArray def testIssue489(self): def f(key): def body_fn(uk): key = uk[1] u = random.uniform(key, ()) key, _ = random.split(key) return u, key u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn, (1., key)) return u print(vmap(f)(random.split(random.PRNGKey(0), 2))) # no crash def testEmptyTuples(self): # Ensure there is no crash when a vectorized input contains empty tuples. result = vmap(lambda x, _: x + 1)(np.array([0, 1]), ()) self.assertAllClose(result, np.array([1, 2]), check_dtypes=False) # Ensure there is no crash when a vectorized output contains empty tuples. result, empty_tuple = vmap(lambda x: (x + 1, ()))(np.array([0, 1])) self.assertAllClose(result, np.array([1, 2]), check_dtypes=False) self.assertEqual((), empty_tuple) def testIndexAddBatchedIndexesOnly(self): f = lambda x, idx, y: jax.ops.index_add(x, jax.ops.index[idx], y) result = vmap(f, (None, 0, None))(np.zeros((10,)), np.arange(10,), 1.) self.assertAllClose(result, np.eye(10), check_dtypes=False) def testIssue1170(self): def f(index1, index2): return jnp.arange(36).reshape(6, 6)[index1, index2] g = jax.jit(jax.pmap(f)) ans = g(index1=np.asarray([1]), index2=np.asarray([2])) expected = g(np.asarray([1]), np.asarray([2])) self.assertAllClose(ans, expected) def testIssue3883(self): def scalar_f(x): return lax.dynamic_slice(x, [], []) xs = jnp.array([1, 2, 3, 4]) ans = vmap(scalar_f)(xs) expected = jnp.array([scalar_f(x) for x in xs]) self.assertAllClose(ans, expected) def scalar_f2(x): return lax.dynamic_update_slice(x, 7, []) xs = jnp.array([1, 2, 3, 4]) ans = vmap(scalar_f2)(xs) expected = jnp.array([scalar_f2(x) for x in xs]) self.assertAllClose(ans, expected) @parameterized.named_parameters( {"testcase_name": "_{}_vmap_names={}_collective_names={}".format( collective.__name__.replace(" ", ""), vmap_names, collective_names), "collective": collective, "bulk_op": bulk_op, "vmap_names": vmap_names, "collective_names": collective_names} for collective, bulk_op in [(lax.psum, jnp.sum), (lax.pmax, jnp.max), (lax.pmin, jnp.min)] for vmap_names in [('i',), ('i', 'j'), ('j', 'i')] for collective_names in it.permutations(vmap_names)) @skipIf(not jax.config.omnistaging_enabled, "vmap collectives only supported when omnistaging is enabled") def testCommAssocCollective(self, collective, bulk_op, vmap_names, collective_names): x = jnp.arange(3 * 4 * 5).reshape((3, 4, 5)) # To test relative permutations of the order in which the axis names appear # in the primitive call versus the order the vmaps are applied, we always # apply vmaps in the order of the `vmap_names` argument, and apply the # collective with names according to the `collective_names` argument. f = lambda x: x - collective(x, collective_names) for axis_name in vmap_names: f = vmap(f, axis_name=axis_name) self.assertAllClose(f(x), x - bulk_op(x, axis=tuple(range(len(vmap_names))))) @skipIf(not jax.config.omnistaging_enabled, "vmap collectives only supported when omnistaging is enabled") def testPPermute(self): nelem = 10 ntests = 10 x = np.arange(nelem) rng = np.random.RandomState(1) for i in range(ntests): perm = np.arange(nelem) rng.shuffle(perm) perm_pairs = np.stack([np.arange(nelem), perm], axis=-1) rng.shuffle(perm_pairs) self.assertAllClose( vmap(lambda x: x - lax.ppermute(x, 'i', perm_pairs), axis_name='i')(x), x - x[perm]) @parameterized.named_parameters( {"testcase_name": f"_split={split_axis}_concat={concat_axis}_vmap={vmap_axis}", "split_axis": split_axis, "concat_axis": concat_axis, "vmap_axis": vmap_axis} for split_axis, concat_axis, vmap_axis in it.product(range(3), range(3), range(4))) @skipIf(not jax.config.omnistaging_enabled, "vmap collectives only supported when omnistaging is enabled") def testAllToAllShape(self, vmap_axis, split_axis, concat_axis): d = vmap_axis def shape_fun(x, out_d): shape = list(x.shape) vmap_dim_id = shape.pop(d) split_dim_id = shape.pop(split_axis) shape.insert(concat_axis, vmap_dim_id) shape.insert(out_d, split_dim_id) return tuple(shape) shape = (2, 3, 4, 5) x = np.arange(np.prod(shape)).reshape(shape) rule = batching.collective_rules[lax.all_to_all_p] y, out_d = rule(None, (x,), (d,), None, split_axis, concat_axis, None) exp_shape = shape_fun(x, out_d) self.assertEqual(y.shape, exp_shape) @parameterized.named_parameters( {"testcase_name": f"_split={split_axis}_concat={concat_axis}_vmap={vmap_axis}", "split_axis": split_axis, "concat_axis": concat_axis, "vmap_axis": vmap_axis} for split_axis, concat_axis, vmap_axis in it.product(range(2), range(2), range(3))) @skipIf(not jax.config.omnistaging_enabled, "vmap collectives only supported when omnistaging is enabled") def testAllToAllSplitAxis(self, vmap_axis, split_axis, concat_axis): raise SkipTest("all_to_all split axis broken after #4835") # TODO(mattjj,apaszke) shape = (4, 4, 4) x = np.arange(np.prod(shape)).reshape(shape) @partial(vmap, in_axes=vmap_axis, axis_name='i') @partial(vmap, in_axes=vmap_axis, axis_name='j') def f(x): return lax.all_to_all(x, ('i', 'j'), split_axis, concat_axis) unroll_shape = (2, 2, *shape[1:]) unroll_shape = list(shape) unroll_shape[vmap_axis:vmap_axis+1] = (2, 2) x_unroll = x.reshape(unroll_shape) y_unrolled = f(x_unroll) y = y_unrolled.reshape(shape) if vmap_axis <= split_axis: split_axis += 1 ref = jnp.moveaxis(x, (vmap_axis, split_axis), (concat_axis + 1, 0)) self.assertAllClose(y, ref) def testNegativeAxes(self): x = np.arange(3*4*5).reshape(3, 4, 5) self.assertAllClose(jax.vmap(jnp.sum, in_axes=-3)(x), jnp.sum(x, axis=(1, 2))) self.assertAllClose(jax.vmap(jnp.sum, in_axes=-2)(x), jnp.sum(x, axis=(0, 2))) self.assertAllClose(jax.vmap(jnp.sum, in_axes=-1)(x), jnp.sum(x, axis=(0, 1))) with self.assertRaisesRegex(ValueError, "vmap got arg 0 of rank 3 but axis to be mapped -4"): jax.vmap(jnp.sum, in_axes=-4)(x) id = lambda y: y self.assertAllClose(x, jax.vmap(id, in_axes=0, out_axes=-3)(x)) self.assertAllClose(x.transpose(1, 0, 2), jax.vmap(id, in_axes=0, out_axes=-2)(x)) self.assertAllClose(x.transpose(1, 2, 0), jax.vmap(id, in_axes=0, out_axes=-1)(x)) with self.assertRaisesRegex(ValueError, "axis -4 is out of bounds.*"): jax.vmap(id, in_axes=0, out_axes=-4)(x) self.assertAllClose( np.full((5,), 7), jax.vmap(lambda *xs: xs, in_axes=(0, None), out_axes=(0, -1))( np.arange(5), 7)[1]) with self.assertRaisesRegex(ValueError, "axis -2 is out of bounds.*"): jax.vmap(lambda *xs: xs, in_axes=(0, None), out_axes=(0, -2))( np.arange(5), 7) @skipIf(not jax.config.omnistaging_enabled, "vmap collectives only supported when omnistaging is enabled") def testAxisIndex(self): x = np.arange(10) self.assertAllClose( vmap(lambda x: x - lax.axis_index('i'), axis_name='i')(x), x - np.arange(x.shape[0])) @skipIf(not jax.config.omnistaging_enabled, "vmap collectives only supported when omnistaging is enabled") def testCollectivePdot(self): def f(x, y): return lax.pdot(x, y, 'i') rng = np.random.RandomState(0) x = rng.randn(3, 4) y = rng.randn(4, 5) z = vmap(f, axis_name='i', in_axes=(1, 0), out_axes=None)(x, y) self.assertAllClose(z, jnp.dot(x, y)) x = rng.randn(4, 3) y = rng.randn(4, 5) z = vmap(f, axis_name='i', in_axes=(0, 0), out_axes=None)(x, y) self.assertAllClose(z, jnp.dot(x.T, y)) @skipIf(not jax.config.omnistaging_enabled, "vmap collectives only supported when omnistaging is enabled") def testCollectivePdotBatching(self): def f(x, y): return lax.pdot(x, y, 'i') rng = np.random.RandomState(1) xs = rng.randn(2, 8, 3) ys = rng.randn(2, 3, 5) zs = vmap(vmap(f, axis_name='i', in_axes=(1, 0), out_axes=None))(xs, ys) self.assertAllClose(zs, jnp.einsum('nij,njk->nik', xs, ys)) @skipIf(not jax.config.omnistaging_enabled, "vmap collectives only supported when omnistaging is enabled") def testPdotJvp(self): def f(x, y): return lax.pdot(x, y, 'i') rng = np.random.RandomState(1) x = rng.randn(3, 4) x_dot = rng.randn(*x.shape) y = rng.randn(4, 5) y_dot = rng.randn(*y.shape) z, z_dot = vmap(lambda x, y, x_dot, y_dot: jvp(f, (x, y), (x_dot, y_dot)), axis_name='i', in_axes=(1, 0, 1, 0), out_axes=None)(x, y, x_dot, y_dot) self.assertAllClose(z, jnp.dot(x, y)) self.assertAllClose(z_dot, jnp.dot(x_dot, y) + jnp.dot(x, y_dot)) @skipIf(not jax.config.omnistaging_enabled, "vmap collectives only supported when omnistaging is enabled") def testPdotVjp(self): def f(x, y): return lax.pdot(x, y, 'i') rng = np.random.RandomState(1) x = rng.randn(3, 4) y = rng.randn(4, 5) z_bar = rng.randn(3, 5) x_bar, y_bar = vmap(lambda x, y, z_bar: vjp(f, x, y)[1](z_bar), axis_name='i', in_axes=(1, 0, None), out_axes=(1, 0))(x, y, z_bar) self.assertAllClose(x_bar, jnp.dot(z_bar, y.T)) self.assertAllClose(y_bar, jnp.dot(x.T, z_bar))
np.array([2], dtype=np.int32), np.array([2, 4], dtype=np.int32), np.array([[2, 4], [5, 6]], dtype=np.int32), np.array([0, 1, 10], dtype=np.int32), # Index out of bounds np.array([0, 1, 2, -1], dtype=np.int32), # Index out of bounds ] for axis in [0, 1, 2]] + # Directly from lax.gather in lax_test.py. [Harness( f"_shape={shape}_idxs_shape={idxs.shape}_dnums={dnums}_slice_sizes={slice_sizes}", lambda op, idxs, dnums, slice_sizes: lax.gather(op, idxs, dimension_numbers=dnums, slice_sizes=slice_sizes), [RandArg(shape, np.float32), idxs, StaticArg(dnums), StaticArg(slice_sizes)]) for shape, idxs, dnums, slice_sizes in [ ((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)), ((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)), ((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)), ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), ] ] ) lax_scatter = tuple(
class BatchingTest(jtu.JaxTestCase): def testConstantFunction(self): ans = vmap(lambda x: 3)(onp.ones(4)) expected = 3 * onp.ones(4) self.assertAllClose(ans, expected, check_dtypes=False) def testNestedBatchingMatMat(self): matvec = vmap(np.vdot, in_axes=(0, None)) matmat = vmap(matvec, in_axes=(None, 1), out_axes=1) R = onp.random.RandomState(0).randn A = R(4, 3) B = R(3, 2) ans = matmat(A, B) expected = onp.dot(A, B) self.assertAllClose(ans, expected, check_dtypes=False, rtol={onp.float32: 1e-2} if jtu.device_under_test() == "tpu" else None) jaxpr = make_jaxpr(matmat)(A, B) self.assertEqual(len(jaxpr.jaxpr.eqns), 1) def testPerExampleGradients(self): def predict(params, inputs): for W, b in params: outputs = np.dot(W, inputs) + b inputs = np.tanh(outputs) return outputs def loss(params, data): inputs, targets = data predictions = predict(params, inputs) return np.sum((predictions - targets)**2) batch_size = 5 layer_sizes = [3, 2, 4] R = onp.random.RandomState(0).randn params = [(R(m, n), R(m)) for m, n in zip(layer_sizes[1:], layer_sizes[:-1])] input_vec = R(3) target_vec = R(4) datum = (input_vec, target_vec) input_batch = R(5, 3) target_batch = R(5, 4) batch = (input_batch, target_batch) ans = vmap(partial(grad(loss), params))(batch) for ans_pair, param_pair in zip(ans, params): dW, db = ans_pair W, b = param_pair self.assertEqual(dW.shape, (batch_size, ) + W.shape) self.assertEqual(db.shape, (batch_size, ) + b.shape) def testJacobians(self): def jacbwd(f, x): y, pullback = vjp(f, x) std_basis = onp.eye(onp.size(y)).reshape((-1, ) + onp.shape(y)) jac_flat, = vmap(pullback, out_axes=onp.ndim(y))(std_basis) return jac_flat.reshape(onp.shape(y) + onp.shape(x)) def jacfwd(f, x): pushfwd = lambda v: jvp(f, (x, ), (v, )) std_basis = onp.eye(onp.size(x)).reshape((-1, ) + onp.shape(x)) y, jac_flat = vmap(pushfwd, out_axes=(None, 0))(std_basis) return jac_flat.reshape(onp.shape(y) + onp.shape(x)) R = onp.random.RandomState(0).randn A = R(4, 3) b = R(4) f = lambda x: np.tanh(np.dot(A, x) + b) x = R(3) self.assertAllClose(jacfwd(f, x), jacbwd(f, x), check_dtypes=False) def testBatchOfCompile(self): side = [] @jit def f(x): side.append(None) return x + x g = jit(vmap(f)) self.assertAllClose(g(onp.ones(2)), 2 * onp.ones(2), check_dtypes=False) self.assertEqual(len(side), 1) self.assertAllClose(g(2 * onp.ones(2)), 4 * onp.ones(2), check_dtypes=False) self.assertEqual(len(side), 1) def testSliceLax(self): fun = lambda x: lax.slice(x, (2, ), (4, )) R = onp.random.RandomState(0).randn x = R(5, 10) ans = vmap(fun)(x) expected_ans = x[:, 2:4] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testSliceNumpy(self): fun = lambda x: x[:, 2] R = onp.random.RandomState(0).randn x = R(10, 5, 3, 7) ans = vmap(fun)(x) expected_ans = x[:, :, 2] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testRevLax(self): fun = lambda x: lax.rev(x, [0]) R = onp.random.RandomState(0).randn x = R(2, 3) ans = vmap(fun)(x) expected_ans = x[:, ::-1] self.assertAllClose(ans, expected_ans, check_dtypes=False) ans = vmap(fun, (1, ), 1)(x) expected_ans = x[::-1, :] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testRevNumpy(self): fun = lambda x: x[:, ::-1] R = onp.random.RandomState(0).randn x = R(3, 2, 4) ans = vmap(fun)(x) expected_ans = x[:, :, ::-1] self.assertAllClose(ans, expected_ans, check_dtypes=False) ans = vmap(fun, (1, ), 1)(x) expected_ans = x[:, :, ::-1] self.assertAllClose(ans, expected_ans, check_dtypes=False) ans = vmap(fun, (2, ), 2)(x) expected_ans = x[:, ::-1, :] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testNpMaximum(self): fun = lambda x: np.maximum(x, 0.0) R = onp.random.RandomState(0).randn x = R(10, 5, 3, 7) ans = vmap(fun)(x) expected_ans = onp.maximum(x, 0.0) self.assertAllClose(ans, expected_ans, check_dtypes=False) def testNpGtrThan(self): R = onp.random.RandomState(0).randn x = R(10, 5, 3, 7) ans = vmap(lambda x: x > 1.0)(x) expected_ans = x > 1.0 self.assertAllClose(ans, expected_ans, check_dtypes=True) def testNpMaximumPerExampleGrad(self): R = onp.random.RandomState(0).randn x = R(10, 5) W = R(5, 5) fun = lambda W, x: np.sum(np.maximum(np.dot(x, W), 0.0)**2) ans = vmap(partial(grad(fun), W))(x) W_t = np.transpose(W) for i in range(10): x_ex = x[i:i + 1] expected_ans = 2.0 * np.dot( np.maximum(np.dot(W_t, np.transpose(x_ex)), 0.0), x_ex) expected_ans = np.transpose(expected_ans) self.assertAllClose(ans[i], expected_ans, check_dtypes=False, atol={onp.float32: 5e-2} if jtu.device_under_test() == "tpu" else None) def testDotGeneral(self): R = onp.random.RandomState(0).randn x = R(10, 3, 4, 5) y = R(10, 3, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ), (0, ))]) ans = vmap(fun)(x, y) expected = lax.dot_general(x, y, [((3, ), (2, )), ((0, 1), (0, 1))]) self.assertAllClose(ans, expected, check_dtypes=True) x = R(3, 4, 10, 5) y = R(3, 10, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ), (0, ))]) ans = vmap(fun, in_axes=(2, 1))(x, y) expected = onp.stack( [fun(x[..., i, :], y[:, i, ...]) for i in range(10)]) self.assertAllClose(ans, expected, check_dtypes=True) x = R(3, 4, 5, 10) y = R(3, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ), (0, ))]) ans = vmap(fun, in_axes=(3, None))(x, y) expected = onp.stack([fun(x[..., i], y) for i in range(10)]) self.assertAllClose(ans, expected, check_dtypes=True) x = R(3, 4, 5) y = R(3, 5, 10, 6) fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ), (0, ))]) ans = vmap(fun, in_axes=(None, 2))(x, y) expected = onp.stack([fun(x, y[..., i, :]) for i in range(10)]) self.assertAllClose(ans, expected, check_dtypes=True) x = R(4) y = R(4, 10) fun = lambda x, y: lax.dot_general(x, y, [((0, ), (0, )), ((), ())]) ans = vmap(fun, in_axes=(None, 1))(x, y) expected = onp.stack([fun(x, y[..., i]) for i in range(10)]) self.assertAllClose(ans, expected, check_dtypes=True) def testDot(self): # these tests are based on @shoyer's notebook studying gufuncs def vecvec(a, b): dot = np.dot for ndim in range(1, max(a.ndim, b.ndim)): a_ax = 0 if a.ndim > ndim else None b_ax = 0 if b.ndim > ndim else None dot = vmap(dot, in_axes=(a_ax, b_ax)) return dot(a, b) assert vecvec(np.zeros((3, )), np.zeros((3, ))).shape == () assert vecvec(np.zeros((2, 3)), np.zeros((3, ))).shape == (2, ) assert vecvec(np.zeros((4, 2, 3)), np.zeros((3, ))).shape == (4, 2) def testDot2(self): R = onp.random.RandomState(0).randn xs = R(10, 3) ys = R(10, 3) ans = vmap(np.dot)(xs, ys) expected = onp.einsum('ni,ni->n', xs, ys) self.assertAllClose(ans, expected, check_dtypes=False) def testDot3(self): R = onp.random.RandomState(0).randn xs = R(5, 8, 10) ys = R(10, 1) ans = vmap(np.dot, in_axes=(1, None))(xs, ys) expected = onp.einsum('inj,jk->nik', xs, ys) self.assertAllClose(ans, expected, check_dtypes=False) def testDot4(self): R = onp.random.RandomState(0).randn xs = R(3, 2) ys = R(3) ans = vmap(np.dot, in_axes=(1, None))(xs, ys) expected = onp.einsum('ij,i->j', xs, ys) self.assertAllClose(ans, expected, check_dtypes=False) def testDot5(self): f = vmap(partial(np.einsum, 'ij,j->i'), (None, 0)) jaxpr = make_jaxpr(f)(np.zeros((1000, 1000)), np.zeros((1000, 1000))) assert "broadcast" not in str(jaxpr) def testPad(self): R = onp.random.RandomState(0).randn fun = lambda x: lax.pad(x, onp.float32(0), [(1, 2, 1)]) x = R(5, 10).astype(onp.float32) ans = vmap(fun)(x) expected_ans = np.stack(list(map(fun, x))) self.assertAllClose(ans, expected_ans, check_dtypes=False) fun = lambda x: lax.pad(x, onp.float32(0), [(1, 2, 1), (0, 1, 0)]) x = R(5, 10, 3).astype(onp.float32) ans = vmap(fun)(x) expected_ans = np.stack(list(map(fun, x))) self.assertAllClose(ans, expected_ans, check_dtypes=False) def testConcatenate(self): R = lambda *shape: onp.random.RandomState(0).randn(*shape).astype( onp.float32) fun = lambda *args: lax.concatenate(args, dimension=0) x, y, z = R(10, 2, 3), R(1, 10, 3), R(4, 3) ans = vmap(fun, in_axes=(0, 1, None))(x, y, z) expected_ans = onp.concatenate( [x, onp.swapaxes(y, 0, 1), onp.broadcast_to(z, (10, 4, 3))], 1) self.assertAllClose(ans, expected_ans, check_dtypes=False) fun = lambda *args: lax.concatenate(args, dimension=1) x, y, z = R(10, 2, 1), R(2, 3), R(2, 4, 10) ans = vmap(fun, in_axes=(0, None, 2))(x, y, z) expected_ans = onp.concatenate( [x, onp.broadcast_to(y, (10, 2, 3)), onp.moveaxis(z, 2, 0)], 2) self.assertAllClose(ans, expected_ans, check_dtypes=False) def testJacobianIssue54(self): # test modeling the code in https://github.com/google/jax/issues/54 def func(xs): return np.array([x for x in xs]) xs = np.ones((5, 1)) jacrev(func)(xs) # don't crash jacfwd(func)(xs) # don't crash def testAny(self): # test modeling the code in https://github.com/google/jax/issues/108 ans = vmap(np.any)(np.array([[True, False], [False, False]])) expected = np.array([True, False]) self.assertAllClose(ans, expected, check_dtypes=True) @jtu.skip_on_devices("tpu") def testHessian(self): # test based on code from sindhwani@google def fun(x, t): return np.sum(np.power(np.maximum(x, 0.0), 2)) + t x = onp.array([-1., -0.5, 0., 0.5, 1.0]) ans = hessian(lambda x: fun(x, 0.0))(x) expected = onp.array([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0.5, 0., 0.], [0., 0., 0., 2., 0.], [0., 0., 0., 0., 2.]]) self.assertAllClose(ans, expected, check_dtypes=False) def testDynamicSlice(self): # test dynamic_slice via numpy indexing syntax # see https://github.com/google/jax/issues/1613 for an explanation of why we # need to use np rather than onp to create x and idx x = np.arange(30).reshape((10, 3)) ans = vmap(lambda x, i: x[i], in_axes=(0, None))(x, 1) expected = x[:, 1] self.assertAllClose(ans, expected, check_dtypes=False) idx = np.array([0, 1, 2, 1, 0] * 2) ans = vmap(lambda x, i: x[i], in_axes=(0, 0))(x, idx) expected = x[onp.arange(10), idx] self.assertAllClose(ans, expected, check_dtypes=False) x = np.arange(3) idx = np.array([0, 1, 2, 1, 0] * 2) ans = vmap(lambda x, i: x[i], in_axes=(None, 0))(x, idx) expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testDynamicUpdateSlice(self): x = onp.random.randn(10, 3) y = onp.random.randn(10) ans = vmap( lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0), in_axes=(0, 0, None))(x, y, 1) expected = x.copy() expected[:, 1] = y self.assertAllClose(ans, expected, check_dtypes=False) x = onp.random.randn(3) idx = onp.array([0, 1, 2, 1, 0] * 2) ans = vmap( lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0), in_axes=(None, 0, 0))(x, y, idx) expected = onp.broadcast_to(x, (10, 3)).copy() expected[onp.arange(10), idx] = y self.assertAllClose(ans, expected, check_dtypes=False) def testRandom(self): seeds = vmap(random.PRNGKey)(onp.arange(10)) ans = vmap(partial(random.normal, shape=(3, 2)))(seeds) expected = onp.stack([ random.normal(random.PRNGKey(seed), (3, 2)) for seed in onp.arange(10) ]) self.assertAllClose(ans, expected, check_dtypes=False) assert len(onp.unique(ans)) == 10 * 3 * 2 def testSort(self): v = onp.arange(12)[::-1].reshape(3, 4) sv = vmap(partial(lax.sort, dimension=0), (0, ))(v) self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) sv = vmap(partial(lax.sort, dimension=-1), (0, ))(v) self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) sv = vmap(partial(lax.sort, dimension=0), (1, ))(v) self.assertAllClose(sv, v[::-1, :].T, check_dtypes=True) sv = vmap(partial(lax.sort, dimension=0), (1, ), 1)(v) self.assertAllClose(sv, v[::-1, :], check_dtypes=True) def testSortKeyVal(self): k = onp.arange(12)[::-1].reshape(3, 4) v = onp.random.RandomState(0).permutation(12).reshape(3, 4) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 0))(k, v) self.assertAllClose(sk, k[:, ::-1], check_dtypes=True) self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 1), 1)(k, v) self.assertAllClose(sk, k[::-1, :], check_dtypes=True) self.assertAllClose(sv, v[::-1, :], check_dtypes=True) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 1))(k, v.T) self.assertAllClose(sk, k[:, ::-1], check_dtypes=True) self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 0))(k.T, v) self.assertAllClose(sk, k[:, ::-1], check_dtypes=True) self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (None, 0))(k[0], v) self.assertAllClose(sk, onp.broadcast_to(k[0, ::-1], (3, 4)), check_dtypes=True) self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, None))(k.T, v[0]) self.assertAllClose(sk, k[:, ::-1], check_dtypes=True) self.assertAllClose(sv, onp.broadcast_to(v[0, ::-1], (3, 4)), check_dtypes=True) def testConvGeneralDilated(self): W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32) X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32) def f(params, x): one = (1, 1) dimension_numbers = ('NHWC', 'HWIO', 'NHWC') y = lax.conv_general_dilated(x, params, one, 'SAME', one, one, dimension_numbers) return y grad_loss = grad(lambda params, x: np.mean(f(params, x)**2)) # Test forward prop. per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example = np.reshape(per_example, (10, 5, 5, 5)) per_example_direct = f(W, X) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) # Test gradients. per_example = vmap(partial(grad_loss, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example_direct = [] for i in range(10): g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1))) per_example_direct += [np.reshape(g, (1, ) + g.shape)] per_example_direct = np.concatenate(per_example_direct, axis=0) self.assertAllClose(per_example, per_example_direct, check_dtypes=True, rtol=2e-2) def testConvGeneralDilatedBatchNotMajor(self): W = np.array(onp.random.randn(3, 3, 1, 4), dtype=onp.float32) x = np.array(onp.random.randn(3, 5, 7, 5, 1), dtype=onp.float32) def f(params, x): one = (1, 1) dimension_numbers = ('HNWC', 'HWIO', 'HWNC') y = lax.conv_general_dilated(x, params, one, 'SAME', one, one, dimension_numbers) return y per_example = vmap(partial(f, W))(x) per_example = np.reshape(np.transpose(per_example, (1, 2, 0, 3, 4)), (5, 5, 21, 4)) per_example_direct = f( W, np.reshape(np.transpose(x, (1, 0, 2, 3, 4)), (5, 21, 5, 1))) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) @parameterized.named_parameters({ "testcase_name": "_op={}".format(name), "op": op, "unit": unit } for name, op, unit in [("max", lax.max, -np.inf), ("min", lax.min, np.inf)]) def testMinMaxPool(self, op, unit): W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32) X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32) def f(params, x): one = (1, 1) dimension_numbers = ('NHWC', 'HWIO', 'NHWC') y = lax.conv_general_dilated(x, params, one, 'SAME', one, one, dimension_numbers) y = lax.reduce_window(y, unit, op, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME') return y grad_loss = grad(lambda params, x: np.mean(f(params, x)**2)) # Test forward prop. per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example = np.reshape(per_example, (10, 5, 5, 5)) per_example_direct = f(W, X) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) # Test gradients. per_example = vmap(partial(grad_loss, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example_direct = [] for i in range(10): g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1))) per_example_direct += [np.reshape(g, (1, ) + g.shape)] per_example_direct = np.concatenate(per_example_direct, axis=0) self.assertAllClose(per_example, per_example_direct, check_dtypes=True, rtol=5e-2) def testSumPool(self): W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32) X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32) def f(params, x): one = (1, 1) dimension_numbers = ('NHWC', 'HWIO', 'NHWC') y = lax.conv_general_dilated(x, params, one, 'SAME', one, one, dimension_numbers) y = lax.reduce_window(y, 0.0, lax.add, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME') return y grad_loss = grad(lambda params, x: np.mean(f(params, x)**2)) # Test forward prop. per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example = np.reshape(per_example, (10, 5, 5, 5)) per_example_direct = f(W, X) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) # Test gradients. per_example = vmap(partial(grad_loss, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example_direct = [] for i in range(10): g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1))) per_example_direct += [np.reshape(g, (1, ) + g.shape)] per_example_direct = np.concatenate(per_example_direct, axis=0) self.assertAllClose(per_example, per_example_direct, check_dtypes=True, rtol=3e-2) def testCumProd(self): x = np.arange(9).reshape(3, 3) + 1 y = vmap(lambda x: np.cumprod(x, axis=-1))(x) self.assertAllClose(onp.cumprod(x, axis=1, dtype=np.int_), y, check_dtypes=True) def testSelect(self): pred = onp.array([True, False]) on_true = onp.array([0, 1]) on_false = onp.array([2, 3]) ans = vmap(lax.select)(pred, on_true, on_false) expected = onp.array([0, 3]) self.assertAllClose(ans, expected, check_dtypes=True) pred = onp.array([False, True]) on_true = onp.array([0, 1]) on_false = onp.array([2, 3]) ans = vmap(lax.select, (0, None, None))(pred, on_true, on_false) expected = onp.array([[2, 3], [0, 1]]) self.assertAllClose(ans, expected, check_dtypes=True) pred = True on_true = onp.array([0, 1], onp.float32) on_false = onp.array(3, onp.float32) ans = vmap(lax.select, (None, 0, None))(pred, on_true, on_false) expected = onp.array([0, 1], onp.float32) self.assertAllClose(ans, expected, check_dtypes=True) pred = onp.array([False, True]) on_true = onp.array([0, 1], onp.float32) on_false = onp.array(3, onp.float32) ans = vmap(lax.select, (0, 0, None))(pred, on_true, on_false) expected = onp.array([3, 1], onp.float32) self.assertAllClose(ans, expected, check_dtypes=True) pred = onp.array([False, True]) on_true = onp.array([2], onp.float32) on_false = onp.array([[3, 4]], onp.float32) ans = vmap(lax.select, (0, None, 1), 1)(pred, on_true, on_false) expected = onp.array([[3, 2]], onp.float32) self.assertAllClose(ans, expected, check_dtypes=True) def testLaxLinalgCholesky(self): a = onp.random.RandomState(0).randn(10, 5, 5).astype(onp.float32) a = onp.matmul(a, onp.conj(onp.swapaxes(a, -1, -2))) ans = vmap(lax_linalg.cholesky)(a) expected = onp.linalg.cholesky(a) self.assertAllClose(ans, expected, check_dtypes=False, rtol=1e-4) b = onp.random.RandomState(0).randn(10, 5, 5).astype(onp.float32) b = onp.matmul(b, onp.conj(onp.swapaxes(b, -1, -2))) b_trans = onp.swapaxes(b, 0, 1) # shape is (5, 10, 5) ans = vmap(lax_linalg.cholesky, in_axes=1, out_axes=0)(b_trans) expected = onp.linalg.cholesky(b) self.assertAllClose(ans, expected, check_dtypes=False, rtol=1e-4) def testLaxLinalgTriangularSolve(self): a = onp.random.RandomState(0).randn(4, 10, 4).astype(onp.float32) a += onp.eye(4, dtype=np.float32)[:, None, :] b = onp.random.RandomState(0).randn(5, 4, 10).astype(onp.float32) ans = vmap(lax_linalg.triangular_solve, in_axes=(1, 2))(a, b) expected = onp.stack([ lax_linalg.triangular_solve(a[:, i], b[..., i]) for i in range(10) ]) self.assertAllClose(ans, expected, check_dtypes=True) ans = vmap(lax_linalg.triangular_solve, in_axes=(None, 2))(a[:, 0], b) expected = onp.stack([ lax_linalg.triangular_solve(a[:, 0], b[..., i]) for i in range(10) ]) self.assertAllClose(ans, expected, check_dtypes=True) ans = vmap(lax_linalg.triangular_solve, in_axes=(1, None))(a, b[..., 0]) expected = onp.stack([ lax_linalg.triangular_solve(a[:, i], b[..., 0]) for i in range(10) ]) self.assertAllClose(ans, expected, check_dtypes=True) @parameterized.named_parameters( { "testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory } for dtype in [onp.float32, onp.int32] for axis, shape, idxs, dnums, slice_sizes in [ (0, (3, 5), onp.array([[0], [2]]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, )), (1, (10, 3), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, )), (2, )), (1, (10, 3, 5), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, 3)), (2, (10, 5, 3), onp.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, 1)), (1, 3)), ] for rng_idx_factory in [partial(jtu.rand_int, max(shape))] for rng_factory in [jtu.rand_default]) def testGatherBatchedOperand(self, axis, shape, dtype, idxs, dnums, slice_sizes, rng_factory, rng_idx_factory): rng = rng_factory() rng_idx = rng_idx_factory() fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) operand = rng(shape, dtype) ans = vmap(fun, (axis, None))(operand, idxs) expected = onp.stack([ fun(operand[(slice(None), ) * axis + (i, )], idxs) for i in range(operand.shape[axis]) ]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( { "testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory } for dtype in [onp.float32, onp.float64] for axis, shape, idxs, dnums, slice_sizes in [ (0, (3, 5), onp.array([[0], [2]]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, )), (1, (10, 3), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, )), (2, )), (1, (10, 3, 5), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, 3)), (2, (10, 5, 3), onp.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, 1)), (1, 3)), ] for rng_idx_factory in [partial(jtu.rand_int, max(shape))] for rng_factory in [jtu.rand_default]) def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums, slice_sizes, rng_factory, rng_idx_factory): rng = rng_factory() rng_idx = rng_idx_factory() fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx)))) operand = rng(shape, dtype) ans = vmap(gfun, (axis, None))(operand, idxs) expected = onp.stack([ gfun(operand[(slice(None), ) * axis + (i, )], idxs) for i in range(operand.shape[axis]) ]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( { "testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory } for dtype in [onp.float32, onp.int32] for axis, shape, idxs, dnums, slice_sizes in [ (0, (5, ), onp.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, )), (1, (10, ), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None], lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, )), (2, )), (1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T[..., None], lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, 3)), (0, (10, 5), onp.array([[[0, 1], [2, 0]], [[1, 0], [2, 3]]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, 1)), (1, 3)), ] for rng_idx_factory in [partial(jtu.rand_int, max(shape))] for rng_factory in [jtu.rand_default]) def testGatherBatchedIndices(self, axis, shape, dtype, idxs, dnums, slice_sizes, rng_factory, rng_idx_factory): rng = rng_factory() rng_idx = rng_idx_factory() fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) operand = rng(shape, dtype) ans = vmap(fun, (None, axis))(operand, idxs) expected = onp.stack([ fun(operand, idxs[(slice(None), ) * axis + (i, )]) for i in range(idxs.shape[axis]) ]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( { "testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory } for dtype in [onp.float32, onp.float64] for axis, shape, idxs, dnums, slice_sizes in [ (0, (5, ), onp.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, )), (1, (10, ), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None], lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, )), (2, )), (1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T[..., None], lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, 3)), (0, (10, 5), onp.array([[[0, 1], [2, 0]], [[1, 0], [2, 3]]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, 1)), (1, 3)), ] for rng_idx_factory in [partial(jtu.rand_int, max(shape))] for rng_factory in [jtu.rand_default]) def testGatherGradBatchedIndices(self, axis, shape, dtype, idxs, dnums, slice_sizes, rng_factory, rng_idx_factory): rng = rng_factory() rng_idx = rng_idx_factory() fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx)))) operand = rng(shape, dtype) ans = vmap(gfun, (None, axis))(operand, idxs) expected = onp.stack([ gfun(operand, idxs[(slice(None), ) * axis + (i, )]) for i in range(idxs.shape[axis]) ]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( { "testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}" .format(jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs, dnums, slice_sizes), "op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory } for dtype in [onp.float32, onp.int32] for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [ (0, 0, (2, 5), onp.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, )), (1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None], lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, )), (2, )), (0, 1, ( 2, 10, 5, ), onp.array([[[0, 2, 1], [0, 3, 3]]]).T, lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, 3)), (2, 0, (10, 5, 2), onp.array([[[0, 2], [1, 0]], [[1, 0], [2, 0]]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, 1)), (1, 3)), ] for rng_idx_factory in [partial(jtu.rand_int, max(shape))] for rng_factory in [jtu.rand_default]) def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums, slice_sizes, rng_factory, rng_idx_factory): rng = rng_factory() rng_idx = rng_idx_factory() fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) operand = rng(shape, dtype) assert operand.shape[op_axis] == idxs.shape[idxs_axis] ans = vmap(fun, (op_axis, idxs_axis))(operand, idxs) expected = onp.stack([ fun(operand[(slice(None), ) * op_axis + (i, )], idxs[(slice(None), ) * idxs_axis + (i, )]) for i in range(idxs.shape[idxs_axis]) ]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( { "testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}" .format(jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs, dnums, slice_sizes), "op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory } for dtype in [onp.float32] for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [ (0, 0, (2, 5), onp.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, )), (1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None], lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, )), (2, )), (0, 1, ( 2, 10, 5, ), onp.array([[[0, 2, 1], [0, 3, 3]]]).T, lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, 3)), (2, 0, (10, 5, 2), onp.array([[[0, 2], [1, 0]], [[1, 0], [2, 0]]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, 1)), (1, 3)), ] for rng_idx_factory in [partial(jtu.rand_int, max(shape))] for rng_factory in [jtu.rand_default]) def testGatherGradBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums, slice_sizes, rng_factory, rng_idx_factory): rng = rng_factory() rng_idx = rng_idx_factory() fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx)))) operand = rng(shape, dtype) assert operand.shape[op_axis] == idxs.shape[idxs_axis] ans = vmap(gfun, (op_axis, idxs_axis))(operand, idxs) expected = onp.stack([ gfun(operand[(slice(None), ) * op_axis + (i, )], idxs[(slice(None), ) * idxs_axis + (i, )]) for i in range(idxs.shape[idxs_axis]) ]) self.assertAllClose(ans, expected, check_dtypes=False) def testNumpyIndexing1(self): a = np.arange(2 * 3 * 4).reshape((2, 3, 4)) ind = onp.array([[0, 1], [2, 0]]) def f(a, ind): return a[:, ind] expected = onp.stack([f(a, ind[i, :]) for i in range(ind.shape[0])]) ans = vmap(f, (None, 0))(a, ind) assert onp.all(ans == expected) def testNumpyIndexing2(self): a = np.arange(2 * 3 * 4).reshape((2, 3, 4)) def f(a): inds = np.array([0, 2]) return a[:, inds] ans = vmap(f)(a) expected = onp.stack([f(a[:, i, :]) for i in range(a.shape[1])], axis=1) assert onp.all(ans == expected) def testTranspose(self): x = onp.arange(4 * 3 * 3).reshape((4, 3, 3)) ans = vmap(lambda x: x + x.T)(x) expected = x + onp.swapaxes(x, -1, -2) self.assertAllClose(ans, expected, check_dtypes=False) def testTransposePermutation(self): x = onp.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5)) ans = vmap(lambda x: np.transpose(x, (1, 0, 2)))(x) expected = onp.transpose(x, (0, 2, 1, 3)) self.assertAllClose(ans, expected, check_dtypes=False) x = onp.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5)) ans = vmap(lambda x: np.transpose(x, (1, 2, 0)))(x) expected = onp.transpose(x, (0, 2, 3, 1)) self.assertAllClose(ans, expected, check_dtypes=False) x = onp.arange(6 * 3 * 4 * 5).reshape((3, 4, 6, 5)) ans = vmap(lambda x: np.transpose(x, (1, 2, 0)), in_axes=2)(x) expected = onp.transpose(x, (2, 1, 3, 0)) self.assertAllClose(ans, expected, check_dtypes=False) def testIssue354(self): psd_mat = onp.random.randn(20, 10) psd_mat = psd_mat.T.dot(psd_mat) vec = onp.random.randn(10) def f(scale): scaled_mat = scale * psd_mat chol = np.linalg.cholesky(scaled_mat) return -0.5 * np.sum((np.einsum('ij,j->i', chol, vec))**2) vmapped_f = vmap(f) vmapped_f_grad = grad(lambda x: np.sum(vmapped_f(x))) scales = onp.array([[0.1], [0.2], [0.3], [0.4], [0.5]]) ans = vmapped_f_grad(scales) # don't crash! expected = onp.stack([grad(f)(scale) for scale in scales]) self.assertAllClose(ans, expected, check_dtypes=False, rtol=jtu.default_gradient_tolerance) def testIssue387(self): # https://github.com/google/jax/issues/387 R = onp.random.RandomState(0).rand(100, 2) def dist_sq(R): dR = R[:, np.newaxis, :] - R[np.newaxis, :, :] zero = np.zeros_like(dR) dR = dR - np.where(np.abs(dR) < 0.5, zero, 0.5 * np.sign(dR)) return np.sum(dR**2, axis=2) @jit def f(R): dr = dist_sq(R) return np.sum(R**2) H = hessian(f)(R) # don't crash on UnshapedArray def testIssue489(self): def f(key): def body_fn(uk): key = uk[1] u = random.uniform(key, (), dtype=np.float64) key, _ = random.split(key) return u, key u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn, (np.float64(1.), key)) return u print(vmap(f)(random.split(random.PRNGKey(0), 2))) # no crash def testEmptyTuples(self): # Ensure there is no crash when a vectorized input contains empty tuples. result = vmap(lambda x, _: x + 1)(onp.array([0, 1]), ()) self.assertAllClose(result, onp.array([1, 2]), check_dtypes=False) # Ensure there is no crash when a vectorized output contains empty tuples. result, empty_tuple = vmap(lambda x: (x + 1, ()))(onp.array([0, 1])) self.assertAllClose(result, onp.array([1, 2]), check_dtypes=False) self.assertEqual((), empty_tuple) def testIndexAddBatchedIndexesOnly(self): f = lambda x, idx, y: jax.ops.index_add(x, jax.ops.index[idx], y) result = vmap(f, (None, 0, None))(onp.zeros((10, )), onp.arange(10, ), 1.) self.assertAllClose(result, onp.eye(10), check_dtypes=False) def testIssue1170(self): def f(index1, index2): return np.arange(36).reshape(6, 6)[index1, index2] g = jax.jit(jax.pmap(f)) ans = g(index1=onp.asarray([1]), index2=onp.asarray([2])) expected = g(onp.asarray([1]), onp.asarray([2])) self.assertAllClose(ans, expected, check_dtypes=True)
class LaxAutodiffTest(jtu.JaxTestCase): @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix( rec.name, shapes, itertools.repeat(dtype)), "op": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype, "order": rec.order, "tol": rec.tol} for shape_group in compatible_shapes for shapes in CombosWithReplacement(shape_group, rec.nargs) for dtype in rec.dtypes) for rec in LAX_GRAD_OPS)) def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): rng = rng_factory(self.rng()) if jtu.device_under_test() == "tpu" and op is lax.pow: raise SkipTest("pow grad imprecise on tpu") tol = jtu.join_tolerance(1e-1, tol) if num_float_bits(dtype) == 32 else tol args = tuple(rng(shape, dtype) for shape in shapes) check_grads(op, args, order, ["fwd", "rev"], tol, tol) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": "_{}_{}".format(rec.op.__name__, special_value), "op": rec.op, "special_value": special_value, "tol": rec.tol} for special_value in rec.values) for rec in LAX_GRAD_SPECIAL_VALUE_TESTS)) def testOpGradSpecialValue(self, op, special_value, tol): check_grads(op, (special_value,), 2, ["fwd", "rev"], rtol=tol, atol=tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_from_dtype={}_to_dtype={}".format( jtu.dtype_str(from_dtype), jtu.dtype_str(to_dtype)), "from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory} for from_dtype, to_dtype in itertools.product( float_dtypes + complex_dtypes, repeat=2) for rng_factory in [jtu.rand_default])) def testConvertElementTypeGrad(self, from_dtype, to_dtype, rng_factory): rng = rng_factory(self.rng()) tol = max(jtu.tolerance(to_dtype, jtu.default_gradient_tolerance), jtu.tolerance(from_dtype, jtu.default_gradient_tolerance)) args = (rng((2, 3), from_dtype),) convert_element_type = lambda x: lax.convert_element_type(x, to_dtype) convert_element_type = jtu.ignore_warning(category=onp.ComplexWarning)( convert_element_type) check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format( jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng_factory": rng_factory} for shape in [(), (2, 3)] for dtype in grad_float_dtypes for rng_factory in [jtu.rand_default])) def testClampGrad(self, shape, dtype, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) low = operand - dtype(10) high = operand + dtype(10) # Avoids points near the boundary where the gradient may be inaccurate. check_grads(lax.clamp, (operand, low, high), 2, ["fwd", "rev"], eps=1e-2) check_grads(lax.clamp, (low, operand, high), 2, ["fwd", "rev"], eps=1e-2) check_grads(lax.clamp, (low, high, operand), 2, ["fwd", "rev"], eps=1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format( dim, ",".join(str(d) for d in base_shape), onp.dtype(dtype).name, num_arrs), "dim": dim, "base_shape": base_shape, "dtype": dtype, "num_arrs": num_arrs, "rng_factory": rng_factory} for num_arrs in [3] for dtype in float_dtypes for base_shape in [(4,), (3, 4), (2, 3, 4)] for dim in range(len(base_shape)) for rng_factory in [jtu.rand_default])) def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs, rng_factory): rng = rng_factory(self.rng()) shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:] for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))] operands = tuple(rng(shape, dtype) for shape in shapes) concatenate = lambda *args: lax.concatenate(args, dim) check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "strides": strides, "padding": padding, "rng_factory": rng_factory,} for lhs_shape, rhs_shape, all_strides in itertools.chain( [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)]) for b, i, j in itertools.product([2, 3], repeat=3)], [((4, 2, 1), (3, 2, 1), [(1,)])]) for strides in all_strides for dtype in float_dtypes for padding in ["VALID", "SAME"] for rng_factory in [jtu.rand_small])) def testConvGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, rng_factory): rng = rng_factory(self.rng()) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) conv = partial(lax.conv, window_strides=strides, padding=padding, precision=lax.Precision.HIGHEST) check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=1e-2, rtol=1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" "rhs_dilation={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, lhs_dil, rhs_dil), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "strides": strides, "padding": padding, "lhs_dil": lhs_dil, "rhs_dil": rhs_dil, "rng_factory": rng_factory} for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in itertools.chain( [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)], [((0, 0), (0, 0)), ((-1, 0), (0, -1)), ((1, 0), (0, 1))], [(1, 1), (2, 1)], [(1, 1)]) for b, i, j in itertools.product([2, 3], repeat=3)], [((4, 2, 1), (3, 2, 1), [(1,)], [((1, 1),), ((0, 0),)], [(1,), (2,)], [(1,), (2,)])]) for strides in all_strides for rhs_dil in rhs_dils for lhs_dil in lhs_dils for dtype in float_dtypes for padding in all_pads for rng_factory in [jtu.rand_small])) def testConvWithGeneralPaddingGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil, rng_factory): rng = rng_factory(self.rng()) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) conv = partial(lax.conv_with_general_padding, window_strides=strides, padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, precision=lax.Precision.HIGHEST) check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=1e-2, rtol=1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums), feature_group_count, batch_group_count), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "strides": strides, "padding": padding, "lhs_dil": lhs_dil, "rhs_dil": rhs_dil, "rng_factory": rng_factory, "dimension_numbers": dim_nums, "perms": perms, "feature_group_count": feature_group_count, "batch_group_count": batch_group_count} for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)]) for lhs_shapes, rhs_shape, all_strides, lhs_dils, rhs_dils in [ ([(b * batch_group_count, i * feature_group_count, 6, 7), (b * batch_group_count, i * feature_group_count, 0, 4)], # lhs_shape (j * batch_group_count * feature_group_count, i, 1, 2), # rhs_shape [(1, 1), (1, 2), (2, 1)], # strides [(1, 1), (2, 1)], # lhs_dils [(1, 1), (2, 2)]) # rhs_dils for b, i, j in itertools.product([1, 2], repeat=3)] for lhs_shape in lhs_shapes for strides in all_strides for rhs_dil in rhs_dils for lhs_dil in lhs_dils for dtype in grad_float_dtypes for padding in ([((0, 0), (0, 0)), ((1, 0), (0, 1))] + ([((0, -1), (0, 0))] if lhs_shape[2] != 0 else [])) for dim_nums, perms in [ (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])), (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])), (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))] for rng_factory in [jtu.rand_default] )) def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil, dimension_numbers, perms, feature_group_count, batch_group_count, rng_factory): if dtype == onp.float16: raise SkipTest("float16 numerical issues") # TODO(mattjj): resolve rng = rng_factory(self.rng()) tol = {dtypes.bfloat16: 1e-0, onp.float16: 5e-1, onp.float32: 2e-4} # permute shapes to match dim_spec, scale by feature_group_count lhs_perm, rhs_perm = perms lhs_shape = list(onp.take(lhs_shape, lhs_perm)) rhs_shape = list(onp.take(rhs_shape, rhs_perm)) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) conv = partial(lax.conv_general_dilated, window_strides=strides, padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision=lax.Precision.HIGHEST) check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=tol, rtol=tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype)), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "rng_factory": jtu.rand_default} for lhs_shape in [(2,), (3, 2)] for rhs_shape in [(2,), (2, 4)] for dtype in float_dtypes)) def testDotGrad(self, lhs_shape, rhs_shape, dtype, rng_factory): rng = rng_factory(self.rng()) tol = {onp.float16: 1e-1, onp.float32: 1e-4} lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) dot = partial(lax.dot, precision=lax.Precision.HIGHEST) check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=tol, rtol=tol) # check that precision config is preserved result, pullback = api.vjp(dot, lhs, rhs) gresult = lax.zeros_like_array(result) s = str(api.make_jaxpr(pullback)(gresult)) assert "precision=HIGHEST" in s @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_dimension_numbers={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), dimension_numbers), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "dimension_numbers": dimension_numbers, "rng_factory": jtu.rand_small} for lhs_shape, rhs_shape, dimension_numbers in [ ((3, 2), (2, 4), (([1], [0]), ([], []))), ((3, 5), (2, 5), (([1], [1]), ([], []))), ((5, 3), (5, 2), (([0], [0]), ([], []))), ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))), ] for dtype in float_dtypes)) def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype, dimension_numbers, rng_factory): rng = rng_factory(self.rng()) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers, precision=lax.Precision.HIGHEST) check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"]) # check that precision config is preserved result, pullback = api.vjp(dot_general, lhs, rhs) gresult = lax.zeros_like_array(result) s = str(api.make_jaxpr(pullback)(gresult)) assert "precision=HIGHEST" in s @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format( shape, onp.dtype(dtype).name, broadcast_sizes), "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes, "rng_factory": rng_factory} for shape in [(), (2, 3)] for dtype in float_dtypes for broadcast_sizes in [(), (2,), (1, 2)] for rng_factory in [jtu.rand_default])) def testBroadcastGrad(self, shape, dtype, broadcast_sizes, rng_factory): rng = rng_factory(self.rng()) args = (rng(shape, dtype),) broadcast = lambda x: lax.broadcast(x, broadcast_sizes) check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}_bcdims={}".format( jtu.format_shape_dtype_string(inshape, dtype), outshape, broadcast_dimensions), "inshape": inshape, "dtype": dtype, "outshape": outshape, "dimensions": broadcast_dimensions, "rng_factory": rng_factory} for inshape, outshape, broadcast_dimensions in [ ([2], [2, 2], [0]), ([2], [2, 2], [1]), ([2], [2, 3], [0]), ([], [2, 3], []), ] for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions, rng_factory): rng = rng_factory(self.rng()) operand = rng(inshape, dtype) broadcast_in_dim = lambda x: lax.broadcast_in_dim(x, outshape, dimensions) check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}_perm={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), jtu.format_shape_dtype_string(out_shape, dtype), permutation), "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, "rng_factory": rng_factory, "permutation": permutation} for dtype in float_dtypes for arg_shape, out_shape, permutation in [ [(3, 4), (12,), None], [(2, 1, 4), (8,), None], [(2, 2, 4), (2, 8), None], [(3, 4), (12,), (0, 1)], [(3, 4), (12,), (1, 0)], [(2, 1, 4), (8,), (0, 2, 1)], [(2, 1, 4), (8,), (2, 0, 1)], [(2, 2, 4), (2, 8), (0, 2, 1)], [(2, 2, 4), (2, 8), (2, 0, 1)], ] for rng_factory in [jtu.rand_default])) def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype, rng_factory): rng = rng_factory(self.rng()) operand = rng(arg_shape, dtype) reshape = lambda x: lax.reshape(x, out_shape, permutation) check_grads(reshape, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_pads={}" .format(jtu.format_shape_dtype_string(shape, dtype), pads), "shape": shape, "dtype": dtype, "pads": pads, "rng_factory": jtu.rand_small} for shape in [(2, 3)] for dtype in float_dtypes for pads in [[(1, 2, 1), (0, 1, 0)], [(-1, 0, 0), (-1, 0, 2)]])) def testPadGrad(self, shape, dtype, pads, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) pad = lambda operand: lax.pad(operand, onp.array(0, dtype), pads) check_grads(pad, (operand,), 2, ["fwd", "rev"], eps=1.) operand = rng(shape, dtype) padding_value = onp.array(0., dtype) pad = lambda operand, padding_value: lax.pad(operand, padding_value, pads) check_grads(pad, (operand, padding_value), 2, ["fwd", "rev"], eps=1.) def testReverseGrad(self): rev = lambda operand: lax.rev(operand, dimensions) dimensions = [0] check_grads(rev, (onp.array([3., 2., 1.]),), 2) dimensions = [0, 1] check_grads(rev, (onp.array([[6., 5., 4.], [3., 2., 1.]]),), 2, rtol={onp.float32: 3e-3}) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_predshape={}_argshapes={}".format( jtu.format_shape_dtype_string(pred_shape, onp.bool_), jtu.format_shape_dtype_string(arg_shape, dtype)), "pred_shape": pred_shape, "arg_shape": arg_shape, "dtype": dtype, "rng_factory": rng_factory} for arg_shape in [(), (3,), (2, 3)] for pred_shape in ([(), arg_shape] if arg_shape else [()]) for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testSelectGrad(self, pred_shape, arg_shape, dtype, rng_factory): rng = rng_factory(self.rng()) pred = rng(pred_shape, onp.bool_) on_true = rng(arg_shape, dtype) on_false = rng(arg_shape, dtype) select = lambda on_true, on_false: lax.select(pred, on_true, on_false) check_grads(select, (on_true, on_false), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_start_indices={}_limit_indices={}_strides={}".format( jtu.format_shape_dtype_string(shape, dtype), start_indices, limit_indices, strides), "shape": shape, "dtype": dtype, "starts": start_indices, "limits": limit_indices, "strides": strides, "rng_factory": rng_factory} for shape, start_indices, limit_indices, strides in [ [(3,), (1,), (2,), None], [(7,), (4,), (7,), None], [(5,), (1,), (5,), (2,)], [(8,), (1,), (6,), (2,)], [(5, 3), (1, 1), (3, 2), None], [(5, 3), (1, 1), (3, 1), None], [(7, 5, 3), (4, 0, 1), (7, 1, 3), None], [(5, 3), (1, 1), (2, 1), (1, 1)], [(5, 3), (1, 1), (5, 3), (2, 1)], ] for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testSliceGrad(self, shape, dtype, starts, limits, strides, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) slice = lambda x: lax.slice(x, starts, limits, strides) check_grads(slice, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_start_indices={}_size_indices={}".format( jtu.format_shape_dtype_string(shape, dtype), start_indices, size_indices), "shape": shape, "dtype": dtype, "start_indices": start_indices, "size_indices": size_indices, "rng_factory": rng_factory} for shape, start_indices, size_indices in [ [(3,), (1,), (1,)], [(5, 3), (1, 1), (3, 1)], [(7, 5, 3), (4, 1, 0), (2, 0, 1)], ] for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) dynamic_slice = lambda x: lax.dynamic_slice(x, start_indices, size_indices) check_grads(dynamic_slice, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_start_indices={}_update_shape={}".format( jtu.format_shape_dtype_string(shape, dtype), start_indices, update_shape), "shape": shape, "dtype": dtype, "start_indices": start_indices, "update_shape": update_shape, "rng_factory": rng_factory} for shape, start_indices, update_shape in [ [(3,), (1,), (1,)], [(5, 3), (1, 1), (3, 1)], [(7, 5, 3), (4, 1, 0), (2, 0, 1)], ] for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices, update_shape, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) update = rng(update_shape, dtype) start_indices = onp.array(start_indices) dus = lambda x, y: lax.dynamic_update_slice(x, y, start_indices) check_grads(dus, (operand, update), 2, ["fwd", "rev"], eps=1.) dus = lambda x: lax.dynamic_update_slice(x, update, start_indices) check_grads(dus, (operand,), 2, ["fwd", "rev"], eps=1.) dus = lambda y: lax.dynamic_update_slice(operand, y, start_indices) check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_perm={}".format( jtu.format_shape_dtype_string(shape, dtype), perm), "shape": shape, "dtype": dtype, "perm": perm, "rng_factory": rng_factory} for shape, perm in [ [(3, 4), (1, 0)], [(3, 4), (0, 1)], [(3, 4, 5), (2, 1, 0)], [(3, 4, 5), (1, 0, 2)], ] for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testTransposeGrad(self, shape, dtype, perm, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) transpose = lambda x: lax.transpose(x, perm) check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_inshape={}_reducedims={}" .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims), "op": op, "init_val": init_val, "shape": shape, "dtype": dtype, "dims": dims, "rng_factory": rng_factory} for init_val, op, dtypes, rng_factory in [ (0, lax.add, inexact_dtypes, jtu.rand_default), (-onp.inf, lax.max, grad_inexact_dtypes, jtu.rand_unique_int), (onp.inf, lax.min, grad_inexact_dtypes, jtu.rand_unique_int), (1, lax.mul, grad_float_dtypes, partial(jtu.rand_default, scale=1)), ] for dtype in dtypes for shape, dims in [ [(3, 4, 5), ()], [(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)], [(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)], [(3, 1), (1,)], ])) def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory): rng = rng_factory(self.rng()) if jtu.device_under_test() == "tpu" and op is lax.mul: raise SkipTest("unimplemented case") tol = {dtypes.bfloat16: 2e-1, onp.float16: 1e-1, onp.float32: 1e-1, onp.float64: 1e-3, onp.complex64: 1e-1} operand = rng(shape, dtype) init_val = onp.asarray(init_val, dtype=dtype) reduce = lambda operand: lax.reduce(operand, init_val, op, dims) eps = (1.0 if dtypes.finfo(dtype).bits == 16 and op is lax.add else 1e-1 if dtype == dtypes.bfloat16 else 1e-2 if dtypes.finfo(dtype).bits == 32 else None) check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_dtype={}_padding={}" .format(op.__name__, onp.dtype(dtype).name, padding), "op": op, "init_val": init_val, "dtype": dtype, "padding": padding, "rng_factory": rng_factory} for init_val, op, dtypes, rng_factory in [ (0, lax.add, grad_float_dtypes, jtu.rand_small), (-onp.inf, lax.max, grad_float_dtypes, jtu.rand_unique_int), (onp.inf, lax.min, grad_float_dtypes, jtu.rand_unique_int), ] for dtype in dtypes for padding in ["VALID", "SAME"])) @jtu.skip_on_flag("jax_skip_slow_tests", True) @jtu.ignore_warning(category=UserWarning, message="Using reduced precision for gradient.*") def testReduceWindowGrad(self, op, init_val, dtype, padding, rng_factory): rng = rng_factory(self.rng()) init_val = onp.asarray(init_val, dtype=dtype) # We need this conditional and the corresponding loop logic to be in the # test method, rather than at the parameterized test level, because it # depends on FLAGS for the device under test. # TODO(b/31565929): enable when fixed. if jtu.device_under_test() == "tpu" and op is not lax.add: all_configs = [((6, 5, 4, 3), (2, 2, 1, 1), (1, 2, 1, 1))] # TODO(b/73062247): need variadic reduce-window for better precision. gradient_order = 1 else: all_configs = itertools.chain( itertools.product( [(4, 6)], # shapes [(2, 1), (1, 2)], # window_dimensions [(1, 1), (2, 1), (1, 2)] # strides ), itertools.product( [(3, 2, 4, 6)], # shapes [(1, 1, 2, 1), (2, 1, 2, 1)], # window_dimensions [(1, 2, 2, 1), (1, 1, 1, 1)]), # strides ) gradient_order = 3 def fun(operand): return lax.reduce_window(operand, init_val, op, dims, strides, padding) for shape, dims, strides in all_configs: operand = rng(shape, dtype) if op is lax.add: eps = 1. tol = None else: # this test can fail if there are duplicates in operand self.assertEqual(onp.unique(operand).size, operand.size, msg="test requires operand elements to be unique.") eps = 1e-2 tol = {onp.float16: 1e-1, onp.float32: 6e-2, onp.float64: 6e-2} check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], tol, tol, eps) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_shape={}_axis={}" .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis), "op": op, "shape": shape, "dtype": dtype, "axis": axis, "rng_factory": rng_factory} for op, types in [ (lax.cumsum, [onp.float32, onp.float64]), (lax.cumprod, [onp.float32, onp.float64]), ] for dtype in types for shape in [[10], [3, 4, 5]] for axis in range(len(shape)) for rng_factory in [ jtu.rand_default if dtypes.issubdtype(dtype, onp.integer) else jtu.rand_small])) def testCumulativeReduceGrad(self, op, shape, dtype, axis, rng_factory): rng = rng_factory(self.rng()) check_grads(partial(op, axis=axis), (rng(shape, dtype),), order=2) # TODO(b/205052657): enable more tests when supported @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_axis={}_isstable={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, is_stable), "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "axis": axis, "is_stable": is_stable} for dtype in [onp.float32] for shape in [(5,), (5, 7)] for axis in [len(shape) - 1] for is_stable in [False, True] for rng_factory in [jtu.rand_default])) def testSortGrad(self, shape, dtype, axis, is_stable, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) sort = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable) check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2) # TODO(b/205052657): enable more tests when supported @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_keyshape={}_valshape={}_axis={}_isstable={}".format( jtu.format_shape_dtype_string(shape, key_dtype), jtu.format_shape_dtype_string(shape, val_dtype), axis, is_stable), "rng_factory": rng_factory, "shape": shape, "key_dtype": key_dtype, "val_dtype": val_dtype, "axis": axis, "is_stable": is_stable} for key_dtype in [onp.float32] for val_dtype in [onp.float32] for shape in [(3,), (5, 3)] for axis in [len(shape) - 1] for is_stable in [False, True] for rng_factory in [jtu.rand_default])) def testSortKeyValGrad(self, shape, key_dtype, val_dtype, axis, is_stable, rng_factory): rng = rng_factory(self.rng()) # This test relies on the property that wherever keys are tied, values are # too, since we don't guarantee the same ordering of values with equal keys. # To avoid that case, we generate unique keys (globally in the key array). def args_maker(): flat_keys = onp.arange(onp.prod(shape, dtype=int), dtype=key_dtype) keys = self.rng().permutation(flat_keys).reshape(shape) values = rng(shape, val_dtype) return keys, values keys, values = args_maker() fun = lambda keys, values: lax.sort_key_val(keys, values, axis, is_stable) check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_k={}".format( jtu.format_shape_dtype_string(shape, dtype), k), "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k} for dtype in [onp.float32,] for shape in [(4,), (5, 5), (2, 1, 4)] for k in [1, 3] for rng_factory in [jtu.rand_default])) def testTopKGrad(self, shape, dtype, k, rng_factory): flat_values = onp.arange(onp.prod(shape, dtype=int), dtype=dtype) values = self.rng().permutation(flat_values).reshape(shape) fun = lambda vs: lax.top_k(vs, k=k)[0] check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_axes={}".format( jtu.format_shape_dtype_string(shape, dtype), idxs, axes), "shape": shape, "dtype": dtype, "idxs": idxs, "axes": axes, "rng_factory": rng_factory} for dtype in float_dtypes for shape, idxs, axes in [ [(3, 4, 5), (onp.array([0, 2, 1]),), (0,)], [(3, 4, 5), (onp.array([-1, -2]),), (0,)], [(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 1)], [(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 2)], ] for rng_factory in [jtu.rand_default])) def testIndexTakeGrad(self, shape, dtype, idxs, axes, rng_factory): rng = rng_factory(self.rng()) src = rng(shape, dtype) index_take = lambda src: lax.index_take(src, idxs, axes) check_grads(index_take, (src,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), idxs, dnums, slice_sizes), "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for shape, idxs, dnums, slice_sizes, max_idx in [ ((5,), onp.array([[0], [2]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,), 5), ((10,), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,), 9), ((10, 5,), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3), 3), ] for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)] for rng_factory in [jtu.rand_default])) def testGatherGrad(self, shape, dtype, idxs, dnums, slice_sizes, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) gather = lambda x: lax.gather(x, idxs, dimension_numbers=dnums, slice_sizes=slice_sizes) x = rng(shape, dtype) check_grads(gather, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for arg_shape, idxs, update_shape, dnums, max_idx in [ ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)), 4), ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,)), 9), ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)), 3), ] for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)] for rng_factory in [jtu.rand_default])) def testScatterAddGrad(self, arg_shape, dtype, idxs, update_shape, dnums, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) scatter_add = lambda x, y: lax.scatter_add(x, idxs, y, dimension_numbers=dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter_add, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for arg_shape, idxs, update_shape, dnums, max_idx in [ ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)), 4), ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,)), 9), ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)), 3), ] # Scatters with conflicting indices are not deterministic on GPU, so we # use indices that do not collide. for rng_idx_factory in [partial(jtu.rand_unique_int, high=max_idx)] for rng_factory in [jtu.rand_default])) def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) scatter = lambda x, y: lax.scatter(x, idxs, y, dimension_numbers=dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) def testScatterGradSymbolicZeroUpdate(self): # https://github.com/google/jax/issues/1901 def f(x): n = x.shape[0] y = onp.arange(n, dtype=x.dtype) return jax.ops.index_update(x, onp.diag_indices(n), y) rng = jtu.rand_default(self.rng()) check_grads(f, (rng((5, 5), onp.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for arg_shape, idxs, update_shape, dnums in [ ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ] for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))] for rng_factory in [jtu.rand_default])) def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) scatter_max = lambda x, y: lax.scatter_max(x, idxs, y, dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter_max, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for arg_shape, idxs, update_shape, dnums in [ ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ] for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))] for rng_factory in [jtu.rand_default])) def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) scatter_min = lambda x, y: lax.scatter_min(x, idxs, y, dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter_min, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2) def testStopGradient(self): def f(x): return lax.sin(x) * lax.cos(lax.stop_gradient(x)) def f2(x, y): return lax.sin(x) * lax.cos(y) x = 3.14 ans = api.grad(f)(x) expected = api.grad(f2)(x, x) self.assertAllClose(ans, expected) ans = api.grad(api.grad(f))(x) expected = api.grad(api.grad(f2))(x, x) self.assertAllClose(ans, expected) ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.) expected = onp.array(0.0) self.assertAllClose(ans, expected, check_dtypes=False) with core.skipping_checks(): with self.assertRaises(TypeError): lax.stop_gradient(lambda x: x) # TODO(mattjj): make this a more systematic test def testRemainder(self): rng = onp.random.RandomState(0) x = rng.uniform(-0.9, 9, size=(3, 4)) y = rng.uniform(0.7, 1.9, size=(3, 1)) assert not set(onp.unique(x)) & set(onp.unique(y)) tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3 check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol) rng = onp.random.RandomState(0) x = rng.uniform(-0.9, 9, size=(1, 4)) y = rng.uniform(0.7, 1.9, size=(3, 4)) assert not set(onp.unique(x)) & set(onp.unique(y)) tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3 check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol) def testHigherOrderGradientOfReciprocal(self): # Regression test for https://github.com/google/jax/issues/3136 def inv(x): # N.B.: intentionally written as 1/x, not x ** -1 or reciprocal(x) return 1 / x grad_fn = jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(inv)))))) self.assertAllClose(onp.float32(0.0439453125), grad_fn(onp.float32(4.)))
] + # Directly from lax.gather in lax_test.py. [ Harness( f"_shape={shape}_idxs_shape={idxs.shape}_dnums={dnums}_slice_sizes={slice_sizes}", lambda op, idxs, dnums, slice_sizes: lax.gather( op, idxs, dimension_numbers=dnums, slice_sizes=slice_sizes), [ RandArg(shape, np.float32), idxs, StaticArg(dnums), StaticArg(slice_sizes) ]) for shape, idxs, dnums, slice_sizes in [ ((5, ), np.array([[0], [2]]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, )), ((10, ), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, )), (2, )), (( 10, 5, ), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, 3)), ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ),
class BatchingTest(jtu.JaxTestCase): def testConstantFunction(self): ans = vmap(lambda x: 3)(onp.ones(4)) expected = 3 * onp.ones(4) self.assertAllClose(ans, expected, check_dtypes=False) def testNestedBatchingMatMat(self): matvec = vmap(np.vdot, in_axes=(0, None)) matmat = vmap(matvec, in_axes=(None, 1), out_axes=1) R = onp.random.RandomState(0).randn A = R(4, 3) B = R(3, 2) ans = matmat(A, B) expected = onp.dot(A, B) self.assertAllClose(ans, expected, check_dtypes=False) # this is a crude check that we only call a single dot def pv_like(x): aval = ShapedArray(onp.shape(x), onp.result_type(x)) return pe.PartialVal((aval, unit)) def make_jaxpr(fun, example_args): jaxpr, _, _, _ = trace_to_jaxpr(fun, map(pv_like, example_args)) return jaxpr jaxpr = make_jaxpr(matmat, (A, B)) self.assertEqual(len(jaxpr.eqns), 1) def testPerExampleGradients(self): def predict(params, inputs): for W, b in params: outputs = np.dot(W, inputs) + b inputs = np.tanh(outputs) return outputs def loss(params, data): inputs, targets = data predictions = predict(params, inputs) return np.sum((predictions - targets)**2) batch_size = 5 layer_sizes = [3, 2, 4] R = onp.random.RandomState(0).randn params = [(R(m, n), R(m)) for m, n in zip(layer_sizes[1:], layer_sizes[:-1])] input_vec = R(3) target_vec = R(4) datum = (input_vec, target_vec) input_batch = R(5, 3) target_batch = R(5, 4) batch = (input_batch, target_batch) ans = vmap(partial(grad(loss), params))(batch) for ans_pair, param_pair in zip(ans, params): dW, db = ans_pair W, b = param_pair self.assertEqual(dW.shape, (batch_size, ) + W.shape) self.assertEqual(db.shape, (batch_size, ) + b.shape) def testJacobians(self): def jacbwd(f, x): y, pullback = vjp(f, x) std_basis = onp.eye(onp.size(y)).reshape((-1, ) + onp.shape(y)) jac_flat, = vmap(pullback, out_axes=onp.ndim(y))(std_basis) return jac_flat.reshape(onp.shape(y) + onp.shape(x)) def jacfwd(f, x): pushfwd = lambda v: jvp(f, (x, ), (v, )) std_basis = onp.eye(onp.size(x)).reshape((-1, ) + onp.shape(x)) y, jac_flat = vmap(pushfwd, out_axes=(None, 0))(std_basis) return jac_flat.reshape(onp.shape(y) + onp.shape(x)) R = onp.random.RandomState(0).randn A = R(4, 3) b = R(4) f = lambda x: np.tanh(np.dot(A, x) + b) x = R(3) self.assertAllClose(jacfwd(f, x), jacbwd(f, x), check_dtypes=False) def testBatchOfCompile(self): side = [] @jit def f(x): side.append(None) return x + x g = jit(vmap(f)) self.assertAllClose(g(onp.ones(2)), 2 * onp.ones(2), check_dtypes=False) self.assertEqual(len(side), 1) self.assertAllClose(g(2 * onp.ones(2)), 4 * onp.ones(2), check_dtypes=False) self.assertEqual(len(side), 1) def testSliceLax(self): fun = lambda x: lax.slice(x, (2, ), (4, )) R = onp.random.RandomState(0).randn x = R(5, 10) ans = vmap(fun)(x) expected_ans = x[:, 2:4] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testSliceNumpy(self): fun = lambda x: x[:, 2] R = onp.random.RandomState(0).randn x = R(10, 5, 3, 7) ans = vmap(fun)(x) expected_ans = x[:, :, 2] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testRevLax(self): fun = lambda x: lax.rev(x, [0]) R = onp.random.RandomState(0).randn x = R(2, 3) ans = vmap(fun)(x) expected_ans = x[:, ::-1] self.assertAllClose(ans, expected_ans, check_dtypes=False) ans = vmap(fun, (1, ), 1)(x) expected_ans = x[::-1, :] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testRevNumpy(self): fun = lambda x: x[:, ::-1] R = onp.random.RandomState(0).randn x = R(3, 2, 4) ans = vmap(fun)(x) expected_ans = x[:, :, ::-1] self.assertAllClose(ans, expected_ans, check_dtypes=False) ans = vmap(fun, (1, ), 1)(x) expected_ans = x[:, :, ::-1] self.assertAllClose(ans, expected_ans, check_dtypes=False) ans = vmap(fun, (2, ), 2)(x) expected_ans = x[:, ::-1, :] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testNpMaximum(self): fun = lambda x: np.maximum(x, 0.0) R = onp.random.RandomState(0).randn x = R(10, 5, 3, 7) ans = vmap(fun)(x) expected_ans = onp.maximum(x, 0.0) self.assertAllClose(ans, expected_ans, check_dtypes=False) def testNpGtrThan(self): R = onp.random.RandomState(0).randn x = R(10, 5, 3, 7) ans = vmap(lambda x: x > 1.0)(x) expected_ans = x > 1.0 self.assertAllClose(ans, expected_ans, check_dtypes=True) def testNpMaximumPerExampleGrad(self): R = onp.random.RandomState(0).randn x = R(10, 5) W = R(5, 5) fun = lambda W, x: np.sum(np.maximum(np.dot(x, W), 0.0)**2) ans = vmap(partial(grad(fun), W))(x) W_t = np.transpose(W) for i in range(10): x_ex = x[i:i + 1] expected_ans = 2.0 * np.dot( np.maximum(np.dot(W_t, np.transpose(x_ex)), 0.0), x_ex) expected_ans = np.transpose(expected_ans) self.assertAllClose(ans[i], expected_ans, check_dtypes=False) def testDotGeneral(self): R = onp.random.RandomState(0).randn x = R(10, 3, 4, 5) y = R(10, 3, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ), (0, ))]) ans = vmap(fun)(x, y) expected = lax.dot_general(x, y, [((3, ), (2, )), ((0, 1), (0, 1))]) self.assertAllClose(ans, expected, check_dtypes=True) x = R(3, 4, 10, 5) y = R(3, 10, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ), (0, ))]) ans = vmap(fun, in_axes=(2, 1))(x, y) fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ), (0, ))]) expected = onp.stack( [fun(x[..., i, :], y[:, i, ...]) for i in range(10)]) self.assertAllClose(ans, expected, check_dtypes=True) x = R(3, 4, 5, 10) y = R(3, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ), (0, ))]) ans = vmap(fun, in_axes=(3, None))(x, y) fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ), (0, ))]) expected = onp.stack([fun(x[..., i], y) for i in range(10)]) self.assertAllClose(ans, expected, check_dtypes=True) x = R(3, 4, 5) y = R(3, 5, 10, 6) fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ), (0, ))]) ans = vmap(fun, in_axes=(None, 2))(x, y) fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ), (0, ))]) expected = onp.stack([fun(x, y[..., i, :]) for i in range(10)]) self.assertAllClose(ans, expected, check_dtypes=True) def testDot(self): # these tests are based on @shoyer's notebook studying gufuncs def vecvec(a, b): dot = np.dot for ndim in range(1, max(a.ndim, b.ndim)): a_ax = 0 if a.ndim > ndim else None b_ax = 0 if b.ndim > ndim else None dot = vmap(dot, in_axes=(a_ax, b_ax)) return dot(a, b) assert vecvec(np.zeros((3, )), np.zeros((3, ))).shape == () assert vecvec(np.zeros((2, 3)), np.zeros((3, ))).shape == (2, ) # TODO(mattjj): this fails due to an xla error in dot_general # assert vecvec(np.zeros((4, 2, 3)), np.zeros((3,))).shape == (4, 2) def testPad(self): R = onp.random.RandomState(0).randn fun = lambda x: lax.pad(x, onp.float32(0), [(1, 2, 1)]) x = R(5, 10).astype(onp.float32) ans = vmap(fun)(x) expected_ans = np.stack(list(map(fun, x))) self.assertAllClose(ans, expected_ans, check_dtypes=False) fun = lambda x: lax.pad(x, onp.float32(0), [(1, 2, 1), (0, 1, 0)]) x = R(5, 10, 3).astype(onp.float32) ans = vmap(fun)(x) expected_ans = np.stack(list(map(fun, x))) self.assertAllClose(ans, expected_ans, check_dtypes=False) def testConcatenate(self): R = lambda *shape: onp.random.RandomState(0).randn(*shape).astype( onp.float32) fun = lambda *args: lax.concatenate(args, dimension=0) x, y, z = R(10, 2, 3), R(1, 10, 3), R(4, 3) ans = vmap(fun, in_axes=(0, 1, None))(x, y, z) expected_ans = onp.concatenate( [x, onp.swapaxes(y, 0, 1), onp.broadcast_to(z, (10, 4, 3))], 1) self.assertAllClose(ans, expected_ans, check_dtypes=False) fun = lambda *args: lax.concatenate(args, dimension=1) x, y, z = R(10, 2, 1), R(2, 3), R(2, 4, 10) ans = vmap(fun, in_axes=(0, None, 2))(x, y, z) expected_ans = onp.concatenate( [x, onp.broadcast_to(y, (10, 2, 3)), onp.moveaxis(z, 2, 0)], 2) self.assertAllClose(ans, expected_ans, check_dtypes=False) def testJacobianIssue54(self): # test modeling the code in https://github.com/google/jax/issues/54 def func(xs): return np.array([x for x in xs]) xs = np.ones((5, 1)) jacrev(func)(xs) # don't crash jacfwd(func)(xs) # don't crash def testAny(self): # test modeling the code in https://github.com/google/jax/issues/108 ans = vmap(np.any)(np.array([[True, False], [False, False]])) expected = np.array([True, False]) self.assertAllClose(ans, expected, check_dtypes=True) @jtu.skip_on_devices("tpu") def testHessian(self): # test based on code from sindhwani@google def fun(x, t): return np.sum(np.power(np.maximum(x, 0.0), 2)) + t x = onp.array([-1., -0.5, 0., 0.5, 1.0]) ans = hessian(lambda x: fun(x, 0.0))(x) expected = onp.array([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0.5, 0., 0.], [0., 0., 0., 2., 0.], [0., 0., 0., 0., 2.]]) self.assertAllClose(ans, expected, check_dtypes=False) def testDynamicSlice(self): # test dynamic_slice via numpy indexing syntax x = onp.arange(30).reshape((10, 3)) ans = vmap(lambda x, i: x[i], in_axes=(0, None))(x, 1) expected = x[:, 1] self.assertAllClose(ans, expected, check_dtypes=False) idx = onp.array([0, 1, 2, 1, 0] * 2) ans = vmap(lambda x, i: x[i], in_axes=(0, 0))(x, idx) expected = x[onp.arange(10), idx] self.assertAllClose(ans, expected, check_dtypes=False) x = onp.arange(3) idx = onp.array([0, 1, 2, 1, 0] * 2) ans = vmap(lambda x, i: x[i], in_axes=(None, 0))(x, idx) expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testRandom(self): seeds = vmap(random.PRNGKey)(onp.arange(10)) ans = vmap(partial(random.normal, shape=(3, 2)))(seeds) expected = onp.stack([ random.normal(random.PRNGKey(seed), (3, 2)) for seed in onp.arange(10) ]) self.assertAllClose(ans, expected, check_dtypes=False) assert len(onp.unique(ans)) == 10 * 3 * 2 def testSortKeyVal(self): k = onp.arange(12)[::-1].reshape(3, 4) v = onp.random.RandomState(0).permutation(12).reshape(3, 4) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 0))(k, v) self.assertAllClose(sk, k[:, ::-1], check_dtypes=True) self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 1), 1)(k, v) self.assertAllClose(sk, k[::-1, :], check_dtypes=True) self.assertAllClose(sv, v[::-1, :], check_dtypes=True) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 1))(k, v.T) self.assertAllClose(sk, k[:, ::-1], check_dtypes=True) self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 0))(k.T, v) self.assertAllClose(sk, k[:, ::-1], check_dtypes=True) self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (None, 0))(k[0], v) self.assertAllClose(sk, onp.broadcast_to(k[0, ::-1], (3, 4)), check_dtypes=True) self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, None))(k.T, v[0]) self.assertAllClose(sk, k[:, ::-1], check_dtypes=True) self.assertAllClose(sv, onp.broadcast_to(v[0, ::-1], (3, 4)), check_dtypes=True) def testConvGeneralDilated(self): W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32) X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32) def f(params, x): one = (1, 1) dimension_numbers = ('NHWC', 'HWIO', 'NHWC') y = lax.conv_general_dilated(x, params, one, 'SAME', one, one, dimension_numbers) return y grad_loss = grad(lambda params, x: np.mean(f(params, x)**2)) # Test forward prop. per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example = np.reshape(per_example, (10, 5, 5, 5)) per_example_direct = f(W, X) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) # Test gradients. per_example = vmap(partial(grad_loss, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example_direct = [] for i in range(10): g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1))) per_example_direct += [np.reshape(g, (1, ) + g.shape)] per_example_direct = np.concatenate(per_example_direct, axis=0) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) def testMaxPool(self): W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32) X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32) def f(params, x): one = (1, 1) dimension_numbers = ('NHWC', 'HWIO', 'NHWC') y = lax.conv_general_dilated(x, params, one, 'SAME', one, one, dimension_numbers) y = lax.reduce_window(y, -np.inf, lax.max, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME') return y grad_loss = grad(lambda params, x: np.mean(f(params, x)**2)) # Test forward prop. per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example = np.reshape(per_example, (10, 5, 5, 5)) per_example_direct = f(W, X) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) # Test gradients. per_example = vmap(partial(grad_loss, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example_direct = [] for i in range(10): g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1))) per_example_direct += [np.reshape(g, (1, ) + g.shape)] per_example_direct = np.concatenate(per_example_direct, axis=0) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) def testSumPool(self): W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32) X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32) def f(params, x): one = (1, 1) dimension_numbers = ('NHWC', 'HWIO', 'NHWC') y = lax.conv_general_dilated(x, params, one, 'SAME', one, one, dimension_numbers) y = lax.reduce_window(y, 0.0, lax.add, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME') return y grad_loss = grad(lambda params, x: np.mean(f(params, x)**2)) # Test forward prop. per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example = np.reshape(per_example, (10, 5, 5, 5)) per_example_direct = f(W, X) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) # Test gradients. per_example = vmap(partial(grad_loss, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example_direct = [] for i in range(10): g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1))) per_example_direct += [np.reshape(g, (1, ) + g.shape)] per_example_direct = np.concatenate(per_example_direct, axis=0) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) def testSelect(self): pred = onp.array([True, False]) on_true = onp.array([0, 1]) on_false = onp.array([2, 3]) ans = vmap(lax.select)(pred, on_true, on_false) expected = onp.array([0, 3]) self.assertAllClose(ans, expected, check_dtypes=True) pred = onp.array([False, True]) on_true = onp.array([0, 1]) on_false = onp.array([2, 3]) ans = vmap(lax.select, (0, None, None))(pred, on_true, on_false) expected = onp.array([[2, 3], [0, 1]]) self.assertAllClose(ans, expected, check_dtypes=True) pred = True on_true = onp.array([0, 1], onp.float32) on_false = onp.array(3, onp.float32) ans = vmap(lax.select, (None, 0, None))(pred, on_true, on_false) expected = onp.array([0, 1], onp.float32) self.assertAllClose(ans, expected, check_dtypes=True) pred = onp.array([False, True]) on_true = onp.array([0, 1], onp.float32) on_false = onp.array(3, onp.float32) ans = vmap(lax.select, (0, 0, None))(pred, on_true, on_false) expected = onp.array([3, 1], onp.float32) self.assertAllClose(ans, expected, check_dtypes=True) pred = onp.array([False, True]) on_true = onp.array([2], onp.float32) on_false = onp.array([[3, 4]], onp.float32) ans = vmap(lax.select, (0, None, 1), 1)(pred, on_true, on_false) expected = onp.array([[3, 2]], onp.float32) self.assertAllClose(ans, expected, check_dtypes=True) def testLaxLinalgCholesky(self): a = onp.random.RandomState(0).randn(10, 5, 5).astype(onp.float32) a = onp.matmul(a, onp.conj(onp.swapaxes(a, -1, -2))) ans = vmap(lax_linalg.cholesky)(a) expected = onp.linalg.cholesky(a) self.assertAllClose(ans, expected, check_dtypes=False) b = onp.random.RandomState(0).randn(10, 5, 5).astype(onp.float32) b = onp.matmul(b, onp.conj(onp.swapaxes(b, -1, -2))) b_trans = onp.swapaxes(b, 0, 1) # shape is (5, 10, 5) ans = vmap(lax_linalg.cholesky, in_axes=1, out_axes=0)(b_trans) expected = onp.linalg.cholesky(b) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( { "testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx } for dtype in [onp.float32, onp.int32] for axis, shape, idxs, dnums, slice_sizes in [ (0, (3, 5), onp.array([0, 2]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, ), index_vector_dim=1), (1, )), (1, (10, 3), onp.array([0, 0, 0]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, ), index_vector_dim=1), (2, )), (1, (10, 3, 5), onp.array([0, 2, 1]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, ), index_vector_dim=1), (1, 3)), (2, (10, 5, 3), onp.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, 1), index_vector_dim=1), (1, 3)), ] for rng_idx in [jtu.rand_int(max(shape))] for rng in [jtu.rand_default()]) def testGatherBatchedOperand(self, axis, shape, dtype, idxs, dnums, slice_sizes, rng, rng_idx): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) operand = rng(shape, dtype) ans = vmap(fun, (axis, None))(operand, idxs) expected = onp.stack([ fun(operand[(slice(None), ) * axis + (i, )], idxs) for i in range(operand.shape[axis]) ]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( { "testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx } for dtype in [onp.float32, onp.float64] for axis, shape, idxs, dnums, slice_sizes in [ (0, (3, 5), onp.array([0, 2]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, ), index_vector_dim=1), (1, )), (1, (10, 3), onp.array([0, 0, 0]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, ), index_vector_dim=1), (2, )), (1, (10, 3, 5), onp.array([0, 2, 1]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, ), index_vector_dim=1), (1, 3)), (2, (10, 5, 3), onp.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, 1), index_vector_dim=1), (1, 3)), ] for rng_idx in [jtu.rand_int(max(shape))] for rng in [jtu.rand_default()]) def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums, slice_sizes, rng, rng_idx): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx)))) operand = rng(shape, dtype) ans = vmap(gfun, (axis, None))(operand, idxs) expected = onp.stack([ gfun(operand[(slice(None), ) * axis + (i, )], idxs) for i in range(operand.shape[axis]) ]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( { "testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx } for dtype in [onp.float32, onp.int32] for axis, shape, idxs, dnums, slice_sizes in [ (0, (5, ), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, ), index_vector_dim=1), (1, )), (1, (10, ), onp.array([[0, 0, 0], [0, 2, 1]]).T, lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, ), index_vector_dim=1), (2, )), (1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T, lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, ), index_vector_dim=1), (1, 3)), (0, (10, 5), onp.array([[[0, 2], [1, 0]], [[1, 2], [0, 3]]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, 1), index_vector_dim=1), (1, 3)), ] for rng_idx in [jtu.rand_int(max(shape))] for rng in [jtu.rand_default()]) def testGatherBatchedIndices(self, axis, shape, dtype, idxs, dnums, slice_sizes, rng, rng_idx): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) operand = rng(shape, dtype) ans = vmap(fun, (None, axis))(operand, idxs) expected = onp.stack([ fun(operand, idxs[(slice(None), ) * axis + (i, )]) for i in range(idxs.shape[axis]) ]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( { "testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx } for dtype in [onp.float32, onp.float64] for axis, shape, idxs, dnums, slice_sizes in [ (0, (5, ), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, ), index_vector_dim=1), (1, )), (1, (10, ), onp.array([[0, 0, 0], [0, 2, 1]]).T, lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, ), index_vector_dim=1), (2, )), (1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T, lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, ), index_vector_dim=1), (1, 3)), (0, (10, 5), onp.array([[[0, 2], [1, 0]], [[1, 2], [0, 3]]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, 1), index_vector_dim=1), (1, 3)), ] for rng_idx in [jtu.rand_int(max(shape))] for rng in [jtu.rand_default()]) def testGatherGradBatchedIndices(self, axis, shape, dtype, idxs, dnums, slice_sizes, rng, rng_idx): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx)))) operand = rng(shape, dtype) ans = vmap(gfun, (None, axis))(operand, idxs) expected = onp.stack([ gfun(operand, idxs[(slice(None), ) * axis + (i, )]) for i in range(idxs.shape[axis]) ]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( { "testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}" .format(jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs, dnums, slice_sizes), "op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx } for dtype in [onp.float32, onp.int32] for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [ (0, 0, (2, 5), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, ), index_vector_dim=1), (1, )), (1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T, lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, ), index_vector_dim=1), (2, )), (0, 1, ( 2, 10, 5, ), onp.array([[0, 2, 1], [0, 3, 3]]).T, lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, ), index_vector_dim=1), (1, 3)), (2, 0, (10, 5, 2), onp.array([[[0, 2], [1, 0]], [[1, 0], [2, 0]]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, 1), index_vector_dim=1), (1, 3)), ] for rng_idx in [jtu.rand_int(max(shape))] for rng in [jtu.rand_default()]) def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums, slice_sizes, rng, rng_idx): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) operand = rng(shape, dtype) assert operand.shape[op_axis] == idxs.shape[idxs_axis] ans = vmap(fun, (op_axis, idxs_axis))(operand, idxs) expected = onp.stack([ fun(operand[(slice(None), ) * op_axis + (i, )], idxs[(slice(None), ) * idxs_axis + (i, )]) for i in range(idxs.shape[idxs_axis]) ]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( { "testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}" .format(jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs, dnums, slice_sizes), "op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx } for dtype in [onp.float32, onp.int32] for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [ (0, 0, (2, 5), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, ), index_vector_dim=1), (1, )), (1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T, lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, ), index_vector_dim=1), (2, )), (0, 1, ( 2, 10, 5, ), onp.array([[0, 2, 1], [0, 3, 3]]).T, lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, ), index_vector_dim=1), (1, 3)), (2, 0, (10, 5, 2), onp.array([[[0, 2], [1, 0]], [[1, 0], [2, 0]]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, 1), index_vector_dim=1), (1, 3)), ] for rng_idx in [jtu.rand_int(max(shape))] for rng in [jtu.rand_default()]) def testGatherGradBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums, slice_sizes, rng, rng_idx): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx)))) operand = rng(shape, dtype) assert operand.shape[op_axis] == idxs.shape[idxs_axis] ans = vmap(gfun, (op_axis, idxs_axis))(operand, idxs) expected = onp.stack([ gfun(operand[(slice(None), ) * op_axis + (i, )], idxs[(slice(None), ) * idxs_axis + (i, )]) for i in range(idxs.shape[idxs_axis]) ]) self.assertAllClose(ans, expected, check_dtypes=False) def testNumpyIndexing1(self): a = np.arange(2 * 3 * 4).reshape((2, 3, 4)) ind = onp.array([[0, 1], [2, 0]]) def f(a, ind): return a[:, ind] expected = onp.stack([f(a, ind[i, :]) for i in range(ind.shape[0])]) ans = vmap(f, (None, 0))(a, ind) assert onp.all(ans == expected) def testNumpyIndexing2(self): a = np.arange(2 * 3 * 4).reshape((2, 3, 4)) def f(a): inds = np.array([0, 2]) return a[:, inds] ans = vmap(f)(a) expected = onp.stack([f(a[:, i, :]) for i in range(a.shape[1])], axis=1) assert onp.all(ans == expected) def testTranspose(self): x = onp.arange(4 * 3 * 3).reshape((4, 3, 3)) ans = vmap(lambda x: x + x.T)(x) expected = x + onp.swapaxes(x, -1, -2) self.assertAllClose(ans, expected, check_dtypes=False) def testTransposePermutation(self): x = onp.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5)) ans = vmap(lambda x: np.transpose(x, (1, 0, 2)))(x) expected = onp.transpose(x, (0, 2, 1, 3)) self.assertAllClose(ans, expected, check_dtypes=False) x = onp.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5)) ans = vmap(lambda x: np.transpose(x, (1, 2, 0)))(x) expected = onp.transpose(x, (0, 2, 3, 1)) self.assertAllClose(ans, expected, check_dtypes=False) x = onp.arange(6 * 3 * 4 * 5).reshape((3, 4, 6, 5)) ans = vmap(lambda x: np.transpose(x, (1, 2, 0)), in_axes=2)(x) expected = onp.transpose(x, (2, 1, 3, 0)) self.assertAllClose(ans, expected, check_dtypes=False) def testIssue354(self): psd_mat = onp.random.randn(20, 10) psd_mat = psd_mat.T.dot(psd_mat) vec = onp.random.randn(10) def f(scale): scaled_mat = scale * psd_mat chol = np.linalg.cholesky(scaled_mat) return -0.5 * np.sum((np.einsum('ij,j->i', chol, vec))**2) vmapped_f = vmap(f) vmapped_f_grad = grad(lambda x: np.sum(vmapped_f(x))) scales = onp.array([[0.1], [0.2], [0.3], [0.4], [0.5]]) ans = vmapped_f_grad(scales) # don't crash! expected = onp.stack([grad(f)(scale) for scale in scales]) self.assertAllClose(ans, expected, check_dtypes=False)
def predict(self, params, logits, context, target=None): context = jnp.expand_dims(jnp.expand_dims(jnp.expand_dims(context, axis=1), axis=1), axis=1) context_bias = params.get('context_bias', 0.0) context_index = (params['context_maps'] * context).sum(axis=-1) > context_bias context_map_values = jnp.asarray( [[[[1 << n for n in range(self.context_map_size)]]]]) context_index = jnp.where(context_index, context_map_values, 0) context_index = context_index.sum(axis=-1, keepdims=True) batch_size = logits.shape[0] class_neuron_index = jnp.asarray([[[[c, n] for n in range(self.size)] for c in range(self.num_classes)]]) class_neuron_index = jnp.tile(class_neuron_index, reps=(batch_size, 1, 1, 1)) context_index = jnp.concatenate([class_neuron_index, context_index], axis=-1) dims = lax.GatherDimensionNumbers(offset_dims=(3, ), collapsed_slice_dims=(0, 1, 2), start_index_map=(0, 1, 2)) weights = lax.gather(operand=params['weights'], start_indices=context_index, dimension_numbers=dims, slice_sizes=(1, 1, 1, self.input_size + int(self.bias))) if self.bias: bias = jnp.tile(params['bias'], reps=(batch_size, 1, 1)) logits = jnp.concatenate([logits, bias], axis=-1) logits = jnp.expand_dims(logits, axis=-1) output_logits = jnp.matmul(weights, logits) output_logits = jnp.clip(output_logits, a_min=jsp.special.logit(self.pred_clipping), a_max=jsp.special.logit(1.0 - self.pred_clipping)) if target is None: return jnp.squeeze(output_logits, axis=-1) else: logits = jnp.expand_dims(jnp.squeeze(logits, axis=-1), axis=-2) output_preds = jnn.sigmoid(output_logits) target = jnp.expand_dims(jnp.expand_dims(target, axis=-1), axis=-1) params['lr_step'], learning_rate = self.learning_rate.value( params['lr_step']) delta = learning_rate * (target - output_preds) * logits dims = lax.ScatterDimensionNumbers( update_window_dims=(3, ), inserted_window_dims=(0, 1, 2), scatter_dims_to_operand_dims=(0, 1, 2)) if self.weight_clipping is None: params['weights'] = lax.scatter_add( operand=params['weights'], scatter_indices=context_index, updates=delta, dimension_numbers=dims) else: weights = jnp.clip(weights + delta, a_min=-self.weight_clipping, a_max=self.weight_clipping) params['weights'] = lax.scatter(operand=params['weights'], scatter_indices=context_index, updates=weights, dimension_numbers=dims) return params, jnp.squeeze(output_logits, axis=-1)