Ejemplo n.º 1
0
def _scale_and_translate(x, output_shape, scale, translate, kernel, antialias):
    input_shape = x.shape
    assert len(input_shape) == len(output_shape)
    assert len(input_shape) == len(scale)
    assert len(input_shape) == len(translate)
    spatial_dims = np.nonzero(
        np.not_equal(input_shape, output_shape) | np.not_equal(scale, 1)
        | np.not_equal(translate, 0))[0]
    if len(spatial_dims) == 0:
        return x
    output_spatial_shape = tuple(np.array(output_shape)[spatial_dims])
    indices = []
    contractions = []
    slice_shape = list(input_shape)
    in_indices = list(range(len(output_shape) + len(spatial_dims)))
    out_indices = list(range(len(output_shape)))
    for i, d in enumerate(spatial_dims):
        m = input_shape[d]
        n = output_shape[d]
        starts, weights = _compute_spans(m,
                                         n,
                                         scale[d],
                                         translate[d],
                                         kernel,
                                         antialias=antialias)
        starts = lax.broadcast_in_dim(starts, output_spatial_shape + (1, ),
                                      (i, ))
        slice_shape[d] = weights.shape[1]
        indices.append(starts.astype(np.int32))
        contractions.append(weights.astype(x.dtype))
        contractions.append([len(output_shape) + i, d])
        out_indices[d] = len(output_shape) + i
    index = lax.concatenate(indices, len(output_spatial_shape))
    dnums = lax.GatherDimensionNumbers(offset_dims=tuple(
        range(len(output_shape))),
                                       collapsed_slice_dims=(),
                                       start_index_map=tuple(spatial_dims))
    out = lax.gather(x, index, dnums, slice_shape)
    contractions.append(out_indices)
    return jnp.einsum(out,
                      in_indices,
                      *contractions,
                      precision=lax.Precision.HIGHEST)
Ejemplo n.º 2
0
class LaxVmapTest(jtu.JaxTestCase):
    def _CheckBatching(self,
                       op,
                       bdim_size,
                       bdims,
                       shapes,
                       dtypes,
                       rng,
                       rtol=None,
                       atol=None):
        batched_shapes = map(partial(add_bdim, bdim_size), bdims, shapes)
        args = [
            rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes)
        ]
        args_slice = args_slicer(args, bdims)
        ans = api.vmap(op, bdims)(*args)
        if bdim_size == 0:
            args = [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
            out = op(*args)
            expected = np.zeros((0, ) + out.shape, out.dtype)
        else:
            expected = np.stack([op(*args_slice(i)) for i in range(bdim_size)])
        self.assertAllClose(ans, expected, rtol=rtol, atol=atol)

    @parameterized.named_parameters(
        itertools.chain.from_iterable(
            jtu.cases_from_list(
                {
                    "testcase_name":
                    "{}_bdims={}".format(
                        jtu.format_test_name_suffix(
                            rec.op, shapes, itertools.repeat(dtype)), bdims),
                    "op_name":
                    rec.op,
                    "rng_factory":
                    rec.rng_factory,
                    "shapes":
                    shapes,
                    "dtype":
                    dtype,
                    "bdims":
                    bdims,
                    "tol":
                    rec.tol
                } for shape_group in compatible_shapes
                for shapes in itertools.combinations_with_replacement(
                    shape_group, rec.nargs) for bdims in all_bdims(*shapes)
                for dtype in rec.dtypes) for rec in LAX_OPS))
    def testOp(self, op_name, rng_factory, shapes, dtype, bdims, tol):
        rng = rng_factory(self.rng())
        op = getattr(lax, op_name)
        self._CheckBatching(op,
                            10,
                            bdims,
                            shapes, [dtype] * len(shapes),
                            rng,
                            atol=tol,
                            rtol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
            "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}"
            "_lhs_bdim={}_rhs_bdim={}".format(
                jtu.format_shape_dtype_string(lhs_shape, dtype),
                jtu.format_shape_dtype_string(rhs_shape, dtype), strides,
                padding, lhs_dil, rhs_dil, ",".join(dim_nums),
                feature_group_count, batch_group_count, lhs_bdim, rhs_bdim),
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "strides":
            strides,
            "padding":
            padding,
            "lhs_dil":
            lhs_dil,
            "rhs_dil":
            rhs_dil,
            "dimension_numbers":
            dim_nums,
            "perms":
            perms,
            "lhs_bdim":
            lhs_bdim,
            "rhs_bdim":
            rhs_bdim,
            "feature_group_count":
            feature_group_count,
            "batch_group_count":
            batch_group_count,
        } for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1,
                                                                           2)])
                            for lhs_shape, rhs_shape, all_strides, all_pads,
                            lhs_dils, rhs_dils in [
                                (
                                    (b * batch_group_count,
                                     i * feature_group_count, 6,
                                     7),  # lhs_shape
                                    (j * batch_group_count *
                                     feature_group_count, i, 1,
                                     2),  # rhs_shape
                                    [(1, 1), (1, 2), (2, 1)],  # strides
                                    [((0, 0),
                                      (0, 0)), ((1, 0),
                                                (0, 1)), ((0, -1),
                                                          (0, 0))],  # pads
                                    [(1, 1), (2, 1)],  # lhs_dils
                                    [(1, 1), (2, 2)])  # rhs_dils
                                for b, i, j in itertools.product([1, 2],
                                                                 repeat=3)
                            ] for strides in all_strides
                            for rhs_dil in rhs_dils for lhs_dil in lhs_dils
                            for dtype in [np.float32] for padding in all_pads
                            for dim_nums, perms in [(("NCHW", "OIHW", "NCHW"),
                                                     ([0, 1, 2, 3],
                                                      [0, 1, 2, 3])),
                                                    (("NHWC", "HWIO", "NHWC"),
                                                     ([0, 2, 3, 1],
                                                      [2, 3, 1, 0])),
                                                    (("NHWC", "OIHW", "NCHW"),
                                                     ([0, 2, 3, 1],
                                                      [0, 1, 2, 3]))]
                            for lhs_bdim in itertools.chain(
                                [cast(Optional[int], None)],
                                range(len(lhs_shape) + 1))
                            for rhs_bdim in itertools.chain(
                                [cast(Optional[int], None)],
                                range(len(rhs_shape) + 1))
                            if (lhs_bdim, rhs_bdim) != (None, None)))
    def testConvGeneralDilatedBatching(self, lhs_shape, rhs_shape, dtype,
                                       strides, padding, lhs_dil, rhs_dil,
                                       dimension_numbers, perms,
                                       feature_group_count, batch_group_count,
                                       lhs_bdim, rhs_bdim):
        rng = jtu.rand_default(self.rng())
        tol = 1e-1 if dtypes.finfo(dtype).bits <= 32 else 1e-3

        # permute shapes to match dim_spec, scale by feature_group_count
        lhs_perm, rhs_perm = perms
        lhs_shape = list(np.take(lhs_shape, lhs_perm))
        rhs_shape = list(np.take(rhs_shape, rhs_perm))

        conv = partial(lax.conv_general_dilated,
                       window_strides=strides,
                       padding=padding,
                       lhs_dilation=lhs_dil,
                       rhs_dilation=rhs_dil,
                       dimension_numbers=dimension_numbers,
                       feature_group_count=feature_group_count,
                       batch_group_count=batch_group_count,
                       precision=lax.Precision.HIGHEST)
        self._CheckBatching(conv,
                            5, (lhs_bdim, rhs_bdim), (lhs_shape, rhs_shape),
                            (dtype, dtype),
                            rng,
                            rtol=tol,
                            atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format(
                shape, from_dtype, to_dtype, bdims),
            "shape":
            shape,
            "from_dtype":
            from_dtype,
            "to_dtype":
            to_dtype,
            "bdims":
            bdims
        } for from_dtype, to_dtype in itertools.product(
            [np.float32, np.int32, "float32", "int32"], repeat=2)
                            for shape in [(2, 3)]
                            for bdims in all_bdims(shape)))
    def testConvertElementType(self, shape, from_dtype, to_dtype, bdims):
        rng = jtu.rand_default(self.rng())
        op = lambda x: lax.convert_element_type(x, to_dtype)
        self._CheckBatching(op, 10, bdims, (shape, ), (from_dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format(
                shape, from_dtype, to_dtype, bdims),
            "shape":
            shape,
            "from_dtype":
            from_dtype,
            "to_dtype":
            to_dtype,
            "bdims":
            bdims
        } for from_dtype, to_dtype in itertools.product(
            [np.float32, np.int32, "float32", "int32"], repeat=2)
                            for shape in [(2, 3)]
                            for bdims in all_bdims(shape)))
    def testBitcastElementType(
        self,
        shape,
        from_dtype,
        to_dtype,
        bdims,
    ):
        rng = jtu.rand_default(self.rng())
        op = lambda x: lax.bitcast_convert_type(x, to_dtype)
        self._CheckBatching(op, 10, bdims, (shape, ), (from_dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_min_shape={}_operand_shape={}_max_shape={}_bdims={}".format(
                    jtu.format_shape_dtype_string(min_shape, dtype),
                    jtu.format_shape_dtype_string(operand_shape, dtype),
                    jtu.format_shape_dtype_string(max_shape, dtype), bdims),
                "min_shape":
                min_shape,
                "operand_shape":
                operand_shape,
                "max_shape":
                max_shape,
                "dtype":
                dtype,
                "bdims":
                bdims
            } for min_shape, operand_shape, max_shape in [
                [(), (2, 3), ()],
                [(2, 3), (2, 3), ()],
                [(), (2, 3), (2, 3)],
                [(2, 3), (2, 3), (2, 3)],
            ] for dtype in default_dtypes
            for bdims in all_bdims(min_shape, operand_shape, max_shape)))
    def testClamp(self, min_shape, operand_shape, max_shape, dtype, bdims):
        rng = jtu.rand_default(self.rng())
        raise SkipTest(
            "batching rule for clamp not implemented")  # TODO(mattj)
        shapes = [min_shape, operand_shape, max_shape]
        self._CheckBatching(lax.clamp, 10, bdims, shapes, [dtype] * 3, rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs_shape={}_rhs_shape={}_bdims={}".format(
                jtu.format_shape_dtype_string(lhs_shape, dtype),
                jtu.format_shape_dtype_string(rhs_shape, dtype), bdims),
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "bdims":
            bdims
        } for lhs_shape in [(3, ), (4, 3)] for rhs_shape in [(3, ), (3, 6)]
                            for bdims in all_bdims(lhs_shape, rhs_shape)
                            for dtype in default_dtypes))
    def testDot(self, lhs_shape, rhs_shape, dtype, bdims):
        rng = jtu.rand_default(self.rng())
        op = partial(lax.dot, precision=lax.Precision.HIGHEST)
        self._CheckBatching(op,
                            5,
                            bdims, (lhs_shape, rhs_shape), (dtype, dtype),
                            rng,
                            rtol={
                                np.float16: 5e-2,
                                np.float64: 5e-14
                            })

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}_bdims={}"
            .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
                    jtu.format_shape_dtype_string(rhs_shape, dtype),
                    lhs_contracting, rhs_contracting, bdims),
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "lhs_contracting":
            lhs_contracting,
            "rhs_contracting":
            rhs_contracting,
            "bdims":
            bdims
        } for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [
            [(5, ), (5, ), [0], [0]],
            [(5, 7), (5, ), [0], [0]],
            [(7, 5), (5, ), [1], [0]],
            [(3, 5), (2, 5), [1], [1]],
            [(5, 3), (5, 2), [0], [0]],
            [(5, 3, 2), (5, 2, 4), [0], [0]],
            [(5, 3, 2), (5, 2, 4), [0, 2], [0, 1]],
            [(5, 3, 2), (3, 5, 2, 4), [0, 2], [1, 2]],
            [(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]],
            [(3, 2), (2, 4), [1], [0]],
        ] for bdims in all_bdims(lhs_shape, rhs_shape)
                            for dtype in default_dtypes))
    def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype,
                                   lhs_contracting, rhs_contracting, bdims):
        rng = jtu.rand_small(self.rng())
        dimension_numbers = ((lhs_contracting, rhs_contracting), ([], []))
        dot = partial(lax.dot_general, dimension_numbers=dimension_numbers)
        self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape),
                            (dtype, dtype), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_bdims={}".format(
                jtu.format_shape_dtype_string(lhs_shape, dtype),
                jtu.format_shape_dtype_string(rhs_shape, dtype),
                dimension_numbers, bdims),
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "dimension_numbers":
            dimension_numbers,
            "bdims":
            bdims
        } for lhs_shape, rhs_shape, dimension_numbers in [
            ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
            ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))),
            ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
        ] for bdims in all_bdims(lhs_shape, rhs_shape)
                            for dtype in default_dtypes))
    def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype,
                                       dimension_numbers, bdims):
        rng = jtu.rand_small(self.rng())
        dot = partial(lax.dot_general, dimension_numbers=dimension_numbers)
        self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape),
                            (dtype, dtype), rng)

        # Checks that batching didn't introduce any transposes or broadcasts.
        jaxpr = api.make_jaxpr(dot)(np.zeros(lhs_shape, dtype),
                                    np.zeros(rhs_shape, dtype))
        for eqn in jtu.iter_eqns(jaxpr.jaxpr):
            self.assertFalse(eqn.primitive in ["transpose", "broadcast"])

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_dtype={}_broadcast_sizes={}_bdims={}".format(
                shape,
                np.dtype(dtype).name, broadcast_sizes, bdims),
            "shape":
            shape,
            "dtype":
            dtype,
            "broadcast_sizes":
            broadcast_sizes,
            "bdims":
            bdims
        } for shape in [(), (2, 3)] for dtype in default_dtypes
                            for broadcast_sizes in [(), (2, ), (1, 2)]
                            for bdims in all_bdims(shape)))
    def testBroadcast(self, shape, dtype, broadcast_sizes, bdims):
        rng = jtu.rand_default(self.rng())
        op = lambda x: lax.broadcast(x, broadcast_sizes)
        self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_outshape={}_bcdims={}_bdims={}".format(
                jtu.format_shape_dtype_string(inshape, dtype), outshape,
                broadcast_dimensions, bdims),
            "inshape":
            inshape,
            "dtype":
            dtype,
            "outshape":
            outshape,
            "dimensions":
            broadcast_dimensions,
            "bdims":
            bdims
        } for inshape, outshape, broadcast_dimensions in [
            ([2], [2, 2], [0]),
            ([2], [2, 2], [1]),
            ([2], [2, 3], [0]),
            ([], [2, 3], []),
        ] for dtype in default_dtypes for bdims in all_bdims(inshape)))
    def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims):
        rng = jtu.rand_default(self.rng())
        raise SkipTest("this test has failures in some cases")  # TODO(mattjj)
        op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
        self._CheckBatching(op, 5, bdims, (inshape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_dimensions={}_bdims={}".format(
                jtu.format_shape_dtype_string(arg_shape, np.float32),
                dimensions, bdims),
            "arg_shape":
            arg_shape,
            "dimensions":
            dimensions,
            "bdims":
            bdims
        } for arg_shape, dimensions in [
            [(1, ), (0, )],
            [(1, ), (-1, )],
            [(2, 1, 4), (1, )],
            [(2, 1, 4), (-2, )],
            [(2, 1, 3, 1), (1, )],
            [(2, 1, 3, 1), (1, 3)],
            [(2, 1, 3, 1), (3, )],
            [(2, 1, 3, 1), (1, -1)],
        ] for bdims in all_bdims(arg_shape)))
    def testSqueeze(self, arg_shape, dimensions, bdims):
        dtype = np.float32
        rng = jtu.rand_default(self.rng())
        op = lambda x: lax.squeeze(x, dimensions)
        self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_inshape={}_outshape={}_dims={}_bdims={}".format(
                    jtu.format_shape_dtype_string(arg_shape, dtype),
                    jtu.format_shape_dtype_string(out_shape, dtype),
                    dimensions, bdims),
                "arg_shape":
                arg_shape,
                "out_shape":
                out_shape,
                "dtype":
                dtype,
                "dimensions":
                dimensions,
                "bdims":
                bdims
            } for dtype in default_dtypes
            for arg_shape, dimensions, out_shape in [[(3, 4), None, (
                12, )], [(2, 1, 4), None, (8, )], [(2, 2, 4), None, (
                    2, 8)], [(2, 2,
                              4), (0, 1,
                                   2), (2, 8)], [(2, 2, 4), (1, 0, 2), (
                                       8, 2)], [(2, 2, 4), (2, 1, 0), (4, 2,
                                                                       2)]]
            for bdims in all_bdims(arg_shape)))
    def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims):
        rng = jtu.rand_default(self.rng())
        op = lambda x: lax.reshape(x, out_shape, dimensions=dimensions)
        self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_pads={}_bdims={}".format(
                jtu.format_shape_dtype_string(shape, dtype), pads, bdims),
            "shape":
            shape,
            "dtype":
            dtype,
            "pads":
            pads,
            "bdims":
            bdims
        } for shape in [(2, 3)] for bdims in all_bdims(shape)
                            for dtype in default_dtypes
                            for pads in [[(1, 2, 1), (0, 1, 0)]]))
    def testPad(self, shape, dtype, pads, bdims):
        rng = jtu.rand_small(self.rng())
        fun = lambda operand: lax.pad(operand, np.array(0, dtype), pads)
        self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_predshape={}_argshapes={}_bdims={}".format(
                    jtu.format_shape_dtype_string(pred_shape, np.bool_),
                    jtu.format_shape_dtype_string(arg_shape, arg_dtype),
                    bdims),
                "pred_shape":
                pred_shape,
                "arg_shape":
                arg_shape,
                "arg_dtype":
                arg_dtype,
                "bdims":
                bdims
            }
            for arg_shape in [(), (3, ), (2, 3)]
            for pred_shape in ([(), arg_shape] if arg_shape else [()])
            for bdims in all_bdims(pred_shape, arg_shape, arg_shape)
            for arg_dtype in default_dtypes))
    def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims):
        rng = jtu.rand_default(self.rng())
        op = lambda c, x, y: lax.select(c < 0, x, y)
        self._CheckBatching(op, 5, bdims, (
            pred_shape,
            arg_shape,
            arg_shape,
        ), (np.bool_, arg_dtype, arg_dtype), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_start_indices={}_limit_indices={}_strides={}_bdims={}".
            format(jtu.format_shape_dtype_string(shape, dtype), start_indices,
                   limit_indices, strides, bdims),
            "shape":
            shape,
            "dtype":
            dtype,
            "starts":
            start_indices,
            "limits":
            limit_indices,
            "strides":
            strides,
            "bdims":
            bdims
        } for shape, start_indices, limit_indices, strides in [
            [(3, ), (1, ), (2, ), None],
            [(7, ), (4, ), (7, ), None],
            [(5, ), (1, ), (5, ), (2, )],
            [(8, ), (1, ), (6, ), (2, )],
            [(5, 3), (1, 1), (3, 2), None],
            [(5, 3), (1, 1), (3, 1), None],
            [(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
            [(5, 3), (1, 1), (2, 1), (1, 1)],
            [(5, 3), (1, 1), (5, 3), (2, 1)],
        ] for bdims in all_bdims(shape) for dtype in default_dtypes))
    def testSlice(self, shape, dtype, starts, limits, strides, bdims):
        rng = jtu.rand_default(self.rng())
        op = lambda x: lax.slice(x, starts, limits, strides)
        self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_perm={}_bdims={}".format(
                jtu.format_shape_dtype_string(shape, dtype), perm, bdims),
            "shape":
            shape,
            "dtype":
            dtype,
            "perm":
            perm,
            "bdims":
            bdims
        } for shape, perm in [
            [(3, 4), (1, 0)],
            [(3, 4), (0, 1)],
            [(3, 4, 5), (2, 1, 0)],
            [(3, 4, 5), (1, 0, 2)],
        ] for bdims in all_bdims(shape) for dtype in default_dtypes))
    def testTranspose(self, shape, dtype, perm, bdims):
        rng = jtu.rand_default(self.rng())
        op = lambda x: lax.transpose(x, perm)
        self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_op={}_inshape={}_reducedims={}_initval={}_bdims={}".format(
                op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims,
                init_val, bdims),
            "op":
            op,
            "init_val":
            init_val,
            "shape":
            shape,
            "dtype":
            dtype,
            "dims":
            dims,
            "bdims":
            bdims
        } for init_val, op, dtypes in [
            (0, lax.add, default_dtypes),
            (1, lax.mul, default_dtypes),
            (0, lax.max, all_dtypes),  # non-monoidal
            (-np.inf, lax.max, float_dtypes),
            (dtypes.iinfo(np.int32).min, lax.max, [np.int32]),
            (dtypes.iinfo(np.int64).min, lax.max, [np.int64]),
            (dtypes.iinfo(np.uint32).min, lax.max, [np.uint32]),
            (dtypes.iinfo(np.uint64).min, lax.max, [np.uint64]),
            (np.inf, lax.min, float_dtypes),
            (dtypes.iinfo(np.int32).max, lax.min, [np.int32]),
            (dtypes.iinfo(np.int64).max, lax.min, [np.int64]),
            (dtypes.iinfo(np.uint32).max, lax.min, [np.uint32]),
            (dtypes.iinfo(np.uint64).max, lax.min, [np.uint64]),
        ] for dtype in dtypes for shape, dims in [[(3, 4, 5), (
            0, )], [(3, 4, 5), (1, 2)], [(3, 4,
                                          5), (0, 2)], [(3, 4, 5), (0, 1, 2)]]
                            for bdims in all_bdims(shape)))
    def testReduce(self, op, init_val, shape, dtype, dims, bdims):
        rng = jtu.rand_small(self.rng())
        init_val = np.asarray(init_val, dtype=dtype)
        fun = lambda operand: lax.reduce(operand, init_val, op, dims)
        self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_op={}_inshape={}_reducedims={}_bdims={}".format(
                op.__name__, jtu.format_shape_dtype_string(shape, dtype), dim,
                bdims),
            "op":
            op,
            "shape":
            shape,
            "dtype":
            dtype,
            "dim":
            dim,
            "bdims":
            bdims
        } for op in [lax.argmin, lax.argmax] for dtype in default_dtypes
                            for shape in [(3, 4, 5)]
                            for dim in range(len(shape))
                            for bdims in all_bdims(shape)))
    def testArgminmax(self, op, shape, dtype, dim, bdims):
        rng = jtu.rand_default(self.rng())
        fun = lambda operand: op(operand, dim, np.int32)
        self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                ("_op={}_shape={}_dims={}_strides={}_padding={}"
                 "_basedilation={}_windowdilation={}").format(
                     op.__name__, jtu.format_shape_dtype_string(shape, dtype),
                     dims, strides, padding, base_dilation, window_dilation),
                "op":
                op,
                "init_val":
                init_val,
                "dtype":
                dtype,
                "shape":
                shape,
                "dims":
                dims,
                "strides":
                strides,
                "padding":
                padding,
                "base_dilation":
                base_dilation,
                "window_dilation":
                window_dilation
            } for init_val, op, dtypes in [
                (0, lax.add, [np.float32]),
                (-np.inf, lax.max, [np.float32]),
                (np.inf, lax.min, [np.float32]),
            ] for shape, dims, strides, padding, base_dilation, window_dilation
            in (itertools.chain(
                itertools.product([(4, 6)], [(2, 1), (1, 2)], [(1, 1), (
                    2, 1), (1, 2)], ["VALID", "SAME", [(0, 3), (1, 2)]], [(
                        1, 1), (2, 3)], [(1, 1), (1, 2)]),
                itertools.product([(3, 2, 4, 6)], [(1, 1, 2, 1), (
                    2, 1, 2, 1)], [(1, 2, 2, 1), (
                        1, 1, 1, 1
                    )], ["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]], [(
                        1, 1, 1, 1), (2, 1, 3, 2)], [(1, 1, 1, 1), (1, 2, 2,
                                                                    1)])))
            for dtype in dtypes))
    def testReduceWindow(self, op, init_val, dtype, shape, dims, strides,
                         padding, base_dilation, window_dilation):
        rng = jtu.rand_small(self.rng())
        init_val = np.asarray(init_val, dtype=dtype)

        def fun(operand):
            return lax.reduce_window(operand, init_val, op, dims, strides,
                                     padding, base_dilation, window_dilation)

        for bdims in all_bdims(shape):
            self._CheckBatching(fun, 3, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_op={}_shape={}_axis={}_bdims={}_reverse={}".format(
                op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis,
                bdims, reverse),
            "op":
            op,
            "shape":
            shape,
            "dtype":
            dtype,
            "bdims":
            bdims,
            "axis":
            axis,
            "reverse":
            reverse
        } for op, types in [
            (lax.cumsum, [np.float32, np.float64]),
            (lax.cumprod, [np.float32, np.float64]),
        ] for dtype in types for shape in [[10], [3, 4, 5]]
                            for axis in range(len(shape))
                            for bdims in all_bdims(shape)
                            for reverse in [False, True]))
    def testCumulativeReduce(self, op, shape, dtype, axis, bdims, reverse):
        rng_factory = (jtu.rand_default if dtypes.issubdtype(
            dtype, np.integer) else jtu.rand_small)
        rng = rng_factory(self.rng())
        self._CheckBatching(partial(op, axis=axis, reverse=reverse), 7, bdims,
                            (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_dtype={}_padding={}".format(np.dtype(dtype).name, padding),
            "dtype":
            dtype,
            "padding":
            padding
        } for dtype in float_dtypes for padding in ["VALID", "SAME"]))
    @jtu.skip_on_flag("jax_skip_slow_tests", True)
    @jtu.ignore_warning(message="Using reduced precision for gradient.*")
    def testSelectAndGatherAdd(self, dtype, padding):
        if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16:
            raise SkipTest(
                "bfloat16 _select_and_gather_add doesn't work on tpu")
        rng = jtu.rand_small(self.rng())
        all_configs = itertools.chain(
            itertools.product([(4, 6)], [(2, 1), (1, 2)], [(1, 1), (2, 1),
                                                           (1, 2)]),
            itertools.product([(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
                              [(1, 2, 2, 1), (1, 1, 1, 1)]))

        def fun(operand, tangents):
            pads = lax.padtype_to_pads(operand.shape, dims, strides, padding)
            ones = (1, ) * len(operand.shape)
            return lax._select_and_gather_add(operand, tangents, lax.ge_p,
                                              dims, strides, pads, ones, ones)

        for shape, dims, strides in all_configs:
            for bdims in all_bdims(shape, shape):
                self._CheckBatching(fun, 3, bdims, (shape, shape),
                                    (dtype, dtype), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            f"_dtype={jtu.format_shape_dtype_string(shape, dtype)}"
            f"_padding={padding}_dims={dims}_strides={strides}",
            "dtype":
            dtype,
            "padding":
            padding,
            "shape":
            shape,
            "dims":
            dims,
            "strides":
            strides
        } for dtype in float_dtypes for padding in ["VALID", "SAME"]
                            for shape in [(3, 2, 4, 6)]
                            for dims in [(1, 1, 2, 1)]
                            for strides in [(1, 2, 2, 1), (1, 1, 1, 1)]))
    def testSelectAndScatterAdd(self, dtype, padding, shape, dims, strides):
        rng = jtu.rand_small(self.rng())

        pads = lax.padtype_to_pads(shape, dims, strides, padding)

        def fun(operand, cotangents):
            return lax._select_and_scatter_add(operand, cotangents, lax.ge_p,
                                               dims, strides, pads)

        ones = (1, ) * len(shape)
        cotangent_shape = api.eval_shape(
            lambda x: lax._select_and_gather_add(x, x, lax.ge_p, dims, strides,
                                                 pads, ones, ones),
            np.ones(shape, dtype)).shape

        for bdims in all_bdims(cotangent_shape, shape):
            self._CheckBatching(fun, 3, bdims, (cotangent_shape, shape),
                                (dtype, dtype), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_bdims={}_fft_ndims={}".format(shape, bdims, fft_ndims),
            "shape":
            shape,
            "bdims":
            bdims,
            "fft_ndims":
            fft_ndims
        } for shape in [(5, ), (3, 4, 5), (2, 3, 4, 5)]
                            for bdims in all_bdims(shape)
                            for fft_ndims in range(0,
                                                   min(3, len(shape)) + 1)))
    @jtu.skip_on_devices("tpu")  # TODO(b/137993701): unimplemented cases.
    def testFft(self, fft_ndims, shape, bdims):
        rng = jtu.rand_default(self.rng())
        ndims = len(shape)
        axes = range(ndims - fft_ndims, ndims)
        fft_lengths = [shape[axis] for axis in axes]
        op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths)
        self._CheckBatching(op, 5, bdims, [shape], [np.complex64], rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_idxs={}_dnums={}_slice_sizes={}_bdims={}".format(
                jtu.format_shape_dtype_string(shape, dtype), idxs, dnums,
                slice_sizes, bdims),
            "shape":
            shape,
            "dtype":
            dtype,
            "idxs":
            idxs,
            "dnums":
            dnums,
            "slice_sizes":
            slice_sizes,
            "bdims":
            bdims
        } for dtype in all_dtypes for shape, idxs, dnums, slice_sizes in [
            ((5, ), np.array([[0], [2]]),
             lax.GatherDimensionNumbers(offset_dims=(),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, )),
            ((10, ), np.array([[0], [0], [0]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(),
                                        start_index_map=(0, )), (2, )),
            ((
                10,
                5,
            ), np.array([[0], [2], [1]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, 3)),
            ((10, 5), np.array([[0, 2], [1, 0]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, 1)), (1, 3)),
        ] for bdims in all_bdims(shape, idxs.shape)))
    def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims):
        fun = partial(lax.gather,
                      dimension_numbers=dnums,
                      slice_sizes=slice_sizes)
        self._CheckBatching(fun, 0, bdims, [shape, idxs.shape],
                            [dtype, idxs.dtype], jtu.rand_default(self.rng()))
        self._CheckBatching(fun, 5, bdims, [shape, idxs.shape],
                            [dtype, idxs.dtype], jtu.rand_default(self.rng()))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_shape={}_idxs={}_update={}_dnums={}_bdims={}".format(
                    jtu.format_shape_dtype_string(arg_shape, dtype), idxs,
                    update_shape, dnums, bdims),
                "arg_shape":
                arg_shape,
                "dtype":
                dtype,
                "idxs":
                idxs,
                "update_shape":
                update_shape,
                "dnums":
                dnums,
                "bdims":
                bdims
            } for dtype in float_dtypes
            for arg_shape, idxs, update_shape, dnums in [
                ((5, ), np.array([[0], [2]]), (2, ),
                 lax.ScatterDimensionNumbers(update_window_dims=(),
                                             inserted_window_dims=(0, ),
                                             scatter_dims_to_operand_dims=(
                                                 0, ))),
                ((10, ), np.array([[0], [0], [0]]), (3, 2),
                 lax.ScatterDimensionNumbers(update_window_dims=(1, ),
                                             inserted_window_dims=(),
                                             scatter_dims_to_operand_dims=(
                                                 0, ))),
                ((
                    10,
                    5,
                ), np.array([[0], [2], [1]]), (3, 3),
                 lax.ScatterDimensionNumbers(update_window_dims=(1, ),
                                             inserted_window_dims=(0, ),
                                             scatter_dims_to_operand_dims=(
                                                 0, ))),
            ] for bdims in all_bdims(arg_shape, idxs.shape, update_shape)))
    def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums,
                       bdims):
        fun = partial(lax.scatter_add, dimension_numbers=dnums)
        self._CheckBatching(fun,
                            5,
                            bdims, [arg_shape, idxs.shape, update_shape],
                            [dtype, idxs.dtype, dtype],
                            jtu.rand_default(self.rng()),
                            rtol={np.float16: 5e-3})

    def testShapeUsesBuiltinInt(self):
        x = lax.iota(np.int32, 3) + 1
        self.assertIsInstance(x.shape[0], int)  # not np.int64

    def testBroadcastShapesReturnsPythonInts(self):
        shape1, shape2 = (1, 2, 3), (2, 3)
        out_shape = lax.broadcast_shapes(shape1, shape2)
        self.assertTrue(all(type(s) is int for s in out_shape))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_shape={}_k={}_bdims={}".format(
                    jtu.format_shape_dtype_string(shape, dtype), k, bdims),
                "shape":
                shape,
                "dtype":
                dtype,
                "k":
                k,
                "bdims":
                bdims,
                "rng_factory":
                rng_factory
            } for shape in [(4, ), (3, 5, 3)] for k in [1, 3]
            for bdims in all_bdims(shape)
            # TODO(b/155170120): test with repeats once the XLA:CPU stable top_k bug is fixed:
            # The top_k indices for integer arrays with identical entries won't match between
            # vmap'd version and manual reference, so only test unique integer arrays for int_dtypes.
            # Note also that we chose 3 * 5 * 3 * 5 such that it fits in the range of
            # values a bfloat16 can represent exactly to avoid ties.
            for dtype, rng_factory in itertools.chain(
                unsafe_zip(default_dtypes, itertools.repeat(
                    jtu.rand_unique_int)))))
    def testTopK(self, shape, dtype, k, bdims, rng_factory):
        rng = rng_factory(self.rng())
        # _CheckBatching doesn't work with tuple outputs, so test outputs separately.
        op1 = lambda x: lax.top_k(x, k=k)[0]
        self._CheckBatching(op1, 5, bdims, (shape, ), (dtype, ), rng)
        op2 = lambda x: lax.top_k(x, k=k)[1]
        self._CheckBatching(op2, 5, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_dimension={}_arity={}_bdims={}_isstable={}".format(
                jtu.format_shape_dtype_string(shape, np.float32), dimension,
                arity, bdims, is_stable),
            "shape":
            shape,
            "dimension":
            dimension,
            "arity":
            arity,
            "bdims":
            bdims,
            "is_stable":
            is_stable
        } for shape in [(2, 3)] for dimension in [0, 1] for arity in range(3)
                            for bdims in all_bdims(*((shape, ) * arity))
                            for is_stable in [False, True]))
    def testSort(self, shape, dimension, arity, bdims, is_stable):
        rng = jtu.rand_default(self.rng())
        if arity == 1:
            fun = partial(lax.sort, dimension=dimension)
            self._CheckBatching(fun, 5, bdims, (shape, ) * arity,
                                (np.float32, ) * arity, rng)
        else:
            for i in range(arity):
                fun = lambda *args, i=i: lax.sort(
                    args, dimension=dimension, is_stable=is_stable)[i]
                self._CheckBatching(fun, 5, bdims, (shape, ) * arity,
                                    (np.float32, ) * arity, rng)
