def test_convert(self):
     max_seq_length = 12
     with tempfile.TemporaryDirectory() as input_dir:
         vocab_file = os.path.join(input_dir, 'vocab.txt')
         _create_vocab(vocab_file, range(10))
         converter = tf_example_utils.ToClassifierTensorflowExample(
             config=tf_example_utils.ClassifierConversionConfig(
                 vocab_file=vocab_file,
                 max_seq_length=max_seq_length,
                 max_column_id=max_seq_length,
                 max_row_id=max_seq_length,
                 strip_column_names=False,
                 add_aggregation_candidates=False,
             ))
         interaction = interaction_pb2.Interaction(
             table=interaction_pb2.Table(
                 columns=[
                     interaction_pb2.Cell(text='A'),
                     interaction_pb2.Cell(text='B'),
                     interaction_pb2.Cell(text='C'),
                 ],
                 rows=[
                     interaction_pb2.Cells(cells=[
                         interaction_pb2.Cell(text='0'),
                         interaction_pb2.Cell(text='4'),
                         interaction_pb2.Cell(text='5'),
                     ]),
                     interaction_pb2.Cells(cells=[
                         interaction_pb2.Cell(text='1'),
                         interaction_pb2.Cell(text='3'),
                         interaction_pb2.Cell(text='5'),
                     ]),
                 ],
             ),
             questions=[
                 interaction_pb2.Question(id='id', original_text='2')
             ],
         )
         number_annotation_utils.add_numeric_values(interaction)
         example = converter.convert(interaction, 0)
         logging.info(example)
         self.assertEqual(_get_int_feature(example, 'input_ids'),
                          [2, 8, 3, 1, 1, 1, 6, 10, 11, 7, 9, 11])
         self.assertEqual(_get_int_feature(example, 'row_ids'),
                          [0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 2])
         self.assertEqual(_get_int_feature(example, 'column_ids'),
                          [0, 0, 0, 1, 2, 3, 1, 2, 3, 1, 2, 3])
         self.assertEqual(_get_int_feature(example, 'column_ranks'),
                          [0, 0, 0, 0, 0, 0, 1, 2, 1, 2, 1, 1])
         self.assertEqual(_get_int_feature(example, 'numeric_relations'),
                          [0, 0, 0, 0, 0, 0, 4, 2, 2, 4, 2, 2])
         self.assertEqual(
             _get_float_feature(example, 'question_numeric_values'),
             _clean_nans([2.0] + [_NAN] * (_MAX_NUMERIC_VALUES - 1)))
Beispiel #2
0
 def test_convert_with_context_heading(self):
     max_seq_length = 20
     with tempfile.TemporaryDirectory() as input_dir:
         vocab_file = os.path.join(input_dir, 'vocab.txt')
         _create_vocab(vocab_file, ['a', 'b', 'c', 'd', 'e'])
         converter = tf_example_utils.ToClassifierTensorflowExample(
             config=tf_example_utils.ClassifierConversionConfig(
                 vocab_file=vocab_file,
                 max_seq_length=max_seq_length,
                 max_column_id=max_seq_length,
                 max_row_id=max_seq_length,
                 strip_column_names=False,
                 add_aggregation_candidates=False,
                 use_document_title=True,
                 use_context_title=True,
                 update_answer_coordinates=True,
             ))
         interaction = interaction_pb2.Interaction(
             table=interaction_pb2.Table(
                 document_title='E E',
                 columns=[
                     interaction_pb2.Cell(text='A'),
                     interaction_pb2.Cell(text='A B C'),
                 ],
                 rows=[
                     interaction_pb2.Cells(cells=[
                         interaction_pb2.Cell(text='A B'),
                         interaction_pb2.Cell(text='A B C'),
                     ]),
                 ],
                 context_heading='B',
             ),
             questions=[
                 interaction_pb2.Question(
                     id='id',
                     original_text='D',
                     answer=interaction_pb2.Answer(answer_texts=['B C']),
                 )
             ],
         )
         example = converter.convert(interaction, 0)
         logging.info(example)
         self.assertEqual(
             _get_int_feature(example, 'input_ids'),
             [2, 5, 3, 10, 10, 3, 7, 3, 6, 6, 7, 8, 6, 7, 6, 7, 8, 0, 0, 0])
         self.assertEqual(
             _get_int_feature(example, 'label_ids'),
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0])
    def test_parse_cell(self):
        cell = interaction_pb2.Cell()
        hybridqa_utils._parse_cell(cell,
                                   text='Hello World',
                                   links=['/wiki/World'],
                                   descriptions={'/wiki/World': '...'})

        self.assertEqual(
            cell,
            text_format.Parse(
                """
    text: "Hello World"
    [language.tapas.AnnotatedText.annotated_cell_ext] {
      annotations {
        identifier: "/wiki/World"
      }
    }
    """, interaction_pb2.Cell()))
