def test_gather_vectorized(self):
        values = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
        index = IndexMap(indices=tf.convert_to_tensor([[0, 1], [1, 0]]), num_segments=2, batch_dims=1)
        result = gather(values, index)

        # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual
        np.testing.assert_array_equal(result.numpy(), [[[1, 2], [3, 4]], [[7, 8], [5, 6]]])
    def test_gather(self):
        values, row_index, col_index = self._prepare_tables()
        cell_index = ProductIndexMap(row_index, col_index)

        # Compute sums and then gather. The result should have the same shape as
        # the original table and each element should contain the sum the values in
        # its cell.
        sums, _ = reduce_sum(values, cell_index)
        cell_sum = gather(sums, cell_index)
        assert cell_sum.shape == values.shape

        # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual
        np.testing.assert_allclose(
            cell_sum.numpy(),
            [[[3.0, 3.0, 3.0], [2.0, 2.0, 1.0], [4.0, 4.0, 4.0]], [[1.0, 2.0, 3.0], [2.0, 0.0, 1.0], [1.0, 3.0, 4.0]]],
        )