Ejemplo n.º 3
0
class BatchingTest(jtu.JaxTestCase):

  def testConstantFunction(self):
    ans = vmap(lambda x: 3)(np.ones(4))
    expected = 3 * np.ones(4)
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testNestedBatchingMatMat(self):
    matvec = vmap(jnp.vdot, in_axes=(0, None))
    matmat = vmap(matvec, in_axes=(None, 1), out_axes=1)

    R = np.random.RandomState(0).randn
    A = R(4, 3)
    B = R(3, 2)

    ans = matmat(A, B)
    expected = np.dot(A, B)
    self.assertAllClose(
        ans, expected, check_dtypes=False,
        rtol={np.float32:1e-2} if jtu.device_under_test() == "tpu" else None)

    jaxpr = make_jaxpr(matmat)(A, B)
    self.assertEqual(len(jaxpr.jaxpr.eqns), 1)

  def testPerExampleGradients(self):
    def predict(params, inputs):
      for W, b in params:
        outputs = jnp.dot(W, inputs) + b
        inputs = jnp.tanh(outputs)
      return outputs

    def loss(params, data):
      inputs, targets = data
      predictions = predict(params, inputs)
      return jnp.sum((predictions - targets)**2)

    batch_size = 5
    layer_sizes = [3, 2, 4]

    R = np.random.RandomState(0).randn
    params = [(R(m, n), R(m))
              for m, n in zip(layer_sizes[1:], layer_sizes[:-1])]

    input_batch = R(5, 3)
    target_batch = R(5, 4)
    batch = (input_batch, target_batch)

    ans = vmap(partial(grad(loss), params))(batch)

    for ans_pair, param_pair in zip(ans, params):
      dW, db = ans_pair
      W, b = param_pair

      self.assertEqual(dW.shape, (batch_size,) + W.shape)
      self.assertEqual(db.shape, (batch_size,) + b.shape)

  def testJacobians(self):
    def jacbwd(f, x):
      y, pullback = vjp(f, x)
      std_basis = np.eye(np.size(y)).reshape((-1,) + np.shape(y))
      jac_flat, = vmap(pullback, out_axes=np.ndim(y))(std_basis)
      return jac_flat.reshape(np.shape(y) + np.shape(x))

    def jacfwd(f, x):
      pushfwd = lambda v: jvp(f, (x,), (v,))
      std_basis = np.eye(np.size(x)).reshape((-1,) + np.shape(x))
      y, jac_flat = vmap(pushfwd, out_axes=(None, 0))(std_basis)
      return jac_flat.reshape(np.shape(y) + np.shape(x))

    R = np.random.RandomState(0).randn

    A = R(4, 3)
    b = R(4)
    f = lambda x: jnp.tanh(jnp.dot(A, x) + b)

    x = R(3)
    self.assertAllClose(jacfwd(f, x), jacbwd(f, x), check_dtypes=False)

  def testBatchOfCompile(self):
    side = []

    @jit
    def f(x):
      side.append(None)
      return x + x

    g = jit(vmap(f))
    self.assertAllClose(g(np.ones(2)), 2 * np.ones(2), check_dtypes=False)
    self.assertEqual(len(side), 1)
    self.assertAllClose(g(2 * np.ones(2)), 4 * np.ones(2),
                        check_dtypes=False)
    self.assertEqual(len(side), 1)

  def testSliceLax(self):
    fun = lambda x: lax.slice(x, (2,), (4,))
    R = np.random.RandomState(0).randn
    x = R(5, 10)

    ans = vmap(fun)(x)
    expected_ans = x[:, 2:4]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

  def testSliceNumpy(self):
    fun = lambda x: x[:, 2]
    R = np.random.RandomState(0).randn
    x = R(10, 5, 3, 7)

    ans = vmap(fun)(x)
    expected_ans = x[:, :, 2]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

  def testRevLax(self):
    fun = lambda x: lax.rev(x, [0])
    R = np.random.RandomState(0).randn
    x = R(2, 3)

    ans = vmap(fun)(x)
    expected_ans = x[:, ::-1]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

    ans = vmap(fun, (1,), 1)(x)
    expected_ans = x[::-1, :]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

  def testRevNumpy(self):
    fun = lambda x: x[:, ::-1]
    R = np.random.RandomState(0).randn
    x = R(3, 2, 4)

    ans = vmap(fun)(x)
    expected_ans = x[:, :, ::-1]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

    ans = vmap(fun, (1,), 1)(x)
    expected_ans = x[:, :, ::-1]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

    ans = vmap(fun, (2,), 2)(x)
    expected_ans = x[:, ::-1, :]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

  def testNpMaximum(self):
    fun = lambda x: jnp.maximum(x, 0.0)
    R = np.random.RandomState(0).randn
    x = R(10, 5, 3, 7)

    ans = vmap(fun)(x)
    expected_ans = np.maximum(x, 0.0)
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

  def testNpGtrThan(self):
    R = np.random.RandomState(0).randn
    x = R(10, 5, 3, 7)

    ans = vmap(lambda x: x > 1.0)(x)
    expected_ans = x > 1.0
    self.assertAllClose(ans, expected_ans)

  def testNpMaximumPerExampleGrad(self):
    R = np.random.RandomState(0).randn
    x = R(10, 5)
    W = R(5, 5)

    fun = lambda W, x: jnp.sum(jnp.maximum(jnp.dot(x, W), 0.0) ** 2)

    ans = vmap(partial(grad(fun), W))(x)

    W_t = jnp.transpose(W)
    for i in range(10):
      x_ex = x[i:i + 1]

      expected_ans = 2.0 * jnp.dot(
          jnp.maximum(jnp.dot(W_t, jnp.transpose(x_ex)), 0.0), x_ex)
      expected_ans = jnp.transpose(expected_ans)

      self.assertAllClose(
          ans[i], expected_ans, check_dtypes=False,
          atol={np.float32:5e-2} if jtu.device_under_test() == "tpu" else None)

  def testDotGeneral(self):
    R = np.random.RandomState(0).randn

    x = R(10, 3, 4, 5)
    y = R(10, 3, 5, 6)
    fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
    ans = vmap(fun)(x, y)
    expected = lax.dot_general(x, y, [((3,), (2,)), ((0, 1), (0, 1))])
    self.assertAllClose(ans, expected)

    x = R(3, 4, 10, 5)
    y = R(3, 10, 5, 6)
    fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
    ans = vmap(fun, in_axes=(2, 1))(x, y)
    expected = np.stack([fun(x[..., i, :], y[:, i, ...]) for i in range(10)])
    self.assertAllClose(ans, expected)

    x = R(3, 4, 5, 10)
    y = R(3, 5, 6)
    fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
    ans = vmap(fun, in_axes=(3, None))(x, y)
    expected = np.stack([fun(x[..., i], y) for i in range(10)])
    self.assertAllClose(ans, expected)

    x = R(3, 4, 5)
    y = R(3, 5, 10, 6)
    fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
    ans = vmap(fun, in_axes=(None, 2))(x, y)
    expected = np.stack([fun(x, y[..., i, :]) for i in range(10)])
    self.assertAllClose(ans, expected)

    x = R(4)
    y = R(4, 10)
    fun = lambda x, y: lax.dot_general(x, y, [((0,), (0,)), ((), ())])
    ans = vmap(fun, in_axes=(None, 1))(x, y)
    expected = np.stack([fun(x, y[..., i]) for i in range(10)])
    self.assertAllClose(ans, expected)

  def testDot(self):
    # these tests are based on @shoyer's notebook studying gufuncs

    def vecvec(a, b):
      dot = jnp.dot
      for ndim in range(1, max(a.ndim, b.ndim)):
        a_ax = 0 if a.ndim > ndim else None
        b_ax = 0 if b.ndim > ndim else None
        dot = vmap(dot, in_axes=(a_ax, b_ax))
      return dot(a, b)

    assert vecvec(jnp.zeros((3,)), jnp.zeros((3,))).shape == ()
    assert vecvec(jnp.zeros((2, 3)), jnp.zeros((3,))).shape == (2,)
    assert vecvec(jnp.zeros((4, 2, 3)), jnp.zeros((3,))).shape == (4, 2)

  def testDot2(self):
    R = np.random.RandomState(0).randn
    xs = R(10, 3)
    ys = R(10, 3)
    ans = vmap(jnp.dot)(xs, ys)
    expected = np.einsum('ni,ni->n', xs, ys)
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testDot3(self):
    R = np.random.RandomState(0).randn
    xs = R(5, 8, 10)
    ys = R(10, 1)
    ans = vmap(jnp.dot, in_axes=(1, None))(xs, ys)
    expected = np.einsum('inj,jk->nik', xs, ys)
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testDot4(self):
    R = np.random.RandomState(0).randn
    xs = R(3, 2)
    ys = R(3)
    ans = vmap(jnp.dot, in_axes=(1, None))(xs, ys)
    expected = np.einsum('ij,i->j', xs, ys)
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testPad(self):
    R = np.random.RandomState(0).randn

    fun = lambda x: lax.pad(x, np.float32(0), [(1, 2, 1)])
    x = R(5, 10).astype(np.float32)
    ans = vmap(fun)(x)
    expected_ans = jnp.stack(list(map(fun, x)))
    self.assertAllClose(ans, expected_ans, check_dtypes=False)


    fun = lambda x: lax.pad(x, np.float32(0), [(1, 2, 1), (0, 1, 0)])
    x = R(5, 10, 3).astype(np.float32)
    ans = vmap(fun)(x)
    expected_ans = jnp.stack(list(map(fun, x)))
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

  def testConcatenate(self):
    R = lambda *shape: np.random.RandomState(0).randn(*shape).astype(np.float32)

    fun = lambda *args: lax.concatenate(args, dimension=0)
    x, y, z = R(10, 2, 3), R(1, 10, 3), R(4, 3)
    ans = vmap(fun, in_axes=(0, 1, None))(x, y, z)
    expected_ans = np.concatenate([x, np.swapaxes(y, 0, 1),
                                    np.broadcast_to(z, (10, 4, 3))], 1)
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

    fun = lambda *args: lax.concatenate(args, dimension=1)
    x, y, z = R(10, 2, 1), R(2, 3), R(2, 4, 10)
    ans = vmap(fun, in_axes=(0, None, 2))(x, y, z)
    expected_ans = np.concatenate([x, np.broadcast_to(y, (10, 2, 3)),
                                    np.moveaxis(z, 2, 0)], 2)
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

  def testJacobianIssue54(self):
    # test modeling the code in https://github.com/google/jax/issues/54

    def func(xs):
      return jnp.array(list(xs))

    xs = jnp.ones((5, 1))
    jacrev(func)(xs)  # don't crash
    jacfwd(func)(xs)  # don't crash

  def testAny(self):
    # test modeling the code in https://github.com/google/jax/issues/108

    ans = vmap(jnp.any)(jnp.array([[True, False], [False, False]]))
    expected = jnp.array([True, False])
    self.assertAllClose(ans, expected)

  def testHessian(self):
    # test based on code from sindhwani@google
    def fun(x, t):
      return jnp.sum(jnp.power(jnp.maximum(x, 0.0), 2)) + t

    x = np.array([-1., -0.5, 0., 0.5, 1.0])

    ans = hessian(lambda x: fun(x, 0.0))(x)
    expected = np.array([[0., 0., 0., 0., 0.],
                          [0., 0., 0., 0., 0.],
                          [0., 0.,0.5, 0., 0.],
                          [0., 0., 0., 2., 0.],
                          [0., 0., 0., 0., 2.]])
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testDynamicSlice(self):
    # test dynamic_slice via numpy indexing syntax
    # see https://github.com/google/jax/issues/1613 for an explanation of why we
    # need to use np rather than np to create x and idx
    x = jnp.arange(30).reshape((10, 3))

    ans = vmap(lambda x, i: x[i], in_axes=(0, None))(x, 1)
    expected = x[:, 1]
    self.assertAllClose(ans, expected, check_dtypes=False)


    idx = jnp.array([0, 1, 2, 1, 0] * 2)
    ans = vmap(lambda x, i: x[i], in_axes=(0, 0))(x, idx)
    expected = x[np.arange(10), idx]
    self.assertAllClose(ans, expected, check_dtypes=False)

    x = jnp.arange(3)
    idx = jnp.array([0, 1, 2, 1, 0] * 2)
    ans = vmap(lambda x, i: x[i], in_axes=(None, 0))(x, idx)
    expected = x[idx]
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testDynamicUpdateSlice(self):
    x = np.random.randn(10, 3)
    y = np.random.randn(10)
    ans = vmap(lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0),
               in_axes=(0, 0, None))(x, y, 1)
    expected = x.copy()
    expected[:, 1] = y
    self.assertAllClose(ans, expected, check_dtypes=False)

    x = np.random.randn(3)
    idx = np.array([0, 1, 2, 1, 0] * 2)
    ans = vmap(lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0),
               in_axes=(None, 0, 0))(x, y, idx)
    expected = np.broadcast_to(x, (10, 3)).copy()
    expected[np.arange(10), idx] = y
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testRandom(self):
    seeds = vmap(random.PRNGKey)(np.arange(10))
    ans = vmap(partial(random.normal, shape=(3, 2)))(seeds)
    expected = np.stack([random.normal(random.PRNGKey(seed), (3, 2))
                          for seed in np.arange(10)])
    self.assertAllClose(ans, expected, check_dtypes=False)
    assert len(np.unique(ans)) == 10 * 3 * 2

  def testSort(self):
    v = np.arange(12)[::-1].reshape(3, 4)

    sv = vmap(partial(lax.sort, dimension=0), (0,))(v)
    self.assertAllClose(sv, v[:, ::-1])

    sv = vmap(partial(lax.sort, dimension=-1), (0,))(v)
    self.assertAllClose(sv, v[:, ::-1])

    sv = vmap(partial(lax.sort, dimension=0), (1,))(v)
    self.assertAllClose(sv, v[::-1, :].T)

    sv = vmap(partial(lax.sort, dimension=0), (1,), 1)(v)
    self.assertAllClose(sv, v[::-1, :])

  def testSortKeyVal(self):
    k = np.arange(12)[::-1].reshape(3, 4)
    v = np.random.RandomState(0).permutation(12).reshape(3, 4)

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 0))(k, v)
    self.assertAllClose(sk, k[:, ::-1])
    self.assertAllClose(sv, v[:, ::-1])

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 1), 1)(k, v)
    self.assertAllClose(sk, k[::-1, :])
    self.assertAllClose(sv, v[::-1, :])

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 1))(k, v.T)
    self.assertAllClose(sk, k[:, ::-1])
    self.assertAllClose(sv, v[:, ::-1])

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 0))(k.T, v)
    self.assertAllClose(sk, k[:, ::-1])
    self.assertAllClose(sv, v[:, ::-1])

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (None, 0))(k[0], v)
    self.assertAllClose(sk, np.broadcast_to(k[0, ::-1], (3, 4)))
    self.assertAllClose(sv, v[:, ::-1])

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, None))(k.T, v[0])
    self.assertAllClose(sk, k[:, ::-1])
    self.assertAllClose(sv, np.broadcast_to(v[0, ::-1], (3, 4)))

  def testConvGeneralDilated(self):
    W = jnp.array(np.random.randn(3, 3, 1, 5), dtype=np.float32)
    X = jnp.array(np.random.randn(10, 5, 5, 1), dtype=np.float32)

    def f(params, x):
      one = (1, 1)
      dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
      y = lax.conv_general_dilated(
          x, params, one, 'SAME', one, one, dimension_numbers)
      return y
    grad_loss = grad(lambda params, x: jnp.mean(f(params, x) ** 2))

    # Test forward prop.
    per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
    per_example = jnp.reshape(per_example, (10, 5, 5, 5))
    per_example_direct = f(W, X)
    self.assertAllClose(per_example, per_example_direct)

    # Test gradients.
    per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
    per_example_direct = []
    for i in range(10):
      g = grad_loss(W, jnp.reshape(X[i], (1, 5, 5, 1)))
      per_example_direct += [
          jnp.reshape(g, (1,) + g.shape)]
    per_example_direct = jnp.concatenate(per_example_direct, axis=0)
    self.assertAllClose(per_example, per_example_direct,
                        rtol=2e-2, atol=2e-3)

  def testConvGeneralDilatedBatchNotMajor(self):
    W = jnp.array(np.random.randn(3, 3, 1, 4), dtype=np.float32)
    x = jnp.array(np.random.randn(3, 5, 7, 5, 1), dtype=np.float32)

    def f(params, x):
      one = (1, 1)
      dimension_numbers = ('HNWC', 'HWIO', 'HWNC')
      y = lax.conv_general_dilated(
          x, params, one, 'SAME', one, one, dimension_numbers)
      return y

    per_example = vmap(partial(f, W))(x)
    per_example = jnp.reshape(jnp.transpose(per_example, (1, 2, 0, 3, 4)),
                             (5, 5, 21, 4))
    per_example_direct = f(W, jnp.reshape(jnp.transpose(x, (1, 0, 2, 3, 4)),
                                         (5, 21, 5, 1)))
    self.assertAllClose(per_example, per_example_direct)

  @parameterized.named_parameters(
    {"testcase_name": "_op={}".format(name), "op": op, "unit": unit}
    for name, op, unit in [("max", lax.max, -jnp.inf), ("min", lax.min, jnp.inf)])
  def testMinMaxPool(self, op, unit):
    W = jnp.array(np.random.randn(3, 3, 1, 5), dtype=np.float32)
    X = jnp.array(np.random.randn(10, 5, 5, 1), dtype=np.float32)

    def f(params, x):
      one = (1, 1)
      dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
      y = lax.conv_general_dilated(
          x, params, one, 'SAME', one, one, dimension_numbers)
      y = lax.reduce_window(
          y, unit, op, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME')
      return y
    grad_loss = grad(lambda params, x: jnp.mean(f(params, x) ** 2))

    # Test forward prop.
    per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
    per_example = jnp.reshape(per_example, (10, 5, 5, 5))
    per_example_direct = f(W, X)
    self.assertAllClose(per_example, per_example_direct)

    # Test gradients.
    per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
    per_example_direct = []
    for i in range(10):
      g = grad_loss(W, jnp.reshape(X[i], (1, 5, 5, 1)))
      per_example_direct += [
          jnp.reshape(g, (1,) + g.shape)]
    per_example_direct = jnp.concatenate(per_example_direct, axis=0)
    self.assertAllClose(per_example, per_example_direct, rtol=5e-2, atol=1e-3)

  def testSumPool(self):
    W = jnp.array(np.random.randn(3, 3, 1, 5), dtype=np.float32)
    X = jnp.array(np.random.randn(10, 5, 5, 1), dtype=np.float32)

    def f(params, x):
      one = (1, 1)
      dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
      y = lax.conv_general_dilated(
          x, params, one, 'SAME', one, one, dimension_numbers)
      y = lax.reduce_window(
          y, 0.0, lax.add, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME')
      return y
    grad_loss = grad(lambda params, x: jnp.mean(f(params, x) ** 2))

    # Test forward prop.
    per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
    per_example = jnp.reshape(per_example, (10, 5, 5, 5))
    per_example_direct = f(W, X)
    self.assertAllClose(per_example, per_example_direct)

    # Test gradients.
    per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1)))
    per_example_direct = []
    for i in range(10):
      g = grad_loss(W, jnp.reshape(X[i], (1, 5, 5, 1)))
      per_example_direct += [
          jnp.reshape(g, (1,) + g.shape)]
    per_example_direct = jnp.concatenate(per_example_direct, axis=0)
    self.assertAllClose(per_example, per_example_direct,
                        rtol=3e-2, atol=1e-3)

  def testCumProd(self):
   x = jnp.arange(9).reshape(3, 3) + 1
   y = vmap(lambda x: jnp.cumprod(x, axis=-1))(x)
   self.assertAllClose(np.cumprod(x, axis=1, dtype=int), y)

  def testSelect(self):
    pred = np.array([True, False])
    on_true = np.array([0, 1])
    on_false = np.array([2, 3])
    ans = vmap(lax.select)(pred, on_true, on_false)
    expected = np.array([0, 3])
    self.assertAllClose(ans, expected)

    pred = np.array([False, True])
    on_true = np.array([0, 1])
    on_false = np.array([2, 3])
    ans = vmap(lax.select, (0, None, None))(pred, on_true, on_false)
    expected = np.array([[2, 3],
                          [0, 1]])
    self.assertAllClose(ans, expected)

    pred = True
    on_true = np.array([0, 1], np.float32)
    on_false = np.array(3, np.float32)
    ans = vmap(lax.select, (None, 0, None))(pred, on_true, on_false)
    expected = np.array([0, 1], np.float32)
    self.assertAllClose(ans, expected)

    pred = np.array([False, True])
    on_true = np.array([0, 1], np.float32)
    on_false = np.array(3, np.float32)
    ans = vmap(lax.select, (0, 0, None))(pred, on_true, on_false)
    expected = np.array([3, 1], np.float32)
    self.assertAllClose(ans, expected)

    pred = np.array([False, True])
    on_true = np.array([2], np.float32)
    on_false = np.array([[3, 4]], np.float32)
    ans = vmap(lax.select, (0, None, 1), 1)(pred, on_true, on_false)
    expected = np.array([[3, 2]], np.float32)
    self.assertAllClose(ans, expected)

  def testLaxLinalgCholesky(self):
    a = np.random.RandomState(0).randn(10, 5, 5).astype(np.float32)
    a = np.matmul(a, np.conj(np.swapaxes(a, -1, -2)))

    ans = vmap(lax.linalg.cholesky)(a)
    expected = np.linalg.cholesky(a)
    self.assertAllClose(ans, expected, check_dtypes=False, rtol=1e-4)

    b = np.random.RandomState(0).randn(10, 5, 5).astype(np.float32)
    b = np.matmul(b, np.conj(np.swapaxes(b, -1, -2)))
    b_trans = np.swapaxes(b, 0, 1)  # shape is (5, 10, 5)

    ans = vmap(lax.linalg.cholesky, in_axes=1, out_axes=0)(b_trans)
    expected = np.linalg.cholesky(b)
    self.assertAllClose(ans, expected, check_dtypes=False, rtol=1e-4)

  def testLaxLinalgTriangularSolve(self):
    a = np.random.RandomState(0).randn(4, 10, 4).astype(np.float32)
    a += np.eye(4, dtype=jnp.float32)[:, None, :]
    b = np.random.RandomState(0).randn(5, 4, 10).astype(np.float32)

    ans = vmap(lax.linalg.triangular_solve, in_axes=(1, 2))(a, b)
    expected = np.stack(
      [lax.linalg.triangular_solve(a[:, i], b[..., i]) for i in range(10)])
    self.assertAllClose(ans, expected)

    ans = vmap(lax.linalg.triangular_solve, in_axes=(None, 2))(a[:, 0], b)
    expected = np.stack(
      [lax.linalg.triangular_solve(a[:, 0], b[..., i]) for i in range(10)])
    self.assertAllClose(ans, expected)

    ans = vmap(lax.linalg.triangular_solve, in_axes=(1, None))(a, b[..., 0])
    expected = np.stack(
      [lax.linalg.triangular_solve(a[:, i], b[..., 0]) for i in range(10)])
    self.assertAllClose(ans, expected)

  @parameterized.named_parameters(
      {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
          slice_sizes),
       "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
       "slice_sizes": slice_sizes}
      for dtype in [np.float32, np.int32]
      for axis, shape, idxs, dnums, slice_sizes in [
          (0, (3, 5), np.array([[0], [2]]), lax.GatherDimensionNumbers(
            offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1,)),
          (1, (10, 3), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
            (2,)),
          (1, (10, 3, 5), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1, 3)),
          (2, (10, 5, 3), np.array([[0, 2], [1, 0]]),
           lax.GatherDimensionNumbers(
             offset_dims=(1,), collapsed_slice_dims=(0,),
             start_index_map=(0, 1)),
            (1, 3)),
      ])
  def testGatherBatchedOperand(self, axis, shape, dtype, idxs, dnums, slice_sizes):
    rng = jtu.rand_default(self.rng())
    fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
    operand = rng(shape, dtype)
    ans = vmap(fun, (axis, None))(operand, idxs)
    expected = np.stack([fun(operand[(slice(None),) * axis + (i,)], idxs)
                          for i in range(operand.shape[axis])])
    self.assertAllClose(ans, expected, check_dtypes=False)

  @parameterized.named_parameters(
      {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
          slice_sizes),
       "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
       "slice_sizes": slice_sizes}
      for dtype in [np.float32, np.float64]
      for axis, shape, idxs, dnums, slice_sizes in [
          (0, (3, 5), np.array([[0], [2]]), lax.GatherDimensionNumbers(
            offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1,)),
          (1, (10, 3), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
            (2,)),
          (1, (10, 3, 5), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1, 3)),
          (2, (10, 5, 3), np.array([[0, 2], [1, 0]]),
           lax.GatherDimensionNumbers(
             offset_dims=(1,), collapsed_slice_dims=(0,),
             start_index_map=(0, 1)),
            (1, 3))
      ])
  def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums, slice_sizes):
    rng = jtu.rand_default(self.rng())
    fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
    gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx))))
    operand = rng(shape, dtype)
    ans = vmap(gfun, (axis, None))(operand, idxs)
    expected = np.stack([gfun(operand[(slice(None),) * axis + (i,)], idxs)
                          for i in range(operand.shape[axis])])
    self.assertAllClose(ans, expected, check_dtypes=False)

  @parameterized.named_parameters(
      {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
          slice_sizes),
       "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
       "slice_sizes": slice_sizes}
      for dtype in [np.float32, np.int32]
      for axis, shape, idxs, dnums, slice_sizes in [
          (0, (5,), np.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers(
              offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)),
          (1, (10,), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
           lax.GatherDimensionNumbers(
               offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)),
          (1, (10, 5), np.array([[0, 2, 1], [0, 3, 3]]).T[..., None],
           lax.GatherDimensionNumbers(
               offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)),
          (0, (10, 5), np.array([[[0, 1], [2, 0]],
                                  [[1, 0], [2, 3]]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)),
      ])
  def testGatherBatchedIndices(self, axis, shape, dtype, idxs, dnums, slice_sizes):
    rng = jtu.rand_default(self.rng())
    fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
    operand = rng(shape, dtype)
    ans = vmap(fun, (None, axis))(operand, idxs)
    expected = np.stack([fun(operand, idxs[(slice(None),) * axis + (i,)])
                          for i in range(idxs.shape[axis])])
    self.assertAllClose(ans, expected, check_dtypes=False)

  @parameterized.named_parameters(
      {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
          slice_sizes),
       "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
       "slice_sizes": slice_sizes}
      for dtype in [np.float32, np.float64]
      for axis, shape, idxs, dnums, slice_sizes in [
          (0, (5,), np.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers(
              offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)),
          (1, (10,), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
           lax.GatherDimensionNumbers(
               offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)),
          (1, (10, 5), np.array([[0, 2, 1], [0, 3, 3]]).T[..., None],
           lax.GatherDimensionNumbers(
               offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)),
          (0, (10, 5), np.array([[[0, 1], [2, 0]],
                                  [[1, 0], [2, 3]]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)),
      ])
  def testGatherGradBatchedIndices(self, axis, shape, dtype, idxs, dnums, slice_sizes):
    rng = jtu.rand_default(self.rng())
    fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
    gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx))))
    operand = rng(shape, dtype)
    ans = vmap(gfun, (None, axis))(operand, idxs)
    expected = np.stack([gfun(operand, idxs[(slice(None),) * axis + (i,)])
                          for i in range(idxs.shape[axis])])
    self.assertAllClose(ans, expected, check_dtypes=False)

  @parameterized.named_parameters(
      {"testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs,
          dnums, slice_sizes),
       "op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype":
       dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes}
      for dtype in [np.float32, np.int32]
      for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
          (0, 0, (2, 5), np.array([[[0], [2]], [[1], [3]]]),
           lax.GatherDimensionNumbers(
             offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
           (1,)),
          (1, 1, (10, 2), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
           lax.GatherDimensionNumbers(
             offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
           (2,)),
          (0, 1, (2, 10, 5,), np.array([[[0, 2, 1], [0, 3, 3]]]).T,
           lax.GatherDimensionNumbers(
             offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
           (1, 3)),
          (2, 0, (10, 5, 2), np.array([[[0, 2], [1, 0]],
                                        [[1, 0], [2, 0]]]),
          lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
           (1, 3)),
      ])
  def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums, slice_sizes):
    rng = jtu.rand_default(self.rng())
    fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
    operand = rng(shape, dtype)
    assert operand.shape[op_axis] == idxs.shape[idxs_axis]
    ans = vmap(fun, (op_axis, idxs_axis))(operand, idxs)
    expected = np.stack([fun(operand[(slice(None),) * op_axis + (i,)],
                              idxs[(slice(None),) * idxs_axis + (i,)])
                          for i in range(idxs.shape[idxs_axis])])
    self.assertAllClose(ans, expected, check_dtypes=False)

  @parameterized.named_parameters(
      {"testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs,
          dnums, slice_sizes),
       "op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype":
       dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes}
      for dtype in [np.float32]
      for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
          (0, 0, (2, 5), np.array([[[0], [2]], [[1], [3]]]),
           lax.GatherDimensionNumbers(
             offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
           (1,)),
          (1, 1, (10, 2), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
           lax.GatherDimensionNumbers(
             offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
           (2,)),
          (0, 1, (2, 10, 5,), np.array([[[0, 2, 1], [0, 3, 3]]]).T,
           lax.GatherDimensionNumbers(
             offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
           (1, 3)),
          (2, 0, (10, 5, 2), np.array([[[0, 2], [1, 0]],
                                        [[1, 0], [2, 0]]]),
          lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
           (1, 3)),
      ])
  def testGatherGradBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums,
                                slice_sizes):
    rng = jtu.rand_default(self.rng())
    fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
    gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx))))
    operand = rng(shape, dtype)
    assert operand.shape[op_axis] == idxs.shape[idxs_axis]
    ans = vmap(gfun, (op_axis, idxs_axis))(operand, idxs)
    expected = np.stack([gfun(operand[(slice(None),) * op_axis + (i,)],
                              idxs[(slice(None),) * idxs_axis + (i,)])
                          for i in range(idxs.shape[idxs_axis])])
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testNumpyIndexing1(self):
    a = jnp.arange(2 * 3 * 4).reshape((2, 3, 4))
    ind = np.array([[0, 1],
                    [2, 0]])
    def f(a, ind):
      return a[:, ind]
    expected = np.stack([f(a, ind[i, :]) for i in range(ind.shape[0])])
    ans = vmap(f, (None, 0))(a, ind)
    assert np.all(ans == expected)

  def testNumpyIndexing2(self):
    a = jnp.arange(2 * 3 * 4).reshape((2, 3, 4))
    def f(a):
      inds = jnp.array([0, 2])
      return a[:, inds]
    ans = vmap(f)(a)
    expected = np.stack([f(a[:, i, :]) for i in range(a.shape[1])], axis=1)
    assert np.all(ans == expected)

  def testTranspose(self):
    x = np.arange(4 * 3 * 3).reshape((4, 3, 3))
    ans = vmap(lambda x: x + x.T)(x)
    expected = x + np.swapaxes(x, -1, -2)
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testTransposePermutation(self):
    x = np.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5))
    ans = vmap(lambda x: jnp.transpose(x, (1, 0, 2)))(x)
    expected = np.transpose(x, (0, 2, 1, 3))
    self.assertAllClose(ans, expected, check_dtypes=False)

    x = np.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5))
    ans = vmap(lambda x: jnp.transpose(x, (1, 2, 0)))(x)
    expected = np.transpose(x, (0, 2, 3, 1))
    self.assertAllClose(ans, expected, check_dtypes=False)

    x = np.arange(6 * 3 * 4 * 5).reshape((3, 4, 6, 5))
    ans = vmap(lambda x: jnp.transpose(x, (1, 2, 0)), in_axes=2)(x)
    expected = np.transpose(x, (2, 1, 3, 0))
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testIssue354(self):
    psd_mat = np.random.randn(20, 10)
    psd_mat = psd_mat.T.dot(psd_mat)
    vec = np.random.randn(10)

    def f(scale):
      scaled_mat = scale * psd_mat
      chol = jnp.linalg.cholesky(scaled_mat)
      return -0.5 * jnp.sum((jnp.einsum('ij,j->i', chol, vec))**2)
    vmapped_f = vmap(f)
    vmapped_f_grad = grad(lambda x: jnp.sum(vmapped_f(x)))

    scales = np.array([[0.1], [0.2], [0.3], [0.4], [0.5]])
    ans = vmapped_f_grad(scales)  # don't crash!
    expected = np.stack([grad(f)(scale) for scale in scales])
    self.assertAllClose(ans, expected, check_dtypes=False,
                        rtol=jtu.default_gradient_tolerance)

  def testIssue387(self):
    # https://github.com/google/jax/issues/387
    R = np.random.RandomState(0).rand(100, 2)

    def dist_sq(R):
      dR = R[:, jnp.newaxis, :] - R[jnp.newaxis, :, :]
      zero = jnp.zeros_like(dR)
      dR = dR - jnp.where(jnp.abs(dR) < 0.5, zero, 0.5 * jnp.sign(dR))
      return jnp.sum(dR ** 2, axis=2)

    @jit
    def f(R):
      _ = dist_sq(R)
      return jnp.sum(R ** 2)

    _ = hessian(f)(R)  # don't crash on UnshapedArray

  def testIssue489(self):
    def f(key):
      def body_fn(uk):
        key = uk[1]
        u = random.uniform(key, ())
        key, _ = random.split(key)
        return u, key

      u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn, (1., key))
      return u

    print(vmap(f)(random.split(random.PRNGKey(0), 2)))  # no crash

  def testEmptyTuples(self):
    # Ensure there is no crash when a vectorized input contains empty tuples.
    result = vmap(lambda x, _: x + 1)(np.array([0, 1]), ())
    self.assertAllClose(result, np.array([1, 2]), check_dtypes=False)
    # Ensure there is no crash when a vectorized output contains empty tuples.
    result, empty_tuple = vmap(lambda x: (x + 1, ()))(np.array([0, 1]))
    self.assertAllClose(result, np.array([1, 2]), check_dtypes=False)
    self.assertEqual((), empty_tuple)

  def testIndexAddBatchedIndexesOnly(self):
    f = lambda x, idx, y: jax.ops.index_add(x, jax.ops.index[idx], y)
    result = vmap(f, (None, 0, None))(np.zeros((10,)), np.arange(10,), 1.)
    self.assertAllClose(result, np.eye(10), check_dtypes=False)

  def testIssue1170(self):
    def f(index1, index2):
      return jnp.arange(36).reshape(6, 6)[index1, index2]
    g = jax.jit(jax.pmap(f))
    ans = g(index1=np.asarray([1]), index2=np.asarray([2]))
    expected = g(np.asarray([1]), np.asarray([2]))
    self.assertAllClose(ans, expected)

  def testIssue3883(self):
    def scalar_f(x):
      return lax.dynamic_slice(x, [], [])

    xs = jnp.array([1, 2, 3, 4])
    ans = vmap(scalar_f)(xs)
    expected = jnp.array([scalar_f(x) for x in xs])
    self.assertAllClose(ans, expected)

    def scalar_f2(x):
      return lax.dynamic_update_slice(x, 7, [])

    xs = jnp.array([1, 2, 3, 4])
    ans = vmap(scalar_f2)(xs)
    expected = jnp.array([scalar_f2(x) for x in xs])
    self.assertAllClose(ans, expected)

  @parameterized.named_parameters(
      {"testcase_name": "_{}_vmap_names={}_collective_names={}".format(
          collective.__name__.replace(" ", ""), vmap_names, collective_names),
       "collective": collective, "bulk_op": bulk_op, "vmap_names": vmap_names,
       "collective_names": collective_names}
      for collective, bulk_op in [(lax.psum, jnp.sum),
                                  (lax.pmax, jnp.max),
                                  (lax.pmin, jnp.min)]
      for vmap_names in [('i',), ('i', 'j'), ('j', 'i')]
      for collective_names in it.permutations(vmap_names))
  @skipIf(not jax.config.omnistaging_enabled,
          "vmap collectives only supported when omnistaging is enabled")
  def testCommAssocCollective(self, collective, bulk_op, vmap_names, collective_names):
    x = jnp.arange(3 * 4 * 5).reshape((3, 4, 5))

    # To test relative permutations of the order in which the axis names appear
    # in the primitive call versus the order the vmaps are applied, we always
    # apply vmaps in the order of the `vmap_names` argument, and apply the
    # collective with names according to the `collective_names` argument.
    f = lambda x: x - collective(x, collective_names)
    for axis_name in vmap_names:
      f = vmap(f, axis_name=axis_name)
    self.assertAllClose(f(x), x - bulk_op(x, axis=tuple(range(len(vmap_names)))))

  @skipIf(not jax.config.omnistaging_enabled,
          "vmap collectives only supported when omnistaging is enabled")
  def testPPermute(self):
    nelem = 10
    ntests = 10
    x = np.arange(nelem)
    rng = np.random.RandomState(1)
    for i in range(ntests):
      perm = np.arange(nelem)
      rng.shuffle(perm)
      perm_pairs = np.stack([np.arange(nelem), perm], axis=-1)
      rng.shuffle(perm_pairs)
      self.assertAllClose(
        vmap(lambda x: x - lax.ppermute(x, 'i', perm_pairs), axis_name='i')(x),
        x - x[perm])

  @parameterized.named_parameters(
      {"testcase_name": f"_split={split_axis}_concat={concat_axis}_vmap={vmap_axis}",
       "split_axis": split_axis, "concat_axis": concat_axis, "vmap_axis": vmap_axis}
      for split_axis, concat_axis, vmap_axis in it.product(range(3), range(3), range(4)))
  @skipIf(not jax.config.omnistaging_enabled,
          "vmap collectives only supported when omnistaging is enabled")
  def testAllToAllShape(self, vmap_axis, split_axis, concat_axis):
    d = vmap_axis

    def shape_fun(x, out_d):
      shape = list(x.shape)
      vmap_dim_id = shape.pop(d)
      split_dim_id = shape.pop(split_axis)
      shape.insert(concat_axis, vmap_dim_id)
      shape.insert(out_d, split_dim_id)
      return tuple(shape)

    shape = (2, 3, 4, 5)
    x = np.arange(np.prod(shape)).reshape(shape)
    rule = batching.collective_rules[lax.all_to_all_p]
    y, out_d = rule(None, (x,), (d,), None, split_axis, concat_axis, None)
    exp_shape = shape_fun(x, out_d)
    self.assertEqual(y.shape, exp_shape)

  @parameterized.named_parameters(
      {"testcase_name": f"_split={split_axis}_concat={concat_axis}_vmap={vmap_axis}",
       "split_axis": split_axis, "concat_axis": concat_axis, "vmap_axis": vmap_axis}
      for split_axis, concat_axis, vmap_axis in it.product(range(2), range(2), range(3)))
  @skipIf(not jax.config.omnistaging_enabled,
          "vmap collectives only supported when omnistaging is enabled")
  def testAllToAllSplitAxis(self, vmap_axis, split_axis, concat_axis):
    raise SkipTest("all_to_all split axis broken after #4835")  # TODO(mattjj,apaszke)
    shape = (4, 4, 4)
    x = np.arange(np.prod(shape)).reshape(shape)

    @partial(vmap, in_axes=vmap_axis, axis_name='i')
    @partial(vmap, in_axes=vmap_axis, axis_name='j')
    def f(x):
      return lax.all_to_all(x, ('i', 'j'), split_axis, concat_axis)

    unroll_shape = (2, 2, *shape[1:])
    unroll_shape = list(shape)
    unroll_shape[vmap_axis:vmap_axis+1] = (2, 2)
    x_unroll = x.reshape(unroll_shape)
    y_unrolled = f(x_unroll)
    y = y_unrolled.reshape(shape)

    if vmap_axis <= split_axis:
      split_axis += 1
    ref = jnp.moveaxis(x, (vmap_axis, split_axis),
                          (concat_axis + 1, 0))
    self.assertAllClose(y, ref)

  def testNegativeAxes(self):
    x = np.arange(3*4*5).reshape(3, 4, 5)
    self.assertAllClose(jax.vmap(jnp.sum, in_axes=-3)(x),
                        jnp.sum(x, axis=(1, 2)))
    self.assertAllClose(jax.vmap(jnp.sum, in_axes=-2)(x),
                        jnp.sum(x, axis=(0, 2)))
    self.assertAllClose(jax.vmap(jnp.sum, in_axes=-1)(x),
                        jnp.sum(x, axis=(0, 1)))

    with self.assertRaisesRegex(ValueError, "vmap got arg 0 of rank 3 but axis to be mapped -4"):
      jax.vmap(jnp.sum, in_axes=-4)(x)

    id = lambda y: y
    self.assertAllClose(x, jax.vmap(id, in_axes=0, out_axes=-3)(x))
    self.assertAllClose(x.transpose(1, 0, 2),
                        jax.vmap(id, in_axes=0, out_axes=-2)(x))
    self.assertAllClose(x.transpose(1, 2, 0),
                        jax.vmap(id, in_axes=0, out_axes=-1)(x))

    with self.assertRaisesRegex(ValueError, "axis -4 is out of bounds.*"):
      jax.vmap(id, in_axes=0, out_axes=-4)(x)

    self.assertAllClose(
      np.full((5,), 7),
      jax.vmap(lambda *xs: xs, in_axes=(0, None), out_axes=(0, -1))(
        np.arange(5), 7)[1])

    with self.assertRaisesRegex(ValueError, "axis -2 is out of bounds.*"):
      jax.vmap(lambda *xs: xs, in_axes=(0, None), out_axes=(0, -2))(
        np.arange(5), 7)

  @skipIf(not jax.config.omnistaging_enabled,
          "vmap collectives only supported when omnistaging is enabled")
  def testAxisIndex(self):
    x = np.arange(10)
    self.assertAllClose(
      vmap(lambda x: x - lax.axis_index('i'), axis_name='i')(x),
      x - np.arange(x.shape[0]))

  @skipIf(not jax.config.omnistaging_enabled,
          "vmap collectives only supported when omnistaging is enabled")
  def testCollectivePdot(self):
    def f(x, y):
      return lax.pdot(x, y, 'i')

    rng = np.random.RandomState(0)

    x = rng.randn(3, 4)
    y = rng.randn(4, 5)
    z = vmap(f, axis_name='i', in_axes=(1, 0), out_axes=None)(x, y)
    self.assertAllClose(z, jnp.dot(x, y))

    x = rng.randn(4, 3)
    y = rng.randn(4, 5)
    z = vmap(f, axis_name='i', in_axes=(0, 0), out_axes=None)(x, y)
    self.assertAllClose(z, jnp.dot(x.T, y))

  @skipIf(not jax.config.omnistaging_enabled,
          "vmap collectives only supported when omnistaging is enabled")
  def testCollectivePdotBatching(self):
    def f(x, y):
      return lax.pdot(x, y, 'i')

    rng = np.random.RandomState(1)
    xs = rng.randn(2, 8, 3)
    ys = rng.randn(2, 3, 5)
    zs = vmap(vmap(f, axis_name='i', in_axes=(1, 0), out_axes=None))(xs, ys)
    self.assertAllClose(zs, jnp.einsum('nij,njk->nik', xs, ys))

  @skipIf(not jax.config.omnistaging_enabled,
          "vmap collectives only supported when omnistaging is enabled")
  def testPdotJvp(self):
    def f(x, y):
      return lax.pdot(x, y, 'i')

    rng = np.random.RandomState(1)
    x = rng.randn(3, 4)
    x_dot = rng.randn(*x.shape)
    y = rng.randn(4, 5)
    y_dot = rng.randn(*y.shape)

    z, z_dot = vmap(lambda x, y, x_dot, y_dot: jvp(f, (x, y), (x_dot, y_dot)),
                    axis_name='i', in_axes=(1, 0, 1, 0), out_axes=None)(x, y, x_dot, y_dot)
    self.assertAllClose(z, jnp.dot(x, y))
    self.assertAllClose(z_dot, jnp.dot(x_dot, y) + jnp.dot(x, y_dot))

  @skipIf(not jax.config.omnistaging_enabled,
          "vmap collectives only supported when omnistaging is enabled")
  def testPdotVjp(self):
    def f(x, y):
      return lax.pdot(x, y, 'i')

    rng = np.random.RandomState(1)
    x = rng.randn(3, 4)
    y = rng.randn(4, 5)
    z_bar = rng.randn(3, 5)

    x_bar, y_bar = vmap(lambda x, y, z_bar: vjp(f, x, y)[1](z_bar),
                        axis_name='i', in_axes=(1, 0, None), out_axes=(1, 0))(x, y, z_bar)
    self.assertAllClose(x_bar, jnp.dot(z_bar, y.T))
    self.assertAllClose(y_bar, jnp.dot(x.T, z_bar))