Beispiel #4
0
 def test_parse_answer_cell_unicode_without_unquote(self):
     cell = interaction_pb2.Cell()
     with self.assertRaises(ValueError):
         hybridqa_rc_utils._parse_answer_cell(
             cell,
             text='Zurich, Switzerland',
             links=['/wiki/Z%C3%BCrich'],
             descriptions={'/wiki/Zürich': '...'},
             url_unquote=False)
Beispiel #5
0
    def test_parse_answer_cell_unicode_1(self):
        cell = interaction_pb2.Cell()
        hybridqa_rc_utils._parse_answer_cell(
            cell,
            text='Zurich, Switzerland',
            links=['/wiki/Z%C3%BCrich'],
            descriptions={'/wiki/Zürich': '...'},
            url_unquote=True)

        self.assertEqual(
            cell,
            text_format.Parse(
                """
    text: "Zurich, Switzerland : ..."
    [language.tapas.AnnotatedText.annotated_cell_ext] {
      annotations {
        identifier: "/wiki/Zürich"
      }
    }
    """, interaction_pb2.Cell()))
Beispiel #6
0
 def test_convert_with_trimmed_cell(self):
     max_seq_length = 16
     with tempfile.TemporaryDirectory() as input_dir:
         vocab_file = os.path.join(input_dir, 'vocab.txt')
         _create_vocab(vocab_file, range(10))
         converter = tf_example_utils.ToClassifierTensorflowExample(
             config=tf_example_utils.ClassifierConversionConfig(
                 vocab_file=vocab_file,
                 max_seq_length=max_seq_length,
                 max_column_id=max_seq_length,
                 max_row_id=max_seq_length,
                 strip_column_names=False,
                 add_aggregation_candidates=False,
                 cell_trim_length=2,
                 drop_rows_to_fit=True))
         interaction = interaction_pb2.Interaction(
             table=interaction_pb2.Table(
                 columns=[
                     interaction_pb2.Cell(text='A'),
                     interaction_pb2.Cell(text='A A'),
                     interaction_pb2.Cell(text='A A A A'),
                 ],
                 rows=[
                     interaction_pb2.Cells(cells=[
                         interaction_pb2.Cell(text='A A A'),
                         interaction_pb2.Cell(text='A A A'),
                         interaction_pb2.Cell(text='A A A'),
                     ]),
                     interaction_pb2.Cells(cells=[
                         interaction_pb2.Cell(text='A A A'),
                         interaction_pb2.Cell(text='A A A'),
                         interaction_pb2.Cell(text='A A A'),
                     ]),
                 ],
             ),
             questions=[
                 interaction_pb2.Question(id='id', original_text='A')
             ],
         )
         number_annotation_utils.add_numeric_values(interaction)
         example = converter.convert(interaction, 0)
         logging.info(example)
         # We expect the second row to be dropped all cells should be trimmed to
         # >= 2 tokens.
         self.assertEqual(_get_int_feature(example, 'column_ids'),
                          [0, 0, 0, 1, 2, 2, 3, 3, 1, 1, 2, 2, 3, 3, 0, 0])
Beispiel #7
0
def _convert_table(table_id, table_text):
    """Parses a table from # separated values format into proto format."""
    rows = []
    with six.StringIO(table_text) as csv_in:
        for index, row in enumerate(csv.reader(csv_in, delimiter='#')):
            cells = [interaction_pb2.Cell(text=text) for text in row]
            if index == 0:
                columns = cells
            else:
                rows.append(interaction_pb2.Cells(cells=cells))
    return interaction_pb2.Table(table_id=f'{_TABLE_DIR_NAME}/{table_id}',
                                 columns=columns,
                                 rows=rows)
