def test_batch_gather_by_one_hot_with_indices_out_of_range(self):
        params = tf.constant([
            [
                [0.0, 0.0],  #
                [0.1, -0.1],  #
                [0.2, -0.2],  #
                [0.3, -0.3],  #
            ],
            [
                [1.0, -1.0],  #
                [1.1, -1.1],  #
                [1.2, -1.2],  #
                [1.3, -1.3],  #
            ]
        ])

        self.assertAllClose(
            [
                [
                    [0.2, -0.2],  #
                    [0.0, 0.0],  #
                    [0.1, -0.1],  #
                ],
                [
                    [1.0, -1.0],  #
                    [1.3, -1.3],  #
                    [0.0, 0.0],  #
                ]
            ],
            tensor_utils.batch_gather_by_one_hot(params,
                                                 [[2, -1, 1], [0, 3, 4]]))
    def test_batch_gather_by_one_hot_2_batch_dims_plus_2d_params(
            self, gather_fn):
        params = tf.constant([
            [
                [
                    [0.0, 0.0],  #
                    [0.1, -0.1],  #
                    [0.2, -0.2],  #
                    [0.3, -0.3],  #
                ],
                [
                    [1.0, -1.0],  #
                    [1.1, -1.1],  #
                    [1.2, -1.2],  #
                    [1.3, -1.3],  #
                ]
            ],
            [
                [
                    [2.0, -2.0],  #
                    [2.1, -2.1],  #
                    [2.2, -2.2],  #
                    [2.3, -2.3],  #
                ],
                [
                    [3.0, -3.0],  #
                    [3.1, -3.1],  #
                    [3.2, -3.2],  #
                    [3.3, -3.3],  #
                ]
            ]
        ])

        # This case doesn't work for `tf.gather` so we don't use `gather_fn`.
        self.assertAllClose(
            [
                [[0.2, -0.2], [1.1, -1.1]],  #
                [[2.2, -2.2], [3.3, -3.3]]
            ],
            tensor_utils.batch_gather_by_one_hot(params, [[2, 1], [2, 3]],
                                                 batch_dims=2))

        self.assertAllClose(
            [
                [
                    [[0.2, -0.2], [0.0, 0.0]],  #
                    [[1.2, -1.2], [1.3, -1.3]]
                ],
                [
                    [[2.1, -2.1], [2.0, -2.0]],  #
                    [[3.2, -3.2], [3.1, -3.1]]
                ]
            ],
            gather_fn(
                params,
                [
                    [[2, 0], [2, 3]],  #
                    [[1, 0], [2, 1]]
                ],
                batch_dims=2))
    def test_batch_gather_by_one_hot_1_batch_dim_plus_1d_params(
            self, gather_fn):
        params = tf.constant([
            [0.0, 0.1, 0.2, 0.3],  #
            [1.0, 1.1, 1.2, 1.3],  #
        ])

        # This case doesn't work for `tf.gather` so we don't use `gather_fn`.
        self.assertAllClose([0.2, 1.1],
                            tensor_utils.batch_gather_by_one_hot(params,
                                                                 [2, 1],
                                                                 batch_dims=1))

        self.assertAllClose([[0.2, 0.0], [1.2, 1.3]],
                            gather_fn(params, [[2, 0], [2, 3]], batch_dims=1))

        self.assertAllClose(
            [
                [[0.2, 0], [0.2, 0.3]],  #
                [[1.1, 1.0], [1.2, 1.1]]
            ],
            gather_fn(
                params,
                [
                    [[2, 0], [2, 3]],  #
                    [[1, 0], [2, 1]]
                ],
                batch_dims=1))
    def test_batch_gather_by_one_hot_with_default_batch_dim(self):
        params = tf.constant([
            [
                [0.0, 0.0],  #
                [0.1, -0.1],  #
                [0.2, -0.2],  #
                [0.3, -0.3],  #
            ],
            [
                [1.0, -1.0],  #
                [1.1, -1.1],  #
                [1.2, -1.2],  #
                [1.3, -1.3],  #
            ]
        ])

        # `batch_dim` should be inferred as 1.
        self.assertAllClose(
            [
                [[0.2, -0.2], [0.0, 0.0]],  #
                [[1.2, -1.2], [1.3, -1.3]]
            ],
            tensor_utils.batch_gather_by_one_hot(params, [[2, 0], [2, 3]]))