Ejemplo n.º 4
0
    np.array([2], dtype=np.int32),
    np.array([2, 4], dtype=np.int32),
    np.array([[2, 4], [5, 6]], dtype=np.int32),
    np.array([0, 1, 10], dtype=np.int32),  # Index out of bounds
    np.array([0, 1, 2, -1], dtype=np.int32),  # Index out of bounds
  ]
  for axis in [0, 1, 2]] +

  # Directly from lax.gather in lax_test.py.
  [Harness(
    f"_shape={shape}_idxs_shape={idxs.shape}_dnums={dnums}_slice_sizes={slice_sizes}",
    lambda op, idxs, dnums, slice_sizes: lax.gather(op, idxs, dimension_numbers=dnums, slice_sizes=slice_sizes),
    [RandArg(shape, np.float32),
     idxs, StaticArg(dnums), StaticArg(slice_sizes)])
    for shape, idxs, dnums, slice_sizes in [
    ((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers(
      offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
     (1,)),
    ((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
      offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
     (2,)),
    ((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
      offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
     (1, 3)),
    ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
      offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
     (1, 3)),
    ]
  ]
)

lax_scatter = tuple(
Ejemplo n.º 5
0
class BatchingTest(jtu.JaxTestCase):
    def testConstantFunction(self):
        ans = vmap(lambda x: 3)(onp.ones(4))
        expected = 3 * onp.ones(4)
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testNestedBatchingMatMat(self):
        matvec = vmap(np.vdot, in_axes=(0, None))
        matmat = vmap(matvec, in_axes=(None, 1), out_axes=1)

        R = onp.random.RandomState(0).randn
        A = R(4, 3)
        B = R(3, 2)

        ans = matmat(A, B)
        expected = onp.dot(A, B)
        self.assertAllClose(ans,
                            expected,
                            check_dtypes=False,
                            rtol={onp.float32: 1e-2}
                            if jtu.device_under_test() == "tpu" else None)

        jaxpr = make_jaxpr(matmat)(A, B)
        self.assertEqual(len(jaxpr.jaxpr.eqns), 1)

    def testPerExampleGradients(self):
        def predict(params, inputs):
            for W, b in params:
                outputs = np.dot(W, inputs) + b
                inputs = np.tanh(outputs)
            return outputs

        def loss(params, data):
            inputs, targets = data
            predictions = predict(params, inputs)
            return np.sum((predictions - targets)**2)

        batch_size = 5
        layer_sizes = [3, 2, 4]

        R = onp.random.RandomState(0).randn
        params = [(R(m, n), R(m))
                  for m, n in zip(layer_sizes[1:], layer_sizes[:-1])]

        input_vec = R(3)
        target_vec = R(4)
        datum = (input_vec, target_vec)

        input_batch = R(5, 3)
        target_batch = R(5, 4)
        batch = (input_batch, target_batch)

        ans = vmap(partial(grad(loss), params))(batch)

        for ans_pair, param_pair in zip(ans, params):
            dW, db = ans_pair
            W, b = param_pair

            self.assertEqual(dW.shape, (batch_size, ) + W.shape)
            self.assertEqual(db.shape, (batch_size, ) + b.shape)

    def testJacobians(self):
        def jacbwd(f, x):
            y, pullback = vjp(f, x)
            std_basis = onp.eye(onp.size(y)).reshape((-1, ) + onp.shape(y))
            jac_flat, = vmap(pullback, out_axes=onp.ndim(y))(std_basis)
            return jac_flat.reshape(onp.shape(y) + onp.shape(x))

        def jacfwd(f, x):
            pushfwd = lambda v: jvp(f, (x, ), (v, ))
            std_basis = onp.eye(onp.size(x)).reshape((-1, ) + onp.shape(x))
            y, jac_flat = vmap(pushfwd, out_axes=(None, 0))(std_basis)
            return jac_flat.reshape(onp.shape(y) + onp.shape(x))

        R = onp.random.RandomState(0).randn

        A = R(4, 3)
        b = R(4)
        f = lambda x: np.tanh(np.dot(A, x) + b)

        x = R(3)
        self.assertAllClose(jacfwd(f, x), jacbwd(f, x), check_dtypes=False)

    def testBatchOfCompile(self):
        side = []

        @jit
        def f(x):
            side.append(None)
            return x + x

        g = jit(vmap(f))
        self.assertAllClose(g(onp.ones(2)),
                            2 * onp.ones(2),
                            check_dtypes=False)
        self.assertEqual(len(side), 1)
        self.assertAllClose(g(2 * onp.ones(2)),
                            4 * onp.ones(2),
                            check_dtypes=False)
        self.assertEqual(len(side), 1)

    def testSliceLax(self):
        fun = lambda x: lax.slice(x, (2, ), (4, ))
        R = onp.random.RandomState(0).randn
        x = R(5, 10)

        ans = vmap(fun)(x)
        expected_ans = x[:, 2:4]
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

    def testSliceNumpy(self):
        fun = lambda x: x[:, 2]
        R = onp.random.RandomState(0).randn
        x = R(10, 5, 3, 7)

        ans = vmap(fun)(x)
        expected_ans = x[:, :, 2]
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

    def testRevLax(self):
        fun = lambda x: lax.rev(x, [0])
        R = onp.random.RandomState(0).randn
        x = R(2, 3)

        ans = vmap(fun)(x)
        expected_ans = x[:, ::-1]
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

        ans = vmap(fun, (1, ), 1)(x)
        expected_ans = x[::-1, :]
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

    def testRevNumpy(self):
        fun = lambda x: x[:, ::-1]
        R = onp.random.RandomState(0).randn
        x = R(3, 2, 4)

        ans = vmap(fun)(x)
        expected_ans = x[:, :, ::-1]
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

        ans = vmap(fun, (1, ), 1)(x)
        expected_ans = x[:, :, ::-1]
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

        ans = vmap(fun, (2, ), 2)(x)
        expected_ans = x[:, ::-1, :]
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

    def testNpMaximum(self):
        fun = lambda x: np.maximum(x, 0.0)
        R = onp.random.RandomState(0).randn
        x = R(10, 5, 3, 7)

        ans = vmap(fun)(x)
        expected_ans = onp.maximum(x, 0.0)
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

    def testNpGtrThan(self):
        R = onp.random.RandomState(0).randn
        x = R(10, 5, 3, 7)

        ans = vmap(lambda x: x > 1.0)(x)
        expected_ans = x > 1.0
        self.assertAllClose(ans, expected_ans, check_dtypes=True)

    def testNpMaximumPerExampleGrad(self):
        R = onp.random.RandomState(0).randn
        x = R(10, 5)
        W = R(5, 5)

        fun = lambda W, x: np.sum(np.maximum(np.dot(x, W), 0.0)**2)

        ans = vmap(partial(grad(fun), W))(x)

        W_t = np.transpose(W)
        for i in range(10):
            x_ex = x[i:i + 1]

            expected_ans = 2.0 * np.dot(
                np.maximum(np.dot(W_t, np.transpose(x_ex)), 0.0), x_ex)
            expected_ans = np.transpose(expected_ans)

            self.assertAllClose(ans[i],
                                expected_ans,
                                check_dtypes=False,
                                atol={onp.float32: 5e-2}
                                if jtu.device_under_test() == "tpu" else None)

    def testDotGeneral(self):
        R = onp.random.RandomState(0).randn

        x = R(10, 3, 4, 5)
        y = R(10, 3, 5, 6)
        fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ),
                                                                   (0, ))])
        ans = vmap(fun)(x, y)
        expected = lax.dot_general(x, y, [((3, ), (2, )), ((0, 1), (0, 1))])
        self.assertAllClose(ans, expected, check_dtypes=True)

        x = R(3, 4, 10, 5)
        y = R(3, 10, 5, 6)
        fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ),
                                                                   (0, ))])
        ans = vmap(fun, in_axes=(2, 1))(x, y)
        expected = onp.stack(
            [fun(x[..., i, :], y[:, i, ...]) for i in range(10)])
        self.assertAllClose(ans, expected, check_dtypes=True)

        x = R(3, 4, 5, 10)
        y = R(3, 5, 6)
        fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ),
                                                                   (0, ))])
        ans = vmap(fun, in_axes=(3, None))(x, y)
        expected = onp.stack([fun(x[..., i], y) for i in range(10)])
        self.assertAllClose(ans, expected, check_dtypes=True)

        x = R(3, 4, 5)
        y = R(3, 5, 10, 6)
        fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ),
                                                                   (0, ))])
        ans = vmap(fun, in_axes=(None, 2))(x, y)
        expected = onp.stack([fun(x, y[..., i, :]) for i in range(10)])
        self.assertAllClose(ans, expected, check_dtypes=True)

        x = R(4)
        y = R(4, 10)
        fun = lambda x, y: lax.dot_general(x, y, [((0, ), (0, )), ((), ())])
        ans = vmap(fun, in_axes=(None, 1))(x, y)
        expected = onp.stack([fun(x, y[..., i]) for i in range(10)])
        self.assertAllClose(ans, expected, check_dtypes=True)

    def testDot(self):
        # these tests are based on @shoyer's notebook studying gufuncs

        def vecvec(a, b):
            dot = np.dot
            for ndim in range(1, max(a.ndim, b.ndim)):
                a_ax = 0 if a.ndim > ndim else None
                b_ax = 0 if b.ndim > ndim else None
                dot = vmap(dot, in_axes=(a_ax, b_ax))
            return dot(a, b)

        assert vecvec(np.zeros((3, )), np.zeros((3, ))).shape == ()
        assert vecvec(np.zeros((2, 3)), np.zeros((3, ))).shape == (2, )
        assert vecvec(np.zeros((4, 2, 3)), np.zeros((3, ))).shape == (4, 2)

    def testDot2(self):
        R = onp.random.RandomState(0).randn
        xs = R(10, 3)
        ys = R(10, 3)
        ans = vmap(np.dot)(xs, ys)
        expected = onp.einsum('ni,ni->n', xs, ys)
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testDot3(self):
        R = onp.random.RandomState(0).randn
        xs = R(5, 8, 10)
        ys = R(10, 1)
        ans = vmap(np.dot, in_axes=(1, None))(xs, ys)
        expected = onp.einsum('inj,jk->nik', xs, ys)
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testDot4(self):
        R = onp.random.RandomState(0).randn
        xs = R(3, 2)
        ys = R(3)
        ans = vmap(np.dot, in_axes=(1, None))(xs, ys)
        expected = onp.einsum('ij,i->j', xs, ys)
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testDot5(self):
        f = vmap(partial(np.einsum, 'ij,j->i'), (None, 0))
        jaxpr = make_jaxpr(f)(np.zeros((1000, 1000)), np.zeros((1000, 1000)))
        assert "broadcast" not in str(jaxpr)

    def testPad(self):
        R = onp.random.RandomState(0).randn

        fun = lambda x: lax.pad(x, onp.float32(0), [(1, 2, 1)])
        x = R(5, 10).astype(onp.float32)
        ans = vmap(fun)(x)
        expected_ans = np.stack(list(map(fun, x)))
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

        fun = lambda x: lax.pad(x, onp.float32(0), [(1, 2, 1), (0, 1, 0)])
        x = R(5, 10, 3).astype(onp.float32)
        ans = vmap(fun)(x)
        expected_ans = np.stack(list(map(fun, x)))
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

    def testConcatenate(self):
        R = lambda *shape: onp.random.RandomState(0).randn(*shape).astype(
            onp.float32)

        fun = lambda *args: lax.concatenate(args, dimension=0)
        x, y, z = R(10, 2, 3), R(1, 10, 3), R(4, 3)
        ans = vmap(fun, in_axes=(0, 1, None))(x, y, z)
        expected_ans = onp.concatenate(
            [x, onp.swapaxes(y, 0, 1),
             onp.broadcast_to(z, (10, 4, 3))], 1)
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

        fun = lambda *args: lax.concatenate(args, dimension=1)
        x, y, z = R(10, 2, 1), R(2, 3), R(2, 4, 10)
        ans = vmap(fun, in_axes=(0, None, 2))(x, y, z)
        expected_ans = onp.concatenate(
            [x, onp.broadcast_to(y, (10, 2, 3)),
             onp.moveaxis(z, 2, 0)], 2)
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

    def testJacobianIssue54(self):
        # test modeling the code in https://github.com/google/jax/issues/54

        def func(xs):
            return np.array([x for x in xs])

        xs = np.ones((5, 1))
        jacrev(func)(xs)  # don't crash
        jacfwd(func)(xs)  # don't crash

    def testAny(self):
        # test modeling the code in https://github.com/google/jax/issues/108

        ans = vmap(np.any)(np.array([[True, False], [False, False]]))
        expected = np.array([True, False])
        self.assertAllClose(ans, expected, check_dtypes=True)

    @jtu.skip_on_devices("tpu")
    def testHessian(self):
        # test based on code from sindhwani@google
        def fun(x, t):
            return np.sum(np.power(np.maximum(x, 0.0), 2)) + t

        x = onp.array([-1., -0.5, 0., 0.5, 1.0])

        ans = hessian(lambda x: fun(x, 0.0))(x)
        expected = onp.array([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.],
                              [0., 0., 0.5, 0., 0.], [0., 0., 0., 2., 0.],
                              [0., 0., 0., 0., 2.]])
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testDynamicSlice(self):
        # test dynamic_slice via numpy indexing syntax
        # see https://github.com/google/jax/issues/1613 for an explanation of why we
        # need to use np rather than onp to create x and idx
        x = np.arange(30).reshape((10, 3))

        ans = vmap(lambda x, i: x[i], in_axes=(0, None))(x, 1)
        expected = x[:, 1]
        self.assertAllClose(ans, expected, check_dtypes=False)

        idx = np.array([0, 1, 2, 1, 0] * 2)
        ans = vmap(lambda x, i: x[i], in_axes=(0, 0))(x, idx)
        expected = x[onp.arange(10), idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

        x = np.arange(3)
        idx = np.array([0, 1, 2, 1, 0] * 2)
        ans = vmap(lambda x, i: x[i], in_axes=(None, 0))(x, idx)
        expected = x[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testDynamicUpdateSlice(self):
        x = onp.random.randn(10, 3)
        y = onp.random.randn(10)
        ans = vmap(
            lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0),
            in_axes=(0, 0, None))(x, y, 1)
        expected = x.copy()
        expected[:, 1] = y
        self.assertAllClose(ans, expected, check_dtypes=False)

        x = onp.random.randn(3)
        idx = onp.array([0, 1, 2, 1, 0] * 2)
        ans = vmap(
            lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0),
            in_axes=(None, 0, 0))(x, y, idx)
        expected = onp.broadcast_to(x, (10, 3)).copy()
        expected[onp.arange(10), idx] = y
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testRandom(self):
        seeds = vmap(random.PRNGKey)(onp.arange(10))
        ans = vmap(partial(random.normal, shape=(3, 2)))(seeds)
        expected = onp.stack([
            random.normal(random.PRNGKey(seed), (3, 2))
            for seed in onp.arange(10)
        ])
        self.assertAllClose(ans, expected, check_dtypes=False)
        assert len(onp.unique(ans)) == 10 * 3 * 2

    def testSort(self):
        v = onp.arange(12)[::-1].reshape(3, 4)

        sv = vmap(partial(lax.sort, dimension=0), (0, ))(v)
        self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)

        sv = vmap(partial(lax.sort, dimension=-1), (0, ))(v)
        self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)

        sv = vmap(partial(lax.sort, dimension=0), (1, ))(v)
        self.assertAllClose(sv, v[::-1, :].T, check_dtypes=True)

        sv = vmap(partial(lax.sort, dimension=0), (1, ), 1)(v)
        self.assertAllClose(sv, v[::-1, :], check_dtypes=True)

    def testSortKeyVal(self):
        k = onp.arange(12)[::-1].reshape(3, 4)
        v = onp.random.RandomState(0).permutation(12).reshape(3, 4)

        sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 0))(k, v)
        self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
        self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)

        sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 1), 1)(k, v)
        self.assertAllClose(sk, k[::-1, :], check_dtypes=True)
        self.assertAllClose(sv, v[::-1, :], check_dtypes=True)

        sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 1))(k, v.T)
        self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
        self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)

        sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 0))(k.T, v)
        self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
        self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)

        sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (None, 0))(k[0],
                                                                         v)
        self.assertAllClose(sk,
                            onp.broadcast_to(k[0, ::-1], (3, 4)),
                            check_dtypes=True)
        self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)

        sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, None))(k.T,
                                                                         v[0])
        self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
        self.assertAllClose(sv,
                            onp.broadcast_to(v[0, ::-1], (3, 4)),
                            check_dtypes=True)

    def testConvGeneralDilated(self):
        W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32)
        X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32)

        def f(params, x):
            one = (1, 1)
            dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
            y = lax.conv_general_dilated(x, params, one, 'SAME', one, one,
                                         dimension_numbers)
            return y

        grad_loss = grad(lambda params, x: np.mean(f(params, x)**2))

        # Test forward prop.
        per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1)))
        per_example = np.reshape(per_example, (10, 5, 5, 5))
        per_example_direct = f(W, X)
        self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

        # Test gradients.
        per_example = vmap(partial(grad_loss,
                                   W))(np.reshape(X, (10, 1, 5, 5, 1)))
        per_example_direct = []
        for i in range(10):
            g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1)))
            per_example_direct += [np.reshape(g, (1, ) + g.shape)]
        per_example_direct = np.concatenate(per_example_direct, axis=0)
        self.assertAllClose(per_example,
                            per_example_direct,
                            check_dtypes=True,
                            rtol=2e-2)

    def testConvGeneralDilatedBatchNotMajor(self):
        W = np.array(onp.random.randn(3, 3, 1, 4), dtype=onp.float32)
        x = np.array(onp.random.randn(3, 5, 7, 5, 1), dtype=onp.float32)

        def f(params, x):
            one = (1, 1)
            dimension_numbers = ('HNWC', 'HWIO', 'HWNC')
            y = lax.conv_general_dilated(x, params, one, 'SAME', one, one,
                                         dimension_numbers)
            return y

        per_example = vmap(partial(f, W))(x)
        per_example = np.reshape(np.transpose(per_example, (1, 2, 0, 3, 4)),
                                 (5, 5, 21, 4))
        per_example_direct = f(
            W, np.reshape(np.transpose(x, (1, 0, 2, 3, 4)), (5, 21, 5, 1)))
        self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

    @parameterized.named_parameters({
        "testcase_name": "_op={}".format(name),
        "op": op,
        "unit": unit
    } for name, op, unit in [("max", lax.max, -np.inf), ("min", lax.min,
                                                         np.inf)])
    def testMinMaxPool(self, op, unit):
        W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32)
        X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32)

        def f(params, x):
            one = (1, 1)
            dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
            y = lax.conv_general_dilated(x, params, one, 'SAME', one, one,
                                         dimension_numbers)
            y = lax.reduce_window(y, unit, op, (1, 2, 2, 1), (1, 1, 1, 1),
                                  'SAME')
            return y

        grad_loss = grad(lambda params, x: np.mean(f(params, x)**2))

        # Test forward prop.
        per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1)))
        per_example = np.reshape(per_example, (10, 5, 5, 5))
        per_example_direct = f(W, X)
        self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

        # Test gradients.
        per_example = vmap(partial(grad_loss,
                                   W))(np.reshape(X, (10, 1, 5, 5, 1)))
        per_example_direct = []
        for i in range(10):
            g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1)))
            per_example_direct += [np.reshape(g, (1, ) + g.shape)]
        per_example_direct = np.concatenate(per_example_direct, axis=0)
        self.assertAllClose(per_example,
                            per_example_direct,
                            check_dtypes=True,
                            rtol=5e-2)

    def testSumPool(self):
        W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32)
        X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32)

        def f(params, x):
            one = (1, 1)
            dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
            y = lax.conv_general_dilated(x, params, one, 'SAME', one, one,
                                         dimension_numbers)
            y = lax.reduce_window(y, 0.0, lax.add, (1, 2, 2, 1), (1, 1, 1, 1),
                                  'SAME')
            return y

        grad_loss = grad(lambda params, x: np.mean(f(params, x)**2))

        # Test forward prop.
        per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1)))
        per_example = np.reshape(per_example, (10, 5, 5, 5))
        per_example_direct = f(W, X)
        self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

        # Test gradients.
        per_example = vmap(partial(grad_loss,
                                   W))(np.reshape(X, (10, 1, 5, 5, 1)))
        per_example_direct = []
        for i in range(10):
            g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1)))
            per_example_direct += [np.reshape(g, (1, ) + g.shape)]
        per_example_direct = np.concatenate(per_example_direct, axis=0)
        self.assertAllClose(per_example,
                            per_example_direct,
                            check_dtypes=True,
                            rtol=3e-2)

    def testCumProd(self):
        x = np.arange(9).reshape(3, 3) + 1
        y = vmap(lambda x: np.cumprod(x, axis=-1))(x)
        self.assertAllClose(onp.cumprod(x, axis=1, dtype=np.int_),
                            y,
                            check_dtypes=True)

    def testSelect(self):
        pred = onp.array([True, False])
        on_true = onp.array([0, 1])
        on_false = onp.array([2, 3])
        ans = vmap(lax.select)(pred, on_true, on_false)
        expected = onp.array([0, 3])
        self.assertAllClose(ans, expected, check_dtypes=True)

        pred = onp.array([False, True])
        on_true = onp.array([0, 1])
        on_false = onp.array([2, 3])
        ans = vmap(lax.select, (0, None, None))(pred, on_true, on_false)
        expected = onp.array([[2, 3], [0, 1]])
        self.assertAllClose(ans, expected, check_dtypes=True)

        pred = True
        on_true = onp.array([0, 1], onp.float32)
        on_false = onp.array(3, onp.float32)
        ans = vmap(lax.select, (None, 0, None))(pred, on_true, on_false)
        expected = onp.array([0, 1], onp.float32)
        self.assertAllClose(ans, expected, check_dtypes=True)

        pred = onp.array([False, True])
        on_true = onp.array([0, 1], onp.float32)
        on_false = onp.array(3, onp.float32)
        ans = vmap(lax.select, (0, 0, None))(pred, on_true, on_false)
        expected = onp.array([3, 1], onp.float32)
        self.assertAllClose(ans, expected, check_dtypes=True)

        pred = onp.array([False, True])
        on_true = onp.array([2], onp.float32)
        on_false = onp.array([[3, 4]], onp.float32)
        ans = vmap(lax.select, (0, None, 1), 1)(pred, on_true, on_false)
        expected = onp.array([[3, 2]], onp.float32)
        self.assertAllClose(ans, expected, check_dtypes=True)

    def testLaxLinalgCholesky(self):
        a = onp.random.RandomState(0).randn(10, 5, 5).astype(onp.float32)
        a = onp.matmul(a, onp.conj(onp.swapaxes(a, -1, -2)))

        ans = vmap(lax_linalg.cholesky)(a)
        expected = onp.linalg.cholesky(a)
        self.assertAllClose(ans, expected, check_dtypes=False, rtol=1e-4)

        b = onp.random.RandomState(0).randn(10, 5, 5).astype(onp.float32)
        b = onp.matmul(b, onp.conj(onp.swapaxes(b, -1, -2)))
        b_trans = onp.swapaxes(b, 0, 1)  # shape is (5, 10, 5)

        ans = vmap(lax_linalg.cholesky, in_axes=1, out_axes=0)(b_trans)
        expected = onp.linalg.cholesky(b)
        self.assertAllClose(ans, expected, check_dtypes=False, rtol=1e-4)

    def testLaxLinalgTriangularSolve(self):
        a = onp.random.RandomState(0).randn(4, 10, 4).astype(onp.float32)
        a += onp.eye(4, dtype=np.float32)[:, None, :]
        b = onp.random.RandomState(0).randn(5, 4, 10).astype(onp.float32)

        ans = vmap(lax_linalg.triangular_solve, in_axes=(1, 2))(a, b)
        expected = onp.stack([
            lax_linalg.triangular_solve(a[:, i], b[..., i]) for i in range(10)
        ])
        self.assertAllClose(ans, expected, check_dtypes=True)

        ans = vmap(lax_linalg.triangular_solve, in_axes=(None, 2))(a[:, 0], b)
        expected = onp.stack([
            lax_linalg.triangular_solve(a[:, 0], b[..., i]) for i in range(10)
        ])
        self.assertAllClose(ans, expected, check_dtypes=True)

        ans = vmap(lax_linalg.triangular_solve, in_axes=(1, None))(a, b[...,
                                                                        0])
        expected = onp.stack([
            lax_linalg.triangular_solve(a[:, i], b[..., 0]) for i in range(10)
        ])
        self.assertAllClose(ans, expected, check_dtypes=True)

    @parameterized.named_parameters(
        {
            "testcase_name":
            "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
                jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
                slice_sizes),
            "axis":
            axis,
            "shape":
            shape,
            "dtype":
            dtype,
            "idxs":
            idxs,
            "dnums":
            dnums,
            "slice_sizes":
            slice_sizes,
            "rng_factory":
            rng_factory,
            "rng_idx_factory":
            rng_idx_factory
        } for dtype in [onp.float32, onp.int32]
        for axis, shape, idxs, dnums, slice_sizes in [
            (0, (3, 5), onp.array([[0], [2]]),
             lax.GatherDimensionNumbers(offset_dims=(),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, )),
            (1, (10, 3), onp.array([[0], [0], [0]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(),
                                        start_index_map=(0, )), (2, )),
            (1, (10, 3, 5), onp.array([[0], [2], [1]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, 3)),
            (2, (10, 5, 3), onp.array([[0, 2], [1, 0]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, 1)), (1, 3)),
        ] for rng_idx_factory in [partial(jtu.rand_int, max(shape))]
        for rng_factory in [jtu.rand_default])
    def testGatherBatchedOperand(self, axis, shape, dtype, idxs, dnums,
                                 slice_sizes, rng_factory, rng_idx_factory):
        rng = rng_factory()
        rng_idx = rng_idx_factory()
        fun = partial(lax.gather,
                      dimension_numbers=dnums,
                      slice_sizes=slice_sizes)
        operand = rng(shape, dtype)
        ans = vmap(fun, (axis, None))(operand, idxs)
        expected = onp.stack([
            fun(operand[(slice(None), ) * axis + (i, )], idxs)
            for i in range(operand.shape[axis])
        ])
        self.assertAllClose(ans, expected, check_dtypes=False)

    @parameterized.named_parameters(
        {
            "testcase_name":
            "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
                jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
                slice_sizes),
            "axis":
            axis,
            "shape":
            shape,
            "dtype":
            dtype,
            "idxs":
            idxs,
            "dnums":
            dnums,
            "slice_sizes":
            slice_sizes,
            "rng_factory":
            rng_factory,
            "rng_idx_factory":
            rng_idx_factory
        } for dtype in [onp.float32, onp.float64]
        for axis, shape, idxs, dnums, slice_sizes in [
            (0, (3, 5), onp.array([[0], [2]]),
             lax.GatherDimensionNumbers(offset_dims=(),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, )),
            (1, (10, 3), onp.array([[0], [0], [0]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(),
                                        start_index_map=(0, )), (2, )),
            (1, (10, 3, 5), onp.array([[0], [2], [1]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, 3)),
            (2, (10, 5, 3), onp.array([[0, 2], [1, 0]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, 1)), (1, 3)),
        ] for rng_idx_factory in [partial(jtu.rand_int, max(shape))]
        for rng_factory in [jtu.rand_default])
    def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums,
                                     slice_sizes, rng_factory,
                                     rng_idx_factory):
        rng = rng_factory()
        rng_idx = rng_idx_factory()
        fun = partial(lax.gather,
                      dimension_numbers=dnums,
                      slice_sizes=slice_sizes)
        gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx))))
        operand = rng(shape, dtype)
        ans = vmap(gfun, (axis, None))(operand, idxs)
        expected = onp.stack([
            gfun(operand[(slice(None), ) * axis + (i, )], idxs)
            for i in range(operand.shape[axis])
        ])
        self.assertAllClose(ans, expected, check_dtypes=False)

    @parameterized.named_parameters(
        {
            "testcase_name":
            "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
                jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
                slice_sizes),
            "axis":
            axis,
            "shape":
            shape,
            "dtype":
            dtype,
            "idxs":
            idxs,
            "dnums":
            dnums,
            "slice_sizes":
            slice_sizes,
            "rng_factory":
            rng_factory,
            "rng_idx_factory":
            rng_idx_factory
        } for dtype in [onp.float32, onp.int32]
        for axis, shape, idxs, dnums, slice_sizes in [
            (0, (5, ), onp.array([[[0], [2]], [[1], [3]]]),
             lax.GatherDimensionNumbers(offset_dims=(),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, )),
            (1, (10, ), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(),
                                        start_index_map=(0, )), (2, )),
            (1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T[..., None],
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, 3)),
            (0, (10, 5), onp.array([[[0, 1], [2, 0]], [[1, 0], [2, 3]]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, 1)), (1, 3)),
        ] for rng_idx_factory in [partial(jtu.rand_int, max(shape))]
        for rng_factory in [jtu.rand_default])
    def testGatherBatchedIndices(self, axis, shape, dtype, idxs, dnums,
                                 slice_sizes, rng_factory, rng_idx_factory):
        rng = rng_factory()
        rng_idx = rng_idx_factory()
        fun = partial(lax.gather,
                      dimension_numbers=dnums,
                      slice_sizes=slice_sizes)
        operand = rng(shape, dtype)
        ans = vmap(fun, (None, axis))(operand, idxs)
        expected = onp.stack([
            fun(operand, idxs[(slice(None), ) * axis + (i, )])
            for i in range(idxs.shape[axis])
        ])
        self.assertAllClose(ans, expected, check_dtypes=False)

    @parameterized.named_parameters(
        {
            "testcase_name":
            "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
                jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
                slice_sizes),
            "axis":
            axis,
            "shape":
            shape,
            "dtype":
            dtype,
            "idxs":
            idxs,
            "dnums":
            dnums,
            "slice_sizes":
            slice_sizes,
            "rng_factory":
            rng_factory,
            "rng_idx_factory":
            rng_idx_factory
        } for dtype in [onp.float32, onp.float64]
        for axis, shape, idxs, dnums, slice_sizes in [
            (0, (5, ), onp.array([[[0], [2]], [[1], [3]]]),
             lax.GatherDimensionNumbers(offset_dims=(),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, )),
            (1, (10, ), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(),
                                        start_index_map=(0, )), (2, )),
            (1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T[..., None],
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, 3)),
            (0, (10, 5), onp.array([[[0, 1], [2, 0]], [[1, 0], [2, 3]]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, 1)), (1, 3)),
        ] for rng_idx_factory in [partial(jtu.rand_int, max(shape))]
        for rng_factory in [jtu.rand_default])
    def testGatherGradBatchedIndices(self, axis, shape, dtype, idxs, dnums,
                                     slice_sizes, rng_factory,
                                     rng_idx_factory):
        rng = rng_factory()
        rng_idx = rng_idx_factory()
        fun = partial(lax.gather,
                      dimension_numbers=dnums,
                      slice_sizes=slice_sizes)
        gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx))))
        operand = rng(shape, dtype)
        ans = vmap(gfun, (None, axis))(operand, idxs)
        expected = onp.stack([
            gfun(operand, idxs[(slice(None), ) * axis + (i, )])
            for i in range(idxs.shape[axis])
        ])
        self.assertAllClose(ans, expected, check_dtypes=False)

    @parameterized.named_parameters(
        {
            "testcase_name":
            "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}"
            .format(jtu.format_shape_dtype_string(shape, dtype), op_axis,
                    idxs_axis, idxs, dnums, slice_sizes),
            "op_axis":
            op_axis,
            "idxs_axis":
            idxs_axis,
            "shape":
            shape,
            "dtype":
            dtype,
            "idxs":
            idxs,
            "dnums":
            dnums,
            "slice_sizes":
            slice_sizes,
            "rng_factory":
            rng_factory,
            "rng_idx_factory":
            rng_idx_factory
        } for dtype in [onp.float32, onp.int32]
        for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
            (0, 0, (2, 5), onp.array([[[0], [2]], [[1], [3]]]),
             lax.GatherDimensionNumbers(offset_dims=(),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, )),
            (1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(),
                                        start_index_map=(0, )), (2, )),
            (0, 1, (
                2,
                10,
                5,
            ), onp.array([[[0, 2, 1], [0, 3, 3]]]).T,
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, 3)),
            (2, 0, (10, 5, 2), onp.array([[[0, 2], [1, 0]], [[1, 0], [2, 0]]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, 1)), (1, 3)),
        ] for rng_idx_factory in [partial(jtu.rand_int, max(shape))]
        for rng_factory in [jtu.rand_default])
    def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs,
                              dnums, slice_sizes, rng_factory,
                              rng_idx_factory):
        rng = rng_factory()
        rng_idx = rng_idx_factory()
        fun = partial(lax.gather,
                      dimension_numbers=dnums,
                      slice_sizes=slice_sizes)
        operand = rng(shape, dtype)
        assert operand.shape[op_axis] == idxs.shape[idxs_axis]
        ans = vmap(fun, (op_axis, idxs_axis))(operand, idxs)
        expected = onp.stack([
            fun(operand[(slice(None), ) * op_axis + (i, )],
                idxs[(slice(None), ) * idxs_axis + (i, )])
            for i in range(idxs.shape[idxs_axis])
        ])
        self.assertAllClose(ans, expected, check_dtypes=False)

    @parameterized.named_parameters(
        {
            "testcase_name":
            "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}"
            .format(jtu.format_shape_dtype_string(shape, dtype), op_axis,
                    idxs_axis, idxs, dnums, slice_sizes),
            "op_axis":
            op_axis,
            "idxs_axis":
            idxs_axis,
            "shape":
            shape,
            "dtype":
            dtype,
            "idxs":
            idxs,
            "dnums":
            dnums,
            "slice_sizes":
            slice_sizes,
            "rng_factory":
            rng_factory,
            "rng_idx_factory":
            rng_idx_factory
        } for dtype in [onp.float32]
        for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
            (0, 0, (2, 5), onp.array([[[0], [2]], [[1], [3]]]),
             lax.GatherDimensionNumbers(offset_dims=(),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, )),
            (1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(),
                                        start_index_map=(0, )), (2, )),
            (0, 1, (
                2,
                10,
                5,
            ), onp.array([[[0, 2, 1], [0, 3, 3]]]).T,
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, 3)),
            (2, 0, (10, 5, 2), onp.array([[[0, 2], [1, 0]], [[1, 0], [2, 0]]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, 1)), (1, 3)),
        ] for rng_idx_factory in [partial(jtu.rand_int, max(shape))]
        for rng_factory in [jtu.rand_default])
    def testGatherGradBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs,
                                  dnums, slice_sizes, rng_factory,
                                  rng_idx_factory):
        rng = rng_factory()
        rng_idx = rng_idx_factory()
        fun = partial(lax.gather,
                      dimension_numbers=dnums,
                      slice_sizes=slice_sizes)
        gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx))))
        operand = rng(shape, dtype)
        assert operand.shape[op_axis] == idxs.shape[idxs_axis]
        ans = vmap(gfun, (op_axis, idxs_axis))(operand, idxs)
        expected = onp.stack([
            gfun(operand[(slice(None), ) * op_axis + (i, )],
                 idxs[(slice(None), ) * idxs_axis + (i, )])
            for i in range(idxs.shape[idxs_axis])
        ])
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testNumpyIndexing1(self):
        a = np.arange(2 * 3 * 4).reshape((2, 3, 4))
        ind = onp.array([[0, 1], [2, 0]])

        def f(a, ind):
            return a[:, ind]

        expected = onp.stack([f(a, ind[i, :]) for i in range(ind.shape[0])])
        ans = vmap(f, (None, 0))(a, ind)
        assert onp.all(ans == expected)

    def testNumpyIndexing2(self):
        a = np.arange(2 * 3 * 4).reshape((2, 3, 4))

        def f(a):
            inds = np.array([0, 2])
            return a[:, inds]

        ans = vmap(f)(a)
        expected = onp.stack([f(a[:, i, :]) for i in range(a.shape[1])],
                             axis=1)
        assert onp.all(ans == expected)

    def testTranspose(self):
        x = onp.arange(4 * 3 * 3).reshape((4, 3, 3))
        ans = vmap(lambda x: x + x.T)(x)
        expected = x + onp.swapaxes(x, -1, -2)
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testTransposePermutation(self):
        x = onp.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5))
        ans = vmap(lambda x: np.transpose(x, (1, 0, 2)))(x)
        expected = onp.transpose(x, (0, 2, 1, 3))
        self.assertAllClose(ans, expected, check_dtypes=False)

        x = onp.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5))
        ans = vmap(lambda x: np.transpose(x, (1, 2, 0)))(x)
        expected = onp.transpose(x, (0, 2, 3, 1))
        self.assertAllClose(ans, expected, check_dtypes=False)

        x = onp.arange(6 * 3 * 4 * 5).reshape((3, 4, 6, 5))
        ans = vmap(lambda x: np.transpose(x, (1, 2, 0)), in_axes=2)(x)
        expected = onp.transpose(x, (2, 1, 3, 0))
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testIssue354(self):
        psd_mat = onp.random.randn(20, 10)
        psd_mat = psd_mat.T.dot(psd_mat)
        vec = onp.random.randn(10)

        def f(scale):
            scaled_mat = scale * psd_mat
            chol = np.linalg.cholesky(scaled_mat)
            return -0.5 * np.sum((np.einsum('ij,j->i', chol, vec))**2)

        vmapped_f = vmap(f)
        vmapped_f_grad = grad(lambda x: np.sum(vmapped_f(x)))

        scales = onp.array([[0.1], [0.2], [0.3], [0.4], [0.5]])
        ans = vmapped_f_grad(scales)  # don't crash!
        expected = onp.stack([grad(f)(scale) for scale in scales])
        self.assertAllClose(ans,
                            expected,
                            check_dtypes=False,
                            rtol=jtu.default_gradient_tolerance)

    def testIssue387(self):
        # https://github.com/google/jax/issues/387
        R = onp.random.RandomState(0).rand(100, 2)

        def dist_sq(R):
            dR = R[:, np.newaxis, :] - R[np.newaxis, :, :]
            zero = np.zeros_like(dR)
            dR = dR - np.where(np.abs(dR) < 0.5, zero, 0.5 * np.sign(dR))
            return np.sum(dR**2, axis=2)

        @jit
        def f(R):
            dr = dist_sq(R)
            return np.sum(R**2)

        H = hessian(f)(R)  # don't crash on UnshapedArray

    def testIssue489(self):
        def f(key):
            def body_fn(uk):
                key = uk[1]
                u = random.uniform(key, (), dtype=np.float64)
                key, _ = random.split(key)
                return u, key

            u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn,
                                  (np.float64(1.), key))
            return u

        print(vmap(f)(random.split(random.PRNGKey(0), 2)))  # no crash

    def testEmptyTuples(self):
        # Ensure there is no crash when a vectorized input contains empty tuples.
        result = vmap(lambda x, _: x + 1)(onp.array([0, 1]), ())
        self.assertAllClose(result, onp.array([1, 2]), check_dtypes=False)
        # Ensure there is no crash when a vectorized output contains empty tuples.
        result, empty_tuple = vmap(lambda x: (x + 1, ()))(onp.array([0, 1]))
        self.assertAllClose(result, onp.array([1, 2]), check_dtypes=False)
        self.assertEqual((), empty_tuple)

    def testIndexAddBatchedIndexesOnly(self):
        f = lambda x, idx, y: jax.ops.index_add(x, jax.ops.index[idx], y)
        result = vmap(f, (None, 0, None))(onp.zeros((10, )), onp.arange(10, ),
                                          1.)
        self.assertAllClose(result, onp.eye(10), check_dtypes=False)

    def testIssue1170(self):
        def f(index1, index2):
            return np.arange(36).reshape(6, 6)[index1, index2]

        g = jax.jit(jax.pmap(f))
        ans = g(index1=onp.asarray([1]), index2=onp.asarray([2]))
        expected = g(onp.asarray([1]), onp.asarray([2]))
        self.assertAllClose(ans, expected, check_dtypes=True)