Beispiel #8
0
 def _get_interaction(self):
     return interaction_pb2.Interaction(
         table=interaction_pb2.Table(
             columns=[
                 interaction_pb2.Cell(text="A:/, c"),
                 interaction_pb2.Cell(text="B"),
                 interaction_pb2.Cell(text="C"),
             ],
             rows=[
                 interaction_pb2.Cells(cells=[
                     interaction_pb2.Cell(text="0"),
                     interaction_pb2.Cell(text="4"),
                     interaction_pb2.Cell(text="6"),
                 ]),
                 interaction_pb2.Cells(cells=[
                     interaction_pb2.Cell(text="1"),
                     interaction_pb2.Cell(text="3"),
                     interaction_pb2.Cell(text="5"),
                 ]),
             ],
         ),
         questions=[
             interaction_pb2.Question(
                 id="id-1",
                 original_text="A is 5",
                 text="A is 5",
                 answer=interaction_pb2.Answer(answer_coordinates=[
                     interaction_pb2.AnswerCoordinate(row_index=2,
                                                      column_index=2),
                     interaction_pb2.AnswerCoordinate(row_index=0,
                                                      column_index=2),
                 ])),
             interaction_pb2.Question(id="id-2",
                                      original_text="B is A",
                                      text="A is 5 B is A")
         ],
     )
Beispiel #9
0
    def test_convert_with_token_selection(self):
        max_seq_length = 12
        with tempfile.TemporaryDirectory() as input_dir:
            vocab_file = os.path.join(input_dir, 'vocab.txt')
            _create_vocab(vocab_file, range(10))
            converter = tf_example_utils.ToClassifierTensorflowExample(
                config=tf_example_utils.ClassifierConversionConfig(
                    vocab_file=vocab_file,
                    max_seq_length=max_seq_length,
                    max_column_id=max_seq_length,
                    max_row_id=max_seq_length,
                    strip_column_names=False,
                    add_aggregation_candidates=False,
                ))
            interaction = interaction_pb2.Interaction(
                table=interaction_pb2.Table(
                    columns=[
                        interaction_pb2.Cell(text='A'),
                        interaction_pb2.Cell(text='B'),
                        interaction_pb2.Cell(text='C'),
                    ],
                    rows=[
                        interaction_pb2.Cells(cells=[
                            interaction_pb2.Cell(text='0 6'),
                            interaction_pb2.Cell(text='4 7'),
                            interaction_pb2.Cell(text='5 6'),
                        ]),
                        interaction_pb2.Cells(cells=[
                            interaction_pb2.Cell(text='1 7'),
                            interaction_pb2.Cell(text='3 6'),
                            interaction_pb2.Cell(text='5 5'),
                        ]),
                    ],
                ),
                questions=[
                    interaction_pb2.Question(id='id', original_text='2')
                ],
            )
            table_coordinates = []
            for r, c, t in [(0, 0, 0), (1, 0, 0), (1, 2, 0), (2, 0, 0),
                            (2, 2, 0), (2, 2, 1)]:
                table_coordinates.append(
                    table_selection_pb2.TableSelection.TokenCoordinates(
                        row_index=r, column_index=c, token_index=t))
            interaction.questions[0].Extensions[
                table_selection_pb2.TableSelection.
                table_selection_ext].CopyFrom(
                    table_selection_pb2.TableSelection(
                        selected_tokens=table_coordinates))

            number_annotation_utils.add_numeric_values(interaction)
            example = converter.convert(interaction, 0)
            logging.info(example)
            self.assertEqual(_get_int_feature(example, 'input_ids'),
                             [2, 8, 3, 1, 6, 11, 7, 11, 11, 0, 0, 0])
            self.assertEqual(_get_int_feature(example, 'row_ids'),
                             [0, 0, 0, 0, 1, 1, 2, 2, 2, 0, 0, 0])
            self.assertEqual(_get_int_feature(example, 'column_ids'),
                             [0, 0, 0, 1, 1, 3, 1, 3, 3, 0, 0, 0])
            self.assertEqual(_get_int_feature(example, 'column_ranks'),
                             [0, 0, 0, 0, 1, 1, 2, 1, 1, 0, 0, 0])
            self.assertEqual(_get_int_feature(example, 'numeric_relations'),
                             [0, 0, 0, 0, 4, 2, 4, 2, 2, 0, 0, 0])
def _get_table(table_id):
  return interaction_pb2.Table(
      columns=[
          interaction_pb2.Cell(text="Position"),
          interaction_pb2.Cell(text="Player"),
          interaction_pb2.Cell(text="Team"),
      ],
      rows=[
          interaction_pb2.Cells(cells=[
              interaction_pb2.Cell(text="1"),
              interaction_pb2.Cell(text="player 1"),
              interaction_pb2.Cell(text="team 1"),
          ]),
          interaction_pb2.Cells(cells=[
              interaction_pb2.Cell(text="2"),
              interaction_pb2.Cell(text="player 2"),
              interaction_pb2.Cell(text="team 2"),
          ]),
          interaction_pb2.Cells(cells=[
              interaction_pb2.Cell(text="1"),
              interaction_pb2.Cell(text="player 3"),
              interaction_pb2.Cell(text="team 2"),
          ]),
      ],
      table_id=table_id,
  )
Beispiel #11
0
 def test_interaction_duplicate_column_name(self):
     """Test we don't crash when seeing ambiguous column names."""
     config = synthesize_entablement.SynthesizationConfig(attempts=10)
     interaction = interaction_pb2.Interaction(
         id='i_id',
         table=interaction_pb2.Table(
             table_id='t_id',
             columns=[
                 interaction_pb2.Cell(text='Name'),
                 interaction_pb2.Cell(text='Name'),
                 interaction_pb2.Cell(text='Height')
             ],
             rows=[
                 interaction_pb2.Cells(cells=[
                     interaction_pb2.Cell(text='Peter'),
                     interaction_pb2.Cell(text='Peter'),
                     interaction_pb2.Cell(text='100')
                 ]),
                 interaction_pb2.Cells(cells=[
                     interaction_pb2.Cell(text='Bob'),
                     interaction_pb2.Cell(text='Bob'),
                     interaction_pb2.Cell(text='150')
                 ]),
                 interaction_pb2.Cells(cells=[
                     interaction_pb2.Cell(text='Tina'),
                     interaction_pb2.Cell(text='Tina'),
                     interaction_pb2.Cell(text='200')
                 ]),
             ]),
         questions=[])
     for i in range(20):
         rng = np.random.RandomState(i)
         synthesize_entablement.synthesize_from_interaction(
             config, rng, interaction, synthesize_entablement.Counter())
Beispiel #12
0
    def test_interaction(self, prob_count_aggregation):
        config = synthesize_entablement.SynthesizationConfig(
            prob_count_aggregation=prob_count_aggregation, attempts=10)
        interaction = interaction_pb2.Interaction(
            id='i_id',
            table=interaction_pb2.Table(
                table_id='t_id',
                columns=[
                    interaction_pb2.Cell(text='Name'),
                    interaction_pb2.Cell(text='Height'),
                    interaction_pb2.Cell(text='Age')
                ],
                rows=[
                    interaction_pb2.Cells(cells=[
                        interaction_pb2.Cell(text='Peter'),
                        interaction_pb2.Cell(text='100'),
                        interaction_pb2.Cell(text='15')
                    ]),
                    interaction_pb2.Cells(cells=[
                        interaction_pb2.Cell(text='Bob'),
                        interaction_pb2.Cell(text='150'),
                        interaction_pb2.Cell(text='15')
                    ]),
                    interaction_pb2.Cells(cells=[
                        interaction_pb2.Cell(text='Tina'),
                        interaction_pb2.Cell(text='200'),
                        interaction_pb2.Cell(text='17')
                    ]),
                ]),
            questions=[])

        pos_statements = set()
        neg_statements = set()

        counter = TestCounter()

        for i in range(100):
            rng = np.random.RandomState(i)
            interactions = synthesize_entablement.synthesize_from_interaction(
                config, rng, interaction, counter)
            for new_interaction in interactions:
                question = new_interaction.questions[0]
                if question.answer.class_index == 1:
                    pos_statements.add(question.text)
                else:
                    assert question.answer.class_index == 0
                    neg_statements.add(question.text)
        self.assertEqual(neg_statements, pos_statements)
        logging.info('pos_statements: %s', pos_statements)

        counts = counter.get_counts()
        logging.info('counts: %s', counts)

        is_count_test = prob_count_aggregation == 1.0

        if is_count_test:
            self.assertGreater(len(pos_statements), 10)
            expected_statements = {
                '1 is less than the count when age is 15 and height is greater than 100',
                '1 is less than the count when height is less than 200 and age is 15',
                '1 is the count when height is greater than 100 and age is less than 17',
                '2 is the count when age is less than 17 and height is less than 200',
            }
        else:
            self.assertGreater(len(pos_statements), 100)
            expected_statements = {
                '0 is the range of age when height is greater than 100',
                '100 is less than the last height when height is less than 200',
                '125 is greater than height when name is peter',
                '15 is age when height is less than 150',
                '15 is the last age when height is less than 200',
                '150 is the last height when age is 15',
                '175 is the average height when age is less than 17',
                '200 is greater than the greatest height when age is less than 17',
                '250 is less than the total height when age is 15',
                '30 is less than the total age when height is greater than 100',
                'bob is name when age is greater than 15',
                'bob is the first name when age is 15 and height is less than 200',
                'peter is name when age is 15 and height is less than 150',
                'the average height when age is 15 is less than 175',
                'the first height when height is greater than 100 is 150',
                'the first height when height is less than 200 is 150',
                'the first name when age is 15 is name when name is peter',
                'the greatest height when age is 15 is less than 200',
                'the last age when height is greater than 100 is greater than 15',
                'the last name when age is 15 is bob',
                'the last name when age is less than 17 is peter',
                'the last name when height is greater than 100 is bob',
                'the last name when height is less than 200 is bob',
                'the lowest height when age is 15 is 150',
                'the range of age when height is greater than 100 is greater than 0',
                'the range of height when age is 15 is 100',
                'tina is name when age is greater than 15 and height is 200',
                'tina is the first name when age is 15',
                'tina is the last name when age is 15',
                'tina is the last name when height is greater than 100',
            }

        for statement in expected_statements:
            self.assertIn(statement, pos_statements)

        for name in ['pos', 'neg']:
            if is_count_test:
                self.assertGreater(counts[f'{name}: Synthesization success'],
                                   10)
                self.assertGreater(counts[f'{name}: Select: COUNT'], 10)
            else:
                self.assertEqual(counts[f'{name}: Synthesization success'],
                                 100)
                for aggregation in Aggregation:
                    self.assertGreater(
                        counts[f'{name}: Select: {aggregation.name}'], 0)
            for comparator in Comparator:
                min_count = 1 if prob_count_aggregation == 1.0 else 10
                self.assertGreater(
                    counts[f'{name}: Comparator {comparator.name}'], min_count)
                self.assertGreater(
                    counts[f'{name}: where: Comparator {comparator.name}'],
                    min_count)