Ejemplo n.º 6
0
class LaxAutodiffTest(jtu.JaxTestCase):

  @parameterized.named_parameters(itertools.chain.from_iterable(
      jtu.cases_from_list(
        {"testcase_name": jtu.format_test_name_suffix(
            rec.name, shapes, itertools.repeat(dtype)),
         "op": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype,
         "order": rec.order, "tol": rec.tol}
        for shape_group in compatible_shapes
        for shapes in CombosWithReplacement(shape_group, rec.nargs)
        for dtype in rec.dtypes)
      for rec in LAX_GRAD_OPS))
  def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol):
    rng = rng_factory(self.rng())
    if jtu.device_under_test() == "tpu" and op is lax.pow:
      raise SkipTest("pow grad imprecise on tpu")
    tol = jtu.join_tolerance(1e-1, tol) if num_float_bits(dtype) == 32 else tol
    args = tuple(rng(shape, dtype) for shape in shapes)
    check_grads(op, args, order, ["fwd", "rev"], tol, tol)

  @parameterized.named_parameters(itertools.chain.from_iterable(
      jtu.cases_from_list(
          {"testcase_name": "_{}_{}".format(rec.op.__name__, special_value),
           "op": rec.op, "special_value": special_value, "tol": rec.tol}
          for special_value in rec.values)
      for rec in LAX_GRAD_SPECIAL_VALUE_TESTS))
  def testOpGradSpecialValue(self, op, special_value, tol):
    check_grads(op, (special_value,), 2, ["fwd", "rev"], rtol=tol, atol=tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_from_dtype={}_to_dtype={}".format(
          jtu.dtype_str(from_dtype), jtu.dtype_str(to_dtype)),
       "from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory}
      for from_dtype, to_dtype in itertools.product(
          float_dtypes + complex_dtypes, repeat=2)
      for rng_factory in [jtu.rand_default]))
  def testConvertElementTypeGrad(self, from_dtype, to_dtype, rng_factory):
    rng = rng_factory(self.rng())
    tol = max(jtu.tolerance(to_dtype, jtu.default_gradient_tolerance),
              jtu.tolerance(from_dtype, jtu.default_gradient_tolerance))
    args = (rng((2, 3), from_dtype),)
    convert_element_type = lambda x: lax.convert_element_type(x, to_dtype)
    convert_element_type = jtu.ignore_warning(category=onp.ComplexWarning)(
      convert_element_type)
    check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}".format(
          jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype,
       "rng_factory": rng_factory}
      for shape in [(), (2, 3)]
      for dtype in grad_float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testClampGrad(self, shape, dtype, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    low = operand - dtype(10)
    high = operand + dtype(10)
    # Avoids points near the boundary where the gradient may be inaccurate.
    check_grads(lax.clamp, (operand, low, high), 2, ["fwd", "rev"], eps=1e-2)
    check_grads(lax.clamp, (low, operand, high), 2, ["fwd", "rev"], eps=1e-2)
    check_grads(lax.clamp, (low, high, operand), 2, ["fwd", "rev"], eps=1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format(
          dim, ",".join(str(d) for d in base_shape), onp.dtype(dtype).name,
          num_arrs),
       "dim": dim, "base_shape": base_shape, "dtype": dtype,
       "num_arrs": num_arrs, "rng_factory": rng_factory}
      for num_arrs in [3]
      for dtype in float_dtypes
      for base_shape in [(4,), (3, 4), (2, 3, 4)]
      for dim in range(len(base_shape))
      for rng_factory in [jtu.rand_default]))
  def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs, rng_factory):
    rng = rng_factory(self.rng())
    shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:]
              for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))]
    operands = tuple(rng(shape, dtype) for shape in shapes)
    concatenate = lambda *args: lax.concatenate(args, dim)
    check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               strides, padding),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "strides": strides, "padding": padding, "rng_factory": rng_factory,}
       for lhs_shape, rhs_shape, all_strides in itertools.chain(
           [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)])
            for b, i, j in itertools.product([2, 3], repeat=3)],
           [((4, 2, 1), (3, 2, 1), [(1,)])])
       for strides in all_strides
       for dtype in float_dtypes
       for padding in ["VALID", "SAME"]
       for rng_factory in [jtu.rand_small]))
  def testConvGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, rng_factory):
    rng = rng_factory(self.rng())
    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    conv = partial(lax.conv, window_strides=strides, padding=padding,
                   precision=lax.Precision.HIGHEST)
    check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
                         atol=1e-2, rtol=1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
       "rhs_dilation={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               strides, padding, lhs_dil, rhs_dil),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "strides": strides, "padding": padding, "lhs_dil": lhs_dil,
       "rhs_dil": rhs_dil, "rng_factory": rng_factory}
       for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in
       itertools.chain(
           [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)],
             [((0, 0), (0, 0)), ((-1, 0), (0, -1)), ((1, 0), (0, 1))],
             [(1, 1), (2, 1)], [(1, 1)])
            for b, i, j in itertools.product([2, 3], repeat=3)],
           [((4, 2, 1), (3, 2, 1), [(1,)], [((1, 1),), ((0, 0),)],
             [(1,), (2,)], [(1,), (2,)])])
       for strides in all_strides
       for rhs_dil in rhs_dils
       for lhs_dil in lhs_dils
       for dtype in float_dtypes
       for padding in all_pads
       for rng_factory in [jtu.rand_small]))
  def testConvWithGeneralPaddingGrad(self, lhs_shape, rhs_shape, dtype, strides,
                                     padding, lhs_dil, rhs_dil, rng_factory):
    rng = rng_factory(self.rng())
    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    conv = partial(lax.conv_with_general_padding, window_strides=strides,
                   padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
                   precision=lax.Precision.HIGHEST)
    check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
                         atol=1e-2, rtol=1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
       "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums),
               feature_group_count, batch_group_count),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "strides": strides, "padding": padding, "lhs_dil": lhs_dil,
       "rhs_dil": rhs_dil, "rng_factory": rng_factory, "dimension_numbers": dim_nums,
       "perms": perms, "feature_group_count": feature_group_count,
       "batch_group_count": batch_group_count}
      for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)])
      for lhs_shapes, rhs_shape, all_strides, lhs_dils, rhs_dils in [
          ([(b * batch_group_count, i * feature_group_count, 6, 7),
            (b * batch_group_count, i * feature_group_count, 0, 4)],  # lhs_shape
           (j * batch_group_count * feature_group_count, i, 1, 2),  # rhs_shape
           [(1, 1), (1, 2), (2, 1)],  # strides
           [(1, 1), (2, 1)],  # lhs_dils
           [(1, 1), (2, 2)])  # rhs_dils
          for b, i, j in itertools.product([1, 2], repeat=3)]
      for lhs_shape in lhs_shapes
      for strides in all_strides
      for rhs_dil in rhs_dils
      for lhs_dil in lhs_dils
      for dtype in grad_float_dtypes
      for padding in ([((0, 0), (0, 0)), ((1, 0), (0, 1))] +
        ([((0, -1), (0, 0))] if lhs_shape[2] != 0 else []))
      for dim_nums, perms in [
          (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
          (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
          (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))]
      for rng_factory in [jtu.rand_default]
  ))
  def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides,
                                 padding, lhs_dil, rhs_dil, dimension_numbers,
                                 perms, feature_group_count, batch_group_count,
                                 rng_factory):
    if dtype == onp.float16:
      raise SkipTest("float16 numerical issues")  # TODO(mattjj): resolve

    rng = rng_factory(self.rng())
    tol = {dtypes.bfloat16: 1e-0, onp.float16: 5e-1, onp.float32: 2e-4}

    # permute shapes to match dim_spec, scale by feature_group_count
    lhs_perm, rhs_perm = perms
    lhs_shape = list(onp.take(lhs_shape, lhs_perm))
    rhs_shape = list(onp.take(rhs_shape, rhs_perm))

    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    conv = partial(lax.conv_general_dilated, window_strides=strides,
                   padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
                   dimension_numbers=dimension_numbers,
                   feature_group_count=feature_group_count,
                   batch_group_count=batch_group_count,
                   precision=lax.Precision.HIGHEST)
    check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
                         atol=tol, rtol=tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_lhs_shape={}_rhs_shape={}".format(
          jtu.format_shape_dtype_string(lhs_shape, dtype),
          jtu.format_shape_dtype_string(rhs_shape, dtype)),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "rng_factory": jtu.rand_default}
      for lhs_shape in [(2,), (3, 2)] for rhs_shape in [(2,), (2, 4)]
      for dtype in float_dtypes))
  def testDotGrad(self, lhs_shape, rhs_shape, dtype, rng_factory):
    rng = rng_factory(self.rng())
    tol = {onp.float16: 1e-1, onp.float32: 1e-4}
    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    dot = partial(lax.dot, precision=lax.Precision.HIGHEST)
    check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"],
                         atol=tol, rtol=tol)
    # check that precision config is preserved
    result, pullback = api.vjp(dot, lhs, rhs)
    gresult = lax.zeros_like_array(result)
    s = str(api.make_jaxpr(pullback)(gresult))
    assert "precision=HIGHEST" in s

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_dimension_numbers={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               dimension_numbers),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "dimension_numbers": dimension_numbers, "rng_factory": jtu.rand_small}
      for lhs_shape, rhs_shape, dimension_numbers in [
          ((3, 2), (2, 4), (([1], [0]), ([], []))),
          ((3, 5), (2, 5), (([1], [1]), ([], []))),
          ((5, 3), (5, 2), (([0], [0]), ([], []))),
          ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
      ]
      for dtype in float_dtypes))
  def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype,
                                          dimension_numbers, rng_factory):
    rng = rng_factory(self.rng())
    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers,
                          precision=lax.Precision.HIGHEST)
    check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"])
    # check that precision config is preserved
    result, pullback = api.vjp(dot_general, lhs, rhs)
    gresult = lax.zeros_like_array(result)
    s = str(api.make_jaxpr(pullback)(gresult))
    assert "precision=HIGHEST" in s

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format(
          shape, onp.dtype(dtype).name, broadcast_sizes),
       "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes,
       "rng_factory": rng_factory}
      for shape in [(), (2, 3)]
      for dtype in float_dtypes
      for broadcast_sizes in [(), (2,), (1, 2)]
      for rng_factory in [jtu.rand_default]))
  def testBroadcastGrad(self, shape, dtype, broadcast_sizes, rng_factory):
    rng = rng_factory(self.rng())
    args = (rng(shape, dtype),)
    broadcast = lambda x: lax.broadcast(x, broadcast_sizes)
    check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
          jtu.format_shape_dtype_string(inshape, dtype),
          outshape, broadcast_dimensions),
       "inshape": inshape, "dtype": dtype, "outshape": outshape,
       "dimensions": broadcast_dimensions, "rng_factory": rng_factory}
      for inshape, outshape, broadcast_dimensions in [
          ([2], [2, 2], [0]),
          ([2], [2, 2], [1]),
          ([2], [2, 3], [0]),
          ([], [2, 3], []),
      ]
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(inshape, dtype)
    broadcast_in_dim = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
    check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_outshape={}_perm={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          jtu.format_shape_dtype_string(out_shape, dtype),
          permutation),
       "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
       "rng_factory": rng_factory, "permutation": permutation}
      for dtype in float_dtypes
      for arg_shape, out_shape, permutation in [
          [(3, 4), (12,), None],
          [(2, 1, 4), (8,), None],
          [(2, 2, 4), (2, 8), None],
          [(3, 4), (12,), (0, 1)],
          [(3, 4), (12,), (1, 0)],
          [(2, 1, 4), (8,), (0, 2, 1)],
          [(2, 1, 4), (8,), (2, 0, 1)],
          [(2, 2, 4), (2, 8), (0, 2, 1)],
          [(2, 2, 4), (2, 8), (2, 0, 1)],
      ]
      for rng_factory in [jtu.rand_default]))
  def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(arg_shape, dtype)
    reshape = lambda x: lax.reshape(x, out_shape, permutation)
    check_grads(reshape, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_pads={}"
       .format(jtu.format_shape_dtype_string(shape, dtype), pads),
       "shape": shape, "dtype": dtype, "pads": pads, "rng_factory": jtu.rand_small}
      for shape in [(2, 3)]
      for dtype in float_dtypes
      for pads in [[(1, 2, 1), (0, 1, 0)], [(-1, 0, 0), (-1, 0, 2)]]))
  def testPadGrad(self, shape, dtype, pads, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    pad = lambda operand: lax.pad(operand, onp.array(0, dtype), pads)
    check_grads(pad, (operand,), 2, ["fwd", "rev"], eps=1.)

    operand = rng(shape, dtype)
    padding_value = onp.array(0., dtype)
    pad = lambda operand, padding_value: lax.pad(operand, padding_value, pads)
    check_grads(pad, (operand, padding_value), 2, ["fwd", "rev"], eps=1.)

  def testReverseGrad(self):
    rev = lambda operand: lax.rev(operand, dimensions)

    dimensions = [0]
    check_grads(rev, (onp.array([3., 2., 1.]),), 2)

    dimensions = [0, 1]
    check_grads(rev, (onp.array([[6., 5., 4.], [3., 2., 1.]]),), 2,
                rtol={onp.float32: 3e-3})

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_predshape={}_argshapes={}".format(
          jtu.format_shape_dtype_string(pred_shape, onp.bool_),
          jtu.format_shape_dtype_string(arg_shape, dtype)),
       "pred_shape": pred_shape, "arg_shape": arg_shape, "dtype": dtype,
       "rng_factory": rng_factory}
      for arg_shape in [(), (3,), (2, 3)]
      for pred_shape in ([(), arg_shape] if arg_shape else [()])
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testSelectGrad(self, pred_shape, arg_shape, dtype, rng_factory):
    rng = rng_factory(self.rng())
    pred = rng(pred_shape, onp.bool_)
    on_true = rng(arg_shape, dtype)
    on_false = rng(arg_shape, dtype)
    select = lambda on_true, on_false: lax.select(pred, on_true, on_false)
    check_grads(select, (on_true, on_false), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_start_indices={}_limit_indices={}_strides={}".format(
          jtu.format_shape_dtype_string(shape, dtype),
          start_indices, limit_indices, strides),
       "shape": shape, "dtype": dtype, "starts": start_indices,
       "limits": limit_indices, "strides": strides, "rng_factory": rng_factory}
      for shape, start_indices, limit_indices, strides in [
        [(3,), (1,), (2,), None],
        [(7,), (4,), (7,), None],
        [(5,), (1,), (5,), (2,)],
        [(8,), (1,), (6,), (2,)],
        [(5, 3), (1, 1), (3, 2), None],
        [(5, 3), (1, 1), (3, 1), None],
        [(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
        [(5, 3), (1, 1), (2, 1), (1, 1)],
        [(5, 3), (1, 1), (5, 3), (2, 1)],
      ]
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testSliceGrad(self, shape, dtype, starts, limits, strides, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    slice = lambda x: lax.slice(x, starts, limits, strides)
    check_grads(slice, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_start_indices={}_size_indices={}".format(
          jtu.format_shape_dtype_string(shape, dtype),
          start_indices, size_indices),
       "shape": shape, "dtype": dtype, "start_indices": start_indices,
       "size_indices": size_indices, "rng_factory": rng_factory}
      for shape, start_indices, size_indices in [
        [(3,), (1,), (1,)],
        [(5, 3), (1, 1), (3, 1)],
        [(7, 5, 3), (4, 1, 0), (2, 0, 1)],
      ]
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices,
                           rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    dynamic_slice = lambda x: lax.dynamic_slice(x, start_indices, size_indices)
    check_grads(dynamic_slice, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_start_indices={}_update_shape={}".format(
          jtu.format_shape_dtype_string(shape, dtype),
          start_indices, update_shape),
       "shape": shape, "dtype": dtype, "start_indices": start_indices,
       "update_shape": update_shape, "rng_factory": rng_factory}
      for shape, start_indices, update_shape in [
        [(3,), (1,), (1,)],
        [(5, 3), (1, 1), (3, 1)],
        [(7, 5, 3), (4, 1, 0), (2, 0, 1)],
      ]
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices,
                                 update_shape, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    update = rng(update_shape, dtype)
    start_indices = onp.array(start_indices)

    dus = lambda x, y: lax.dynamic_update_slice(x, y, start_indices)
    check_grads(dus, (operand, update), 2, ["fwd", "rev"], eps=1.)

    dus = lambda x: lax.dynamic_update_slice(x, update, start_indices)
    check_grads(dus, (operand,), 2, ["fwd", "rev"], eps=1.)

    dus = lambda y: lax.dynamic_update_slice(operand, y, start_indices)
    check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_perm={}".format(
          jtu.format_shape_dtype_string(shape, dtype), perm),
       "shape": shape, "dtype": dtype, "perm": perm, "rng_factory": rng_factory}
      for shape, perm in [
        [(3, 4), (1, 0)],
        [(3, 4), (0, 1)],
        [(3, 4, 5), (2, 1, 0)],
        [(3, 4, 5), (1, 0, 2)],
      ]
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testTransposeGrad(self, shape, dtype, perm, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    transpose = lambda x: lax.transpose(x, perm)
    check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_inshape={}_reducedims={}"
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims),
       "op": op, "init_val": init_val, "shape": shape, "dtype": dtype,
       "dims": dims, "rng_factory": rng_factory}
      for init_val, op, dtypes, rng_factory in [
          (0, lax.add, inexact_dtypes, jtu.rand_default),
          (-onp.inf, lax.max, grad_inexact_dtypes, jtu.rand_unique_int),
          (onp.inf, lax.min, grad_inexact_dtypes, jtu.rand_unique_int),
          (1, lax.mul, grad_float_dtypes, partial(jtu.rand_default, scale=1)),
      ]
      for dtype in dtypes
      for shape, dims in [
          [(3, 4, 5), ()],
          [(3, 4, 5), (0,)],
          [(3, 4, 5), (1, 2)],
          [(3, 4, 5), (0, 2)],
          [(3, 4, 5), (0, 1, 2)],
          [(3, 1), (1,)],
      ]))
  def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory):
    rng = rng_factory(self.rng())
    if jtu.device_under_test() == "tpu" and op is lax.mul:
      raise SkipTest("unimplemented case")
    tol = {dtypes.bfloat16: 2e-1, onp.float16: 1e-1, onp.float32: 1e-1,
           onp.float64: 1e-3, onp.complex64: 1e-1}
    operand = rng(shape, dtype)
    init_val = onp.asarray(init_val, dtype=dtype)
    reduce = lambda operand: lax.reduce(operand, init_val, op, dims)
    eps = (1.0 if dtypes.finfo(dtype).bits == 16 and op is lax.add else
           1e-1 if dtype == dtypes.bfloat16 else
           1e-2 if dtypes.finfo(dtype).bits == 32 else None)
    check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_dtype={}_padding={}"
       .format(op.__name__, onp.dtype(dtype).name, padding),
       "op": op, "init_val": init_val, "dtype": dtype, "padding": padding,
       "rng_factory": rng_factory}
      for init_val, op, dtypes, rng_factory in [
          (0, lax.add, grad_float_dtypes, jtu.rand_small),
          (-onp.inf, lax.max, grad_float_dtypes, jtu.rand_unique_int),
          (onp.inf, lax.min, grad_float_dtypes, jtu.rand_unique_int),
      ]
      for dtype in dtypes
      for padding in ["VALID", "SAME"]))
  @jtu.skip_on_flag("jax_skip_slow_tests", True)
  @jtu.ignore_warning(category=UserWarning,
                      message="Using reduced precision for gradient.*")
  def testReduceWindowGrad(self, op, init_val, dtype, padding, rng_factory):
    rng = rng_factory(self.rng())
    init_val = onp.asarray(init_val, dtype=dtype)

    # We need this conditional and the corresponding loop logic to be in the
    # test method, rather than at the parameterized test level, because it
    # depends on FLAGS for the device under test.
    # TODO(b/31565929): enable when fixed.
    if jtu.device_under_test() == "tpu" and op is not lax.add:
      all_configs = [((6, 5, 4, 3), (2, 2, 1, 1), (1, 2, 1, 1))]

      # TODO(b/73062247): need variadic reduce-window for better precision.
      gradient_order = 1
    else:
      all_configs = itertools.chain(
          itertools.product(
              [(4, 6)],  # shapes
              [(2, 1), (1, 2)],  # window_dimensions
              [(1, 1), (2, 1), (1, 2)]  # strides
          ),
          itertools.product(
              [(3, 2, 4, 6)],  # shapes
              [(1, 1, 2, 1), (2, 1, 2, 1)],  # window_dimensions
              [(1, 2, 2, 1), (1, 1, 1, 1)]),  # strides
      )
      gradient_order = 3

    def fun(operand):
      return lax.reduce_window(operand, init_val, op, dims, strides, padding)

    for shape, dims, strides in all_configs:
      operand = rng(shape, dtype)
      if op is lax.add:
        eps = 1.
        tol = None
      else:
        # this test can fail if there are duplicates in operand
        self.assertEqual(onp.unique(operand).size, operand.size,
                         msg="test requires operand elements to be unique.")
        eps = 1e-2
        tol = {onp.float16: 1e-1, onp.float32: 6e-2, onp.float64: 6e-2}
      check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], tol, tol,
                  eps)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_shape={}_axis={}"
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis),
       "op": op, "shape": shape, "dtype": dtype,
       "axis": axis, "rng_factory": rng_factory}
      for op, types in [
          (lax.cumsum, [onp.float32, onp.float64]),
          (lax.cumprod, [onp.float32, onp.float64]),
      ]
      for dtype in types
      for shape in [[10], [3, 4, 5]]
      for axis in range(len(shape))
      for rng_factory in [
          jtu.rand_default if dtypes.issubdtype(dtype, onp.integer)
          else jtu.rand_small]))
  def testCumulativeReduceGrad(self, op, shape, dtype, axis, rng_factory):
    rng = rng_factory(self.rng())
    check_grads(partial(op, axis=axis), (rng(shape, dtype),), order=2)


  # TODO(b/205052657): enable more tests when supported
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_axis={}_isstable={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, is_stable),
       "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "axis": axis,
       "is_stable": is_stable}
      for dtype in [onp.float32]
      for shape in [(5,), (5, 7)]
      for axis in [len(shape) - 1]
      for is_stable in [False, True]
      for rng_factory in [jtu.rand_default]))
  def testSortGrad(self, shape, dtype, axis, is_stable, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    sort = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable)
    check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2)

  # TODO(b/205052657): enable more tests when supported
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_keyshape={}_valshape={}_axis={}_isstable={}".format(
          jtu.format_shape_dtype_string(shape, key_dtype),
          jtu.format_shape_dtype_string(shape, val_dtype),
          axis, is_stable),
       "rng_factory": rng_factory, "shape": shape,
       "key_dtype": key_dtype, "val_dtype": val_dtype, "axis": axis,
       "is_stable": is_stable}
      for key_dtype in [onp.float32]
      for val_dtype in [onp.float32]
      for shape in [(3,), (5, 3)]
      for axis in [len(shape) - 1]
      for is_stable in [False, True]
      for rng_factory in [jtu.rand_default]))
  def testSortKeyValGrad(self, shape, key_dtype, val_dtype, axis, is_stable,
                         rng_factory):
    rng = rng_factory(self.rng())
    # This test relies on the property that wherever keys are tied, values are
    # too, since we don't guarantee the same ordering of values with equal keys.
    # To avoid that case, we generate unique keys (globally in the key array).
    def args_maker():
      flat_keys = onp.arange(onp.prod(shape, dtype=int), dtype=key_dtype)
      keys = self.rng().permutation(flat_keys).reshape(shape)
      values = rng(shape, val_dtype)
      return keys, values
    keys, values = args_maker()

    fun = lambda keys, values: lax.sort_key_val(keys, values, axis, is_stable)
    check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_k={}".format(
          jtu.format_shape_dtype_string(shape, dtype), k),
       "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k}
      for dtype in [onp.float32,]
      for shape in [(4,), (5, 5), (2, 1, 4)]
      for k in [1, 3]
      for rng_factory in [jtu.rand_default]))
  def testTopKGrad(self, shape, dtype, k, rng_factory):
    flat_values = onp.arange(onp.prod(shape, dtype=int), dtype=dtype)
    values = self.rng().permutation(flat_values).reshape(shape)
    fun = lambda vs: lax.top_k(vs, k=k)[0]
    check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_axes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), idxs, axes),
       "shape": shape, "dtype": dtype, "idxs": idxs, "axes": axes,
       "rng_factory": rng_factory}
      for dtype in float_dtypes
      for shape, idxs, axes in [
          [(3, 4, 5), (onp.array([0, 2, 1]),), (0,)],
          [(3, 4, 5), (onp.array([-1, -2]),), (0,)],
          [(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 1)],
          [(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 2)],
      ]
      for rng_factory in [jtu.rand_default]))
  def testIndexTakeGrad(self, shape, dtype, idxs, axes, rng_factory):
    rng = rng_factory(self.rng())
    src = rng(shape, dtype)
    index_take = lambda src: lax.index_take(src, idxs, axes)
    check_grads(index_take, (src,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), idxs, dnums,
          slice_sizes),
       "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
       "slice_sizes": slice_sizes, "rng_factory": rng_factory,
       "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for shape, idxs, dnums, slice_sizes, max_idx in [
          ((5,), onp.array([[0], [2]]), lax.GatherDimensionNumbers(
            offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1,), 5),
          ((10,), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
            (2,), 9),
          ((10, 5,), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1, 3), 3),
      ]
      for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)]
      for rng_factory in [jtu.rand_default]))
  def testGatherGrad(self, shape, dtype, idxs, dnums, slice_sizes, rng_factory,
                     rng_idx_factory):
    rng = rng_factory(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    gather = lambda x: lax.gather(x, idxs, dimension_numbers=dnums,
                                  slice_sizes=slice_sizes)
    x = rng(shape, dtype)
    check_grads(gather, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory,
       "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for arg_shape, idxs, update_shape, dnums, max_idx in [
          ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,)), 4),
          ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,)), 9),
          ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,)), 3),
      ]
      for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)]
      for rng_factory in [jtu.rand_default]))
  def testScatterAddGrad(self, arg_shape, dtype, idxs, update_shape, dnums,
                         rng_factory, rng_idx_factory):
    rng = rng_factory(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    scatter_add = lambda x, y: lax.scatter_add(x, idxs, y,
                                               dimension_numbers=dnums)
    x = rng(arg_shape, dtype)
    y = rng(update_shape, dtype)
    check_grads(scatter_add, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory,
       "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for arg_shape, idxs, update_shape, dnums, max_idx in [
          ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,)), 4),
          ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,)), 9),
          ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,)), 3),
      ]
      # Scatters with conflicting indices are not deterministic on GPU, so we
      # use indices that do not collide.
      for rng_idx_factory in [partial(jtu.rand_unique_int, high=max_idx)]
      for rng_factory in [jtu.rand_default]))
  def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums,
                      rng_factory, rng_idx_factory):
    rng = rng_factory(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    scatter = lambda x, y: lax.scatter(x, idxs, y, dimension_numbers=dnums)
    x = rng(arg_shape, dtype)
    y = rng(update_shape, dtype)
    check_grads(scatter, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)

  def testScatterGradSymbolicZeroUpdate(self):
    # https://github.com/google/jax/issues/1901
    def f(x):
      n = x.shape[0]
      y = onp.arange(n, dtype=x.dtype)
      return jax.ops.index_update(x, onp.diag_indices(n), y)
    rng = jtu.rand_default(self.rng())
    check_grads(f, (rng((5, 5), onp.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2,
                1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums,
       "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for arg_shape, idxs, update_shape, dnums in [
          ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
          ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,))),
          ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
      ]
      for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))]
      for rng_factory in [jtu.rand_default]))
  def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums,
                     rng_factory, rng_idx_factory):
    rng = rng_factory(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    scatter_max = lambda x, y: lax.scatter_max(x, idxs, y, dnums)
    x = rng(arg_shape, dtype)
    y = rng(update_shape, dtype)
    check_grads(scatter_max, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums,
       "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for arg_shape, idxs, update_shape, dnums in [
          ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
          ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,))),
          ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
      ]
      for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))]
      for rng_factory in [jtu.rand_default]))
  def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums,
                     rng_factory, rng_idx_factory):
    rng = rng_factory(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    scatter_min = lambda x, y: lax.scatter_min(x, idxs, y, dnums)
    x = rng(arg_shape, dtype)
    y = rng(update_shape, dtype)
    check_grads(scatter_min, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)

  def testStopGradient(self):
    def f(x):
      return lax.sin(x) * lax.cos(lax.stop_gradient(x))

    def f2(x, y):
      return lax.sin(x) * lax.cos(y)

    x = 3.14
    ans = api.grad(f)(x)
    expected = api.grad(f2)(x, x)
    self.assertAllClose(ans, expected)

    ans = api.grad(api.grad(f))(x)
    expected = api.grad(api.grad(f2))(x, x)
    self.assertAllClose(ans, expected)

    ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.)
    expected = onp.array(0.0)
    self.assertAllClose(ans, expected, check_dtypes=False)

    with core.skipping_checks():
      with self.assertRaises(TypeError):
        lax.stop_gradient(lambda x: x)

  # TODO(mattjj): make this a more systematic test
  def testRemainder(self):
    rng = onp.random.RandomState(0)
    x = rng.uniform(-0.9, 9, size=(3, 4))
    y = rng.uniform(0.7, 1.9, size=(3, 1))
    assert not set(onp.unique(x)) & set(onp.unique(y))
    tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3
    check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol)

    rng = onp.random.RandomState(0)
    x = rng.uniform(-0.9, 9, size=(1, 4))
    y = rng.uniform(0.7, 1.9, size=(3, 4))
    assert not set(onp.unique(x)) & set(onp.unique(y))
    tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3
    check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol)

  def testHigherOrderGradientOfReciprocal(self):
    # Regression test for https://github.com/google/jax/issues/3136
    def inv(x):
      # N.B.: intentionally written as 1/x, not x ** -1 or reciprocal(x)
      return 1 / x
    grad_fn = jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(inv))))))
    self.assertAllClose(onp.float32(0.0439453125), grad_fn(onp.float32(4.)))