Beispiel #13
0
 def test_convert_with_negative_tables(self):
     max_seq_length = 12
     with tempfile.TemporaryDirectory() as input_dir:
         vocab_file = os.path.join(input_dir, 'vocab.txt')
         _create_vocab(vocab_file, range(10))
         converter = tf_example_utils.ToRetrievalTensorflowExample(
             config=tf_example_utils.RetrievalConversionConfig(
                 vocab_file=vocab_file,
                 max_seq_length=max_seq_length,
                 max_column_id=max_seq_length,
                 max_row_id=max_seq_length,
                 strip_column_names=False,
             ))
         interaction = interaction_pb2.Interaction(
             table=interaction_pb2.Table(
                 columns=[
                     interaction_pb2.Cell(text='A'),
                     interaction_pb2.Cell(text='B'),
                     interaction_pb2.Cell(text='C'),
                 ],
                 rows=[
                     interaction_pb2.Cells(cells=[
                         interaction_pb2.Cell(text='0 6'),
                         interaction_pb2.Cell(text='4 7'),
                         interaction_pb2.Cell(text='5 6'),
                     ]),
                     interaction_pb2.Cells(cells=[
                         interaction_pb2.Cell(text='1 7'),
                         interaction_pb2.Cell(text='3 6'),
                         interaction_pb2.Cell(text='5 5'),
                     ]),
                 ],
                 table_id='table_0',
             ),
             questions=[
                 interaction_pb2.Question(
                     id='id',
                     original_text='2',
                 )
             ],
         )
         number_annotation_utils.add_numeric_values(interaction)
         n_table = interaction_pb2.Table(
             columns=[
                 interaction_pb2.Cell(text='A'),
                 interaction_pb2.Cell(text='B'),
             ],
             rows=[
                 interaction_pb2.Cells(cells=[
                     interaction_pb2.Cell(text='0 6'),
                     interaction_pb2.Cell(text='4 7'),
                 ]),
                 interaction_pb2.Cells(cells=[
                     interaction_pb2.Cell(text='1 7'),
                     interaction_pb2.Cell(text='3 6'),
                 ]),
             ],
             table_id='table_1',
         )
         number_annotation_utils.add_numeric_table_values(n_table)
         n_example = _NegativeRetrievalExample()
         n_example.table.CopyFrom(n_table)
         n_example.score = -82.0
         n_example.rank = 1
         example = converter.convert(interaction, 0, n_example)
         logging.info(example)
         self.assertEqual(_get_int_feature(example, 'input_ids'), [
             2, 5, 3, 1, 1, 1, 6, 10, 11, 7, 9, 11, 2, 5, 3, 1, 1, 6, 10, 7,
             9, 0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'row_ids'), [
             0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 0, 0, 1, 1, 2, 2,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'column_ids'), [
             0, 0, 0, 1, 2, 3, 1, 2, 3, 1, 2, 3, 0, 0, 0, 1, 2, 1, 2, 1, 2,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'segment_ids'), [
             0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'input_mask'), [
             1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'inv_column_ranks'), [
             0, 0, 0, 0, 0, 0, 2, 1, 1, 1, 2, 1, 0, 0, 0, 0, 0, 2, 1, 1, 2,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'column_ranks'), [
             0, 0, 0, 0, 0, 0, 1, 2, 1, 2, 1, 1, 0, 0, 0, 0, 0, 1, 2, 2, 1,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'numeric_relations'), [
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'table_id_hash'),
                          [911224864, 1294380046])
         self.assertEqual(_get_float_feature(example, 'numeric_values'), [
             'nan', 'nan', 'nan', 'nan', 'nan', 'nan', 0.0, 4.0, 5.0, 1.0,
             3.0, 5.0, 'nan', 'nan', 'nan', 'nan', 'nan', 0.0, 4.0, 1.0,
             3.0, 'nan', 'nan', 'nan'
         ])
         self.assertEqual(
             _get_float_feature(example, 'numeric_values_scale'), [
                 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
                 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0
             ])
         self.assertEqual([
             i.decode('utf-8')
             for i in _get_byte_feature(example, 'table_id')
         ], ['table_0', 'table_1'])
Beispiel #14
0
  def test_add_entity_descriptions_to_table(self):
    annotated_cell_ext = annotated_text_pb2.AnnotatedText.annotated_cell_ext
    table = interaction_pb2.Table(
        columns=[
            interaction_pb2.Cell(text='A'),
            interaction_pb2.Cell(text='B'),
            interaction_pb2.Cell(text='C'),
        ],
        rows=[
            interaction_pb2.Cells(cells=[
                interaction_pb2.Cell(text='0 6'),
                interaction_pb2.Cell(text='4 7'),
                interaction_pb2.Cell(text='5 6'),
            ]),
            interaction_pb2.Cells(cells=[
                interaction_pb2.Cell(text='1 7'),
                interaction_pb2.Cell(text='3 6'),
                interaction_pb2.Cell(text='5 5'),
            ]),
        ],
    )
    # Add some annotations to the table
    entities = ['0', '3']
    for row in table.rows:
      for cell in row.cells:
        for entity in entities:
          if entity in cell.text:
            cell.Extensions[annotated_cell_ext].annotations.add(
                identifier=f'/wiki/{entity}',)

    question = interaction_pb2.Question(
        id='id', text='What prime number has religious meaning?')
    descriptions = {
        '/wiki/0': ('0 (zero) is a number, and the digit used to represent ' +
                    'that number in numerals. It fulfills a central role in ' +
                    'mathematics as the additive identity of the integers.'),
        '/wiki/3':
            ('3 (three) is a number, numeral, and glyph. It is the natural ' +
             'number following 2 and preceding 4, and is the smallest odd ' +
             'prime number. It has religious or cultural significance.')
    }
    expected_table = interaction_pb2.Table()
    expected_table.CopyFrom(table)
    # Only the top two sentences are used, based on tf-idf score
    expected_table.rows[1].cells[1].text = (
        '3 6 ( It is the natural number following 2 and preceding 4, and is ' +
        'the smallest odd prime number. It has religious or cultural ' +
        'significance. )')

    table_without_titles = interaction_pb2.Table()
    table_without_titles.CopyFrom(table)
    tf_example_utils._add_entity_descriptions_to_table(
        question,
        descriptions,
        table_without_titles,
        num_results=2,
        use_entity_title=False)
    self.assertEqual(table_without_titles, expected_table)

    table_with_titles = interaction_pb2.Table()
    table_with_titles.CopyFrom(table)
    expected_table.rows[1].cells[1].text = (
        '3 6 ( 3 : It is the natural number following 2 and preceding 4, and ' +
        'is the smallest odd prime number. It has religious or cultural ' +
        'significance. )')
    tf_example_utils._add_entity_descriptions_to_table(
        question,
        descriptions,
        table_with_titles,
        num_results=2,
        use_entity_title=True)
    self.assertEqual(table_with_titles, expected_table)