Ejemplo n.º 7
0
    ] +

    # Directly from lax.gather in lax_test.py.
    [
        Harness(
            f"_shape={shape}_idxs_shape={idxs.shape}_dnums={dnums}_slice_sizes={slice_sizes}",
            lambda op, idxs, dnums, slice_sizes: lax.gather(
                op, idxs, dimension_numbers=dnums, slice_sizes=slice_sizes), [
                    RandArg(shape, np.float32), idxs,
                    StaticArg(dnums),
                    StaticArg(slice_sizes)
                ])
        for shape, idxs, dnums, slice_sizes in [
            ((5, ), np.array([[0], [2]]),
             lax.GatherDimensionNumbers(offset_dims=(),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, )),
            ((10, ), np.array([[0], [0], [0]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(),
                                        start_index_map=(0, )), (2, )),
            ((
                10,
                5,
            ), np.array([[0], [2], [1]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, 3)),
            ((10, 5), np.array([[0, 2], [1, 0]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
Ejemplo n.º 8
0
class BatchingTest(jtu.JaxTestCase):
    def testConstantFunction(self):
        ans = vmap(lambda x: 3)(onp.ones(4))
        expected = 3 * onp.ones(4)
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testNestedBatchingMatMat(self):
        matvec = vmap(np.vdot, in_axes=(0, None))
        matmat = vmap(matvec, in_axes=(None, 1), out_axes=1)

        R = onp.random.RandomState(0).randn
        A = R(4, 3)
        B = R(3, 2)

        ans = matmat(A, B)
        expected = onp.dot(A, B)
        self.assertAllClose(ans, expected, check_dtypes=False)

        # this is a crude check that we only call a single dot
        def pv_like(x):
            aval = ShapedArray(onp.shape(x), onp.result_type(x))
            return pe.PartialVal((aval, unit))

        def make_jaxpr(fun, example_args):
            jaxpr, _, _, _ = trace_to_jaxpr(fun, map(pv_like, example_args))
            return jaxpr

        jaxpr = make_jaxpr(matmat, (A, B))
        self.assertEqual(len(jaxpr.eqns), 1)

    def testPerExampleGradients(self):
        def predict(params, inputs):
            for W, b in params:
                outputs = np.dot(W, inputs) + b
                inputs = np.tanh(outputs)
            return outputs

        def loss(params, data):
            inputs, targets = data
            predictions = predict(params, inputs)
            return np.sum((predictions - targets)**2)

        batch_size = 5
        layer_sizes = [3, 2, 4]

        R = onp.random.RandomState(0).randn
        params = [(R(m, n), R(m))
                  for m, n in zip(layer_sizes[1:], layer_sizes[:-1])]

        input_vec = R(3)
        target_vec = R(4)
        datum = (input_vec, target_vec)

        input_batch = R(5, 3)
        target_batch = R(5, 4)
        batch = (input_batch, target_batch)

        ans = vmap(partial(grad(loss), params))(batch)

        for ans_pair, param_pair in zip(ans, params):
            dW, db = ans_pair
            W, b = param_pair

            self.assertEqual(dW.shape, (batch_size, ) + W.shape)
            self.assertEqual(db.shape, (batch_size, ) + b.shape)

    def testJacobians(self):
        def jacbwd(f, x):
            y, pullback = vjp(f, x)
            std_basis = onp.eye(onp.size(y)).reshape((-1, ) + onp.shape(y))
            jac_flat, = vmap(pullback, out_axes=onp.ndim(y))(std_basis)
            return jac_flat.reshape(onp.shape(y) + onp.shape(x))

        def jacfwd(f, x):
            pushfwd = lambda v: jvp(f, (x, ), (v, ))
            std_basis = onp.eye(onp.size(x)).reshape((-1, ) + onp.shape(x))
            y, jac_flat = vmap(pushfwd, out_axes=(None, 0))(std_basis)
            return jac_flat.reshape(onp.shape(y) + onp.shape(x))

        R = onp.random.RandomState(0).randn

        A = R(4, 3)
        b = R(4)
        f = lambda x: np.tanh(np.dot(A, x) + b)

        x = R(3)
        self.assertAllClose(jacfwd(f, x), jacbwd(f, x), check_dtypes=False)

    def testBatchOfCompile(self):
        side = []

        @jit
        def f(x):
            side.append(None)
            return x + x

        g = jit(vmap(f))
        self.assertAllClose(g(onp.ones(2)),
                            2 * onp.ones(2),
                            check_dtypes=False)
        self.assertEqual(len(side), 1)
        self.assertAllClose(g(2 * onp.ones(2)),
                            4 * onp.ones(2),
                            check_dtypes=False)
        self.assertEqual(len(side), 1)

    def testSliceLax(self):
        fun = lambda x: lax.slice(x, (2, ), (4, ))
        R = onp.random.RandomState(0).randn
        x = R(5, 10)

        ans = vmap(fun)(x)
        expected_ans = x[:, 2:4]
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

    def testSliceNumpy(self):
        fun = lambda x: x[:, 2]
        R = onp.random.RandomState(0).randn
        x = R(10, 5, 3, 7)

        ans = vmap(fun)(x)
        expected_ans = x[:, :, 2]
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

    def testRevLax(self):
        fun = lambda x: lax.rev(x, [0])
        R = onp.random.RandomState(0).randn
        x = R(2, 3)

        ans = vmap(fun)(x)
        expected_ans = x[:, ::-1]
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

        ans = vmap(fun, (1, ), 1)(x)
        expected_ans = x[::-1, :]
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

    def testRevNumpy(self):
        fun = lambda x: x[:, ::-1]
        R = onp.random.RandomState(0).randn
        x = R(3, 2, 4)

        ans = vmap(fun)(x)
        expected_ans = x[:, :, ::-1]
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

        ans = vmap(fun, (1, ), 1)(x)
        expected_ans = x[:, :, ::-1]
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

        ans = vmap(fun, (2, ), 2)(x)
        expected_ans = x[:, ::-1, :]
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

    def testNpMaximum(self):
        fun = lambda x: np.maximum(x, 0.0)
        R = onp.random.RandomState(0).randn
        x = R(10, 5, 3, 7)

        ans = vmap(fun)(x)
        expected_ans = onp.maximum(x, 0.0)
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

    def testNpGtrThan(self):
        R = onp.random.RandomState(0).randn
        x = R(10, 5, 3, 7)

        ans = vmap(lambda x: x > 1.0)(x)
        expected_ans = x > 1.0
        self.assertAllClose(ans, expected_ans, check_dtypes=True)

    def testNpMaximumPerExampleGrad(self):
        R = onp.random.RandomState(0).randn
        x = R(10, 5)
        W = R(5, 5)

        fun = lambda W, x: np.sum(np.maximum(np.dot(x, W), 0.0)**2)

        ans = vmap(partial(grad(fun), W))(x)

        W_t = np.transpose(W)
        for i in range(10):
            x_ex = x[i:i + 1]

            expected_ans = 2.0 * np.dot(
                np.maximum(np.dot(W_t, np.transpose(x_ex)), 0.0), x_ex)
            expected_ans = np.transpose(expected_ans)

            self.assertAllClose(ans[i], expected_ans, check_dtypes=False)

    def testDotGeneral(self):
        R = onp.random.RandomState(0).randn

        x = R(10, 3, 4, 5)
        y = R(10, 3, 5, 6)
        fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ),
                                                                   (0, ))])
        ans = vmap(fun)(x, y)
        expected = lax.dot_general(x, y, [((3, ), (2, )), ((0, 1), (0, 1))])
        self.assertAllClose(ans, expected, check_dtypes=True)

        x = R(3, 4, 10, 5)
        y = R(3, 10, 5, 6)
        fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ),
                                                                   (0, ))])
        ans = vmap(fun, in_axes=(2, 1))(x, y)
        fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ),
                                                                   (0, ))])
        expected = onp.stack(
            [fun(x[..., i, :], y[:, i, ...]) for i in range(10)])
        self.assertAllClose(ans, expected, check_dtypes=True)

        x = R(3, 4, 5, 10)
        y = R(3, 5, 6)
        fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ),
                                                                   (0, ))])
        ans = vmap(fun, in_axes=(3, None))(x, y)
        fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ),
                                                                   (0, ))])
        expected = onp.stack([fun(x[..., i], y) for i in range(10)])
        self.assertAllClose(ans, expected, check_dtypes=True)

        x = R(3, 4, 5)
        y = R(3, 5, 10, 6)
        fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ),
                                                                   (0, ))])
        ans = vmap(fun, in_axes=(None, 2))(x, y)
        fun = lambda x, y: lax.dot_general(x, y, [((2, ), (1, )), ((0, ),
                                                                   (0, ))])
        expected = onp.stack([fun(x, y[..., i, :]) for i in range(10)])
        self.assertAllClose(ans, expected, check_dtypes=True)

    def testDot(self):
        # these tests are based on @shoyer's notebook studying gufuncs

        def vecvec(a, b):
            dot = np.dot
            for ndim in range(1, max(a.ndim, b.ndim)):
                a_ax = 0 if a.ndim > ndim else None
                b_ax = 0 if b.ndim > ndim else None
                dot = vmap(dot, in_axes=(a_ax, b_ax))
            return dot(a, b)

        assert vecvec(np.zeros((3, )), np.zeros((3, ))).shape == ()
        assert vecvec(np.zeros((2, 3)), np.zeros((3, ))).shape == (2, )
        # TODO(mattjj): this fails due to an xla error in dot_general
        # assert vecvec(np.zeros((4, 2, 3)), np.zeros((3,))).shape == (4, 2)

    def testPad(self):
        R = onp.random.RandomState(0).randn

        fun = lambda x: lax.pad(x, onp.float32(0), [(1, 2, 1)])
        x = R(5, 10).astype(onp.float32)
        ans = vmap(fun)(x)
        expected_ans = np.stack(list(map(fun, x)))
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

        fun = lambda x: lax.pad(x, onp.float32(0), [(1, 2, 1), (0, 1, 0)])
        x = R(5, 10, 3).astype(onp.float32)
        ans = vmap(fun)(x)
        expected_ans = np.stack(list(map(fun, x)))
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

    def testConcatenate(self):
        R = lambda *shape: onp.random.RandomState(0).randn(*shape).astype(
            onp.float32)

        fun = lambda *args: lax.concatenate(args, dimension=0)
        x, y, z = R(10, 2, 3), R(1, 10, 3), R(4, 3)
        ans = vmap(fun, in_axes=(0, 1, None))(x, y, z)
        expected_ans = onp.concatenate(
            [x, onp.swapaxes(y, 0, 1),
             onp.broadcast_to(z, (10, 4, 3))], 1)
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

        fun = lambda *args: lax.concatenate(args, dimension=1)
        x, y, z = R(10, 2, 1), R(2, 3), R(2, 4, 10)
        ans = vmap(fun, in_axes=(0, None, 2))(x, y, z)
        expected_ans = onp.concatenate(
            [x, onp.broadcast_to(y, (10, 2, 3)),
             onp.moveaxis(z, 2, 0)], 2)
        self.assertAllClose(ans, expected_ans, check_dtypes=False)

    def testJacobianIssue54(self):
        # test modeling the code in https://github.com/google/jax/issues/54

        def func(xs):
            return np.array([x for x in xs])

        xs = np.ones((5, 1))
        jacrev(func)(xs)  # don't crash
        jacfwd(func)(xs)  # don't crash

    def testAny(self):
        # test modeling the code in https://github.com/google/jax/issues/108

        ans = vmap(np.any)(np.array([[True, False], [False, False]]))
        expected = np.array([True, False])
        self.assertAllClose(ans, expected, check_dtypes=True)

    @jtu.skip_on_devices("tpu")
    def testHessian(self):
        # test based on code from sindhwani@google
        def fun(x, t):
            return np.sum(np.power(np.maximum(x, 0.0), 2)) + t

        x = onp.array([-1., -0.5, 0., 0.5, 1.0])

        ans = hessian(lambda x: fun(x, 0.0))(x)
        expected = onp.array([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.],
                              [0., 0., 0.5, 0., 0.], [0., 0., 0., 2., 0.],
                              [0., 0., 0., 0., 2.]])
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testDynamicSlice(self):
        # test dynamic_slice via numpy indexing syntax
        x = onp.arange(30).reshape((10, 3))

        ans = vmap(lambda x, i: x[i], in_axes=(0, None))(x, 1)
        expected = x[:, 1]
        self.assertAllClose(ans, expected, check_dtypes=False)

        idx = onp.array([0, 1, 2, 1, 0] * 2)
        ans = vmap(lambda x, i: x[i], in_axes=(0, 0))(x, idx)
        expected = x[onp.arange(10), idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

        x = onp.arange(3)
        idx = onp.array([0, 1, 2, 1, 0] * 2)
        ans = vmap(lambda x, i: x[i], in_axes=(None, 0))(x, idx)
        expected = x[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testRandom(self):
        seeds = vmap(random.PRNGKey)(onp.arange(10))
        ans = vmap(partial(random.normal, shape=(3, 2)))(seeds)
        expected = onp.stack([
            random.normal(random.PRNGKey(seed), (3, 2))
            for seed in onp.arange(10)
        ])
        self.assertAllClose(ans, expected, check_dtypes=False)
        assert len(onp.unique(ans)) == 10 * 3 * 2

    def testSortKeyVal(self):
        k = onp.arange(12)[::-1].reshape(3, 4)
        v = onp.random.RandomState(0).permutation(12).reshape(3, 4)

        sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 0))(k, v)
        self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
        self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)

        sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 1), 1)(k, v)
        self.assertAllClose(sk, k[::-1, :], check_dtypes=True)
        self.assertAllClose(sv, v[::-1, :], check_dtypes=True)

        sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 1))(k, v.T)
        self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
        self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)

        sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 0))(k.T, v)
        self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
        self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)

        sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (None, 0))(k[0],
                                                                         v)
        self.assertAllClose(sk,
                            onp.broadcast_to(k[0, ::-1], (3, 4)),
                            check_dtypes=True)
        self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)

        sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, None))(k.T,
                                                                         v[0])
        self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
        self.assertAllClose(sv,
                            onp.broadcast_to(v[0, ::-1], (3, 4)),
                            check_dtypes=True)

    def testConvGeneralDilated(self):
        W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32)
        X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32)

        def f(params, x):
            one = (1, 1)
            dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
            y = lax.conv_general_dilated(x, params, one, 'SAME', one, one,
                                         dimension_numbers)
            return y

        grad_loss = grad(lambda params, x: np.mean(f(params, x)**2))

        # Test forward prop.
        per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1)))
        per_example = np.reshape(per_example, (10, 5, 5, 5))
        per_example_direct = f(W, X)
        self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

        # Test gradients.
        per_example = vmap(partial(grad_loss,
                                   W))(np.reshape(X, (10, 1, 5, 5, 1)))
        per_example_direct = []
        for i in range(10):
            g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1)))
            per_example_direct += [np.reshape(g, (1, ) + g.shape)]
        per_example_direct = np.concatenate(per_example_direct, axis=0)
        self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

    def testMaxPool(self):
        W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32)
        X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32)

        def f(params, x):
            one = (1, 1)
            dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
            y = lax.conv_general_dilated(x, params, one, 'SAME', one, one,
                                         dimension_numbers)
            y = lax.reduce_window(y, -np.inf, lax.max, (1, 2, 2, 1),
                                  (1, 1, 1, 1), 'SAME')
            return y

        grad_loss = grad(lambda params, x: np.mean(f(params, x)**2))

        # Test forward prop.
        per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1)))
        per_example = np.reshape(per_example, (10, 5, 5, 5))
        per_example_direct = f(W, X)
        self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

        # Test gradients.
        per_example = vmap(partial(grad_loss,
                                   W))(np.reshape(X, (10, 1, 5, 5, 1)))
        per_example_direct = []
        for i in range(10):
            g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1)))
            per_example_direct += [np.reshape(g, (1, ) + g.shape)]
        per_example_direct = np.concatenate(per_example_direct, axis=0)
        self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

    def testSumPool(self):
        W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32)
        X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32)

        def f(params, x):
            one = (1, 1)
            dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
            y = lax.conv_general_dilated(x, params, one, 'SAME', one, one,
                                         dimension_numbers)
            y = lax.reduce_window(y, 0.0, lax.add, (1, 2, 2, 1), (1, 1, 1, 1),
                                  'SAME')
            return y

        grad_loss = grad(lambda params, x: np.mean(f(params, x)**2))

        # Test forward prop.
        per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1)))
        per_example = np.reshape(per_example, (10, 5, 5, 5))
        per_example_direct = f(W, X)
        self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

        # Test gradients.
        per_example = vmap(partial(grad_loss,
                                   W))(np.reshape(X, (10, 1, 5, 5, 1)))
        per_example_direct = []
        for i in range(10):
            g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1)))
            per_example_direct += [np.reshape(g, (1, ) + g.shape)]
        per_example_direct = np.concatenate(per_example_direct, axis=0)
        self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

    def testSelect(self):
        pred = onp.array([True, False])
        on_true = onp.array([0, 1])
        on_false = onp.array([2, 3])
        ans = vmap(lax.select)(pred, on_true, on_false)
        expected = onp.array([0, 3])
        self.assertAllClose(ans, expected, check_dtypes=True)

        pred = onp.array([False, True])
        on_true = onp.array([0, 1])
        on_false = onp.array([2, 3])
        ans = vmap(lax.select, (0, None, None))(pred, on_true, on_false)
        expected = onp.array([[2, 3], [0, 1]])
        self.assertAllClose(ans, expected, check_dtypes=True)

        pred = True
        on_true = onp.array([0, 1], onp.float32)
        on_false = onp.array(3, onp.float32)
        ans = vmap(lax.select, (None, 0, None))(pred, on_true, on_false)
        expected = onp.array([0, 1], onp.float32)
        self.assertAllClose(ans, expected, check_dtypes=True)

        pred = onp.array([False, True])
        on_true = onp.array([0, 1], onp.float32)
        on_false = onp.array(3, onp.float32)
        ans = vmap(lax.select, (0, 0, None))(pred, on_true, on_false)
        expected = onp.array([3, 1], onp.float32)
        self.assertAllClose(ans, expected, check_dtypes=True)

        pred = onp.array([False, True])
        on_true = onp.array([2], onp.float32)
        on_false = onp.array([[3, 4]], onp.float32)
        ans = vmap(lax.select, (0, None, 1), 1)(pred, on_true, on_false)
        expected = onp.array([[3, 2]], onp.float32)
        self.assertAllClose(ans, expected, check_dtypes=True)

    def testLaxLinalgCholesky(self):
        a = onp.random.RandomState(0).randn(10, 5, 5).astype(onp.float32)
        a = onp.matmul(a, onp.conj(onp.swapaxes(a, -1, -2)))

        ans = vmap(lax_linalg.cholesky)(a)
        expected = onp.linalg.cholesky(a)
        self.assertAllClose(ans, expected, check_dtypes=False)

        b = onp.random.RandomState(0).randn(10, 5, 5).astype(onp.float32)
        b = onp.matmul(b, onp.conj(onp.swapaxes(b, -1, -2)))
        b_trans = onp.swapaxes(b, 0, 1)  # shape is (5, 10, 5)

        ans = vmap(lax_linalg.cholesky, in_axes=1, out_axes=0)(b_trans)
        expected = onp.linalg.cholesky(b)
        self.assertAllClose(ans, expected, check_dtypes=False)

    @parameterized.named_parameters(
        {
            "testcase_name":
            "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
                jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
                slice_sizes),
            "axis":
            axis,
            "shape":
            shape,
            "dtype":
            dtype,
            "idxs":
            idxs,
            "dnums":
            dnums,
            "slice_sizes":
            slice_sizes,
            "rng":
            rng,
            "rng_idx":
            rng_idx
        } for dtype in [onp.float32, onp.int32]
        for axis, shape, idxs, dnums, slice_sizes in [
            (0, (3, 5), onp.array([0, 2]),
             lax.GatherDimensionNumbers(offset_dims=(),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, ),
                                        index_vector_dim=1), (1, )),
            (1, (10, 3), onp.array([0, 0, 0]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(),
                                        start_index_map=(0, ),
                                        index_vector_dim=1), (2, )),
            (1, (10, 3, 5), onp.array([0, 2, 1]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, ),
                                        index_vector_dim=1), (1, 3)),
            (2, (10, 5, 3), onp.array([[0, 2], [1, 0]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, 1),
                                        index_vector_dim=1), (1, 3)),
        ] for rng_idx in [jtu.rand_int(max(shape))]
        for rng in [jtu.rand_default()])
    def testGatherBatchedOperand(self, axis, shape, dtype, idxs, dnums,
                                 slice_sizes, rng, rng_idx):
        fun = partial(lax.gather,
                      dimension_numbers=dnums,
                      slice_sizes=slice_sizes)
        operand = rng(shape, dtype)
        ans = vmap(fun, (axis, None))(operand, idxs)
        expected = onp.stack([
            fun(operand[(slice(None), ) * axis + (i, )], idxs)
            for i in range(operand.shape[axis])
        ])
        self.assertAllClose(ans, expected, check_dtypes=False)

    @parameterized.named_parameters(
        {
            "testcase_name":
            "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
                jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
                slice_sizes),
            "axis":
            axis,
            "shape":
            shape,
            "dtype":
            dtype,
            "idxs":
            idxs,
            "dnums":
            dnums,
            "slice_sizes":
            slice_sizes,
            "rng":
            rng,
            "rng_idx":
            rng_idx
        } for dtype in [onp.float32, onp.float64]
        for axis, shape, idxs, dnums, slice_sizes in [
            (0, (3, 5), onp.array([0, 2]),
             lax.GatherDimensionNumbers(offset_dims=(),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, ),
                                        index_vector_dim=1), (1, )),
            (1, (10, 3), onp.array([0, 0, 0]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(),
                                        start_index_map=(0, ),
                                        index_vector_dim=1), (2, )),
            (1, (10, 3, 5), onp.array([0, 2, 1]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, ),
                                        index_vector_dim=1), (1, 3)),
            (2, (10, 5, 3), onp.array([[0, 2], [1, 0]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, 1),
                                        index_vector_dim=1), (1, 3)),
        ] for rng_idx in [jtu.rand_int(max(shape))]
        for rng in [jtu.rand_default()])
    def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums,
                                     slice_sizes, rng, rng_idx):
        fun = partial(lax.gather,
                      dimension_numbers=dnums,
                      slice_sizes=slice_sizes)
        gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx))))
        operand = rng(shape, dtype)
        ans = vmap(gfun, (axis, None))(operand, idxs)
        expected = onp.stack([
            gfun(operand[(slice(None), ) * axis + (i, )], idxs)
            for i in range(operand.shape[axis])
        ])
        self.assertAllClose(ans, expected, check_dtypes=False)

    @parameterized.named_parameters(
        {
            "testcase_name":
            "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
                jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
                slice_sizes),
            "axis":
            axis,
            "shape":
            shape,
            "dtype":
            dtype,
            "idxs":
            idxs,
            "dnums":
            dnums,
            "slice_sizes":
            slice_sizes,
            "rng":
            rng,
            "rng_idx":
            rng_idx
        } for dtype in [onp.float32, onp.int32]
        for axis, shape, idxs, dnums, slice_sizes in [
            (0, (5, ), onp.array([[0, 2], [1, 3]]),
             lax.GatherDimensionNumbers(offset_dims=(),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, ),
                                        index_vector_dim=1), (1, )),
            (1, (10, ), onp.array([[0, 0, 0], [0, 2, 1]]).T,
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(),
                                        start_index_map=(0, ),
                                        index_vector_dim=1), (2, )),
            (1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T,
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, ),
                                        index_vector_dim=1), (1, 3)),
            (0, (10, 5), onp.array([[[0, 2], [1, 0]], [[1, 2], [0, 3]]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, 1),
                                        index_vector_dim=1), (1, 3)),
        ] for rng_idx in [jtu.rand_int(max(shape))]
        for rng in [jtu.rand_default()])
    def testGatherBatchedIndices(self, axis, shape, dtype, idxs, dnums,
                                 slice_sizes, rng, rng_idx):
        fun = partial(lax.gather,
                      dimension_numbers=dnums,
                      slice_sizes=slice_sizes)
        operand = rng(shape, dtype)
        ans = vmap(fun, (None, axis))(operand, idxs)
        expected = onp.stack([
            fun(operand, idxs[(slice(None), ) * axis + (i, )])
            for i in range(idxs.shape[axis])
        ])
        self.assertAllClose(ans, expected, check_dtypes=False)

    @parameterized.named_parameters(
        {
            "testcase_name":
            "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
                jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
                slice_sizes),
            "axis":
            axis,
            "shape":
            shape,
            "dtype":
            dtype,
            "idxs":
            idxs,
            "dnums":
            dnums,
            "slice_sizes":
            slice_sizes,
            "rng":
            rng,
            "rng_idx":
            rng_idx
        } for dtype in [onp.float32, onp.float64]
        for axis, shape, idxs, dnums, slice_sizes in [
            (0, (5, ), onp.array([[0, 2], [1, 3]]),
             lax.GatherDimensionNumbers(offset_dims=(),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, ),
                                        index_vector_dim=1), (1, )),
            (1, (10, ), onp.array([[0, 0, 0], [0, 2, 1]]).T,
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(),
                                        start_index_map=(0, ),
                                        index_vector_dim=1), (2, )),
            (1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T,
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, ),
                                        index_vector_dim=1), (1, 3)),
            (0, (10, 5), onp.array([[[0, 2], [1, 0]], [[1, 2], [0, 3]]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, 1),
                                        index_vector_dim=1), (1, 3)),
        ] for rng_idx in [jtu.rand_int(max(shape))]
        for rng in [jtu.rand_default()])
    def testGatherGradBatchedIndices(self, axis, shape, dtype, idxs, dnums,
                                     slice_sizes, rng, rng_idx):
        fun = partial(lax.gather,
                      dimension_numbers=dnums,
                      slice_sizes=slice_sizes)
        gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx))))
        operand = rng(shape, dtype)
        ans = vmap(gfun, (None, axis))(operand, idxs)
        expected = onp.stack([
            gfun(operand, idxs[(slice(None), ) * axis + (i, )])
            for i in range(idxs.shape[axis])
        ])
        self.assertAllClose(ans, expected, check_dtypes=False)

    @parameterized.named_parameters(
        {
            "testcase_name":
            "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}"
            .format(jtu.format_shape_dtype_string(shape, dtype), op_axis,
                    idxs_axis, idxs, dnums, slice_sizes),
            "op_axis":
            op_axis,
            "idxs_axis":
            idxs_axis,
            "shape":
            shape,
            "dtype":
            dtype,
            "idxs":
            idxs,
            "dnums":
            dnums,
            "slice_sizes":
            slice_sizes,
            "rng":
            rng,
            "rng_idx":
            rng_idx
        } for dtype in [onp.float32, onp.int32]
        for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
            (0, 0, (2, 5), onp.array([[0, 2], [1, 3]]),
             lax.GatherDimensionNumbers(offset_dims=(),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, ),
                                        index_vector_dim=1), (1, )),
            (1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T,
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(),
                                        start_index_map=(0, ),
                                        index_vector_dim=1), (2, )),
            (0, 1, (
                2,
                10,
                5,
            ), onp.array([[0, 2, 1], [0, 3, 3]]).T,
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, ),
                                        index_vector_dim=1), (1, 3)),
            (2, 0, (10, 5, 2), onp.array([[[0, 2], [1, 0]], [[1, 0], [2, 0]]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, 1),
                                        index_vector_dim=1), (1, 3)),
        ] for rng_idx in [jtu.rand_int(max(shape))]
        for rng in [jtu.rand_default()])
    def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs,
                              dnums, slice_sizes, rng, rng_idx):
        fun = partial(lax.gather,
                      dimension_numbers=dnums,
                      slice_sizes=slice_sizes)
        operand = rng(shape, dtype)
        assert operand.shape[op_axis] == idxs.shape[idxs_axis]
        ans = vmap(fun, (op_axis, idxs_axis))(operand, idxs)
        expected = onp.stack([
            fun(operand[(slice(None), ) * op_axis + (i, )],
                idxs[(slice(None), ) * idxs_axis + (i, )])
            for i in range(idxs.shape[idxs_axis])
        ])
        self.assertAllClose(ans, expected, check_dtypes=False)

    @parameterized.named_parameters(
        {
            "testcase_name":
            "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}"
            .format(jtu.format_shape_dtype_string(shape, dtype), op_axis,
                    idxs_axis, idxs, dnums, slice_sizes),
            "op_axis":
            op_axis,
            "idxs_axis":
            idxs_axis,
            "shape":
            shape,
            "dtype":
            dtype,
            "idxs":
            idxs,
            "dnums":
            dnums,
            "slice_sizes":
            slice_sizes,
            "rng":
            rng,
            "rng_idx":
            rng_idx
        } for dtype in [onp.float32, onp.int32]
        for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
            (0, 0, (2, 5), onp.array([[0, 2], [1, 3]]),
             lax.GatherDimensionNumbers(offset_dims=(),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, ),
                                        index_vector_dim=1), (1, )),
            (1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T,
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(),
                                        start_index_map=(0, ),
                                        index_vector_dim=1), (2, )),
            (0, 1, (
                2,
                10,
                5,
            ), onp.array([[0, 2, 1], [0, 3, 3]]).T,
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, ),
                                        index_vector_dim=1), (1, 3)),
            (2, 0, (10, 5, 2), onp.array([[[0, 2], [1, 0]], [[1, 0], [2, 0]]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, 1),
                                        index_vector_dim=1), (1, 3)),
        ] for rng_idx in [jtu.rand_int(max(shape))]
        for rng in [jtu.rand_default()])
    def testGatherGradBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs,
                                  dnums, slice_sizes, rng, rng_idx):
        fun = partial(lax.gather,
                      dimension_numbers=dnums,
                      slice_sizes=slice_sizes)
        gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx))))
        operand = rng(shape, dtype)
        assert operand.shape[op_axis] == idxs.shape[idxs_axis]
        ans = vmap(gfun, (op_axis, idxs_axis))(operand, idxs)
        expected = onp.stack([
            gfun(operand[(slice(None), ) * op_axis + (i, )],
                 idxs[(slice(None), ) * idxs_axis + (i, )])
            for i in range(idxs.shape[idxs_axis])
        ])
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testNumpyIndexing1(self):
        a = np.arange(2 * 3 * 4).reshape((2, 3, 4))
        ind = onp.array([[0, 1], [2, 0]])

        def f(a, ind):
            return a[:, ind]

        expected = onp.stack([f(a, ind[i, :]) for i in range(ind.shape[0])])
        ans = vmap(f, (None, 0))(a, ind)
        assert onp.all(ans == expected)

    def testNumpyIndexing2(self):
        a = np.arange(2 * 3 * 4).reshape((2, 3, 4))

        def f(a):
            inds = np.array([0, 2])
            return a[:, inds]

        ans = vmap(f)(a)
        expected = onp.stack([f(a[:, i, :]) for i in range(a.shape[1])],
                             axis=1)
        assert onp.all(ans == expected)

    def testTranspose(self):
        x = onp.arange(4 * 3 * 3).reshape((4, 3, 3))
        ans = vmap(lambda x: x + x.T)(x)
        expected = x + onp.swapaxes(x, -1, -2)
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testTransposePermutation(self):
        x = onp.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5))
        ans = vmap(lambda x: np.transpose(x, (1, 0, 2)))(x)
        expected = onp.transpose(x, (0, 2, 1, 3))
        self.assertAllClose(ans, expected, check_dtypes=False)

        x = onp.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5))
        ans = vmap(lambda x: np.transpose(x, (1, 2, 0)))(x)
        expected = onp.transpose(x, (0, 2, 3, 1))
        self.assertAllClose(ans, expected, check_dtypes=False)

        x = onp.arange(6 * 3 * 4 * 5).reshape((3, 4, 6, 5))
        ans = vmap(lambda x: np.transpose(x, (1, 2, 0)), in_axes=2)(x)
        expected = onp.transpose(x, (2, 1, 3, 0))
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testIssue354(self):
        psd_mat = onp.random.randn(20, 10)
        psd_mat = psd_mat.T.dot(psd_mat)
        vec = onp.random.randn(10)

        def f(scale):
            scaled_mat = scale * psd_mat
            chol = np.linalg.cholesky(scaled_mat)
            return -0.5 * np.sum((np.einsum('ij,j->i', chol, vec))**2)

        vmapped_f = vmap(f)
        vmapped_f_grad = grad(lambda x: np.sum(vmapped_f(x)))

        scales = onp.array([[0.1], [0.2], [0.3], [0.4], [0.5]])
        ans = vmapped_f_grad(scales)  # don't crash!
        expected = onp.stack([grad(f)(scale) for scale in scales])
        self.assertAllClose(ans, expected, check_dtypes=False)
Ejemplo n.º 9
0
    def predict(self, params, logits, context, target=None):
        context = jnp.expand_dims(jnp.expand_dims(jnp.expand_dims(context,
                                                                  axis=1),
                                                  axis=1),
                                  axis=1)
        context_bias = params.get('context_bias', 0.0)
        context_index = (params['context_maps'] *
                         context).sum(axis=-1) > context_bias

        context_map_values = jnp.asarray(
            [[[[1 << n for n in range(self.context_map_size)]]]])
        context_index = jnp.where(context_index, context_map_values, 0)
        context_index = context_index.sum(axis=-1, keepdims=True)

        batch_size = logits.shape[0]
        class_neuron_index = jnp.asarray([[[[c, n] for n in range(self.size)]
                                           for c in range(self.num_classes)]])
        class_neuron_index = jnp.tile(class_neuron_index,
                                      reps=(batch_size, 1, 1, 1))
        context_index = jnp.concatenate([class_neuron_index, context_index],
                                        axis=-1)

        dims = lax.GatherDimensionNumbers(offset_dims=(3, ),
                                          collapsed_slice_dims=(0, 1, 2),
                                          start_index_map=(0, 1, 2))
        weights = lax.gather(operand=params['weights'],
                             start_indices=context_index,
                             dimension_numbers=dims,
                             slice_sizes=(1, 1, 1,
                                          self.input_size + int(self.bias)))

        if self.bias:
            bias = jnp.tile(params['bias'], reps=(batch_size, 1, 1))
            logits = jnp.concatenate([logits, bias], axis=-1)
        logits = jnp.expand_dims(logits, axis=-1)

        output_logits = jnp.matmul(weights, logits)
        output_logits = jnp.clip(output_logits,
                                 a_min=jsp.special.logit(self.pred_clipping),
                                 a_max=jsp.special.logit(1.0 -
                                                         self.pred_clipping))

        if target is None:
            return jnp.squeeze(output_logits, axis=-1)

        else:
            logits = jnp.expand_dims(jnp.squeeze(logits, axis=-1), axis=-2)
            output_preds = jnn.sigmoid(output_logits)
            target = jnp.expand_dims(jnp.expand_dims(target, axis=-1), axis=-1)
            params['lr_step'], learning_rate = self.learning_rate.value(
                params['lr_step'])
            delta = learning_rate * (target - output_preds) * logits

            dims = lax.ScatterDimensionNumbers(
                update_window_dims=(3, ),
                inserted_window_dims=(0, 1, 2),
                scatter_dims_to_operand_dims=(0, 1, 2))

            if self.weight_clipping is None:
                params['weights'] = lax.scatter_add(
                    operand=params['weights'],
                    scatter_indices=context_index,
                    updates=delta,
                    dimension_numbers=dims)
            else:
                weights = jnp.clip(weights + delta,
                                   a_min=-self.weight_clipping,
                                   a_max=self.weight_clipping)
                params['weights'] = lax.scatter(operand=params['weights'],
                                                scatter_indices=context_index,
                                                updates=weights,
                                                dimension_numbers=dims)

            return params, jnp.squeeze(output_logits, axis=-1)