Esempio n. 1
0
def _convert_data(
    all_questions,
    input_file,
    tables,
):
    """Converts TabFact data to interactions format."""
    logging.info('Converting data from: %s...', input_file)

    counter = collections.Counter()  # Counter for stats.

    with tf.io.gfile.GFile(input_file) as file_in:
        for table_id in json.load(file_in):
            questions, labels, _ = all_questions[table_id]
            for i, (text, label) in enumerate(zip(questions, labels)):
                # The extra zeros are there to match SQA id format.
                question_id = f'{table_id}_{i}-0'
                question = interaction_pb2.Question(
                    id=f'{question_id}_0',
                    original_text=text,
                    answer=interaction_pb2.Answer(class_index=label))
                table = interaction_pb2.Table()
                table.CopyFrom(tables[table_id])
                yield interaction_pb2.Interaction(id=question_id,
                                                  questions=[question],
                                                  table=table)

                counter['questions'] += 1
                if counter['questions'] % 1000 == 0:
                    logging.info('Processed %d questions...',
                                 counter['questions'])

        _log_stats(counter, input_file)
Esempio n. 2
0
def _parse_questions(interaction_dict, supervision_modes, report_filename):
    """Adds numeric value spans to all questions."""
    counters = collections.defaultdict(collections.Counter)
    for key, interactions in interaction_dict.items():
        for interaction in interactions:
            questions = []
            for original_question in interaction.questions:
                try:
                    question = interaction_utils_parser.parse_question(
                        interaction.table, original_question,
                        supervision_modes[key])
                    counters[key]['valid'] += 1
                except ValueError as exc:
                    question = interaction_pb2.Question()
                    question.CopyFrom(original_question)
                    question.answer.is_valid = False
                    counters[key]['failed'] += 1
                    counters[key]['failed-' + str(exc)] += 1

                questions.append(question)

            del interaction.questions[:]
            interaction.questions.extend(questions)

    _write_report(report_filename, supervision_modes, counters)
def _parse_question(table, original_question, clear_fields):
    """Parses question's answer_texts fields to possibly populate additional fields.

  Args:
    table: a Table message, needed to compute the answer coordinates.
    original_question: a Question message containing answer_texts.
    clear_fields: A list of strings indicating which fields need to be cleared
      and possibly repopulated.

  Returns:
    A Question message with answer_coordinates or float_value field populated.

  Raises:
    ValueError if we cannot parse correctly the question message.
  """
    question = interaction_pb2.Question()
    question.CopyFrom(original_question)

    # If we have a float value signal we just copy its string representation to
    # the answer text (if multiple answers texts are present OR the answer text
    # cannot be parsed to float OR the float value is different), after clearing
    # this field.
    if "float_value" in clear_fields and question.answer.HasField(
            "float_value"):
        if not _has_single_float_answer_equal_to(question,
                                                 question.answer.float_value):
            del question.answer.answer_texts[:]
            float_value = float(question.answer.float_value)
            if float_value.is_integer():
                number_str = str(int(float_value))
            else:
                number_str = str(float_value)
            question.answer.answer_texts.append(number_str)

    if not question.answer.answer_texts:
        raise ValueError("No answer_texts provided")

    for field_name in clear_fields:
        question.answer.ClearField(field_name)

    error_message = ""

    if not question.answer.answer_coordinates:
        try:
            _parse_answer_coordinates(table, question.answer)
        except ValueError as exc:
            error_message += "[answer_coordinates: {}]".format(str(exc))

    if not question.answer.HasField("float_value"):
        try:
            _parse_answer_float(question.answer)
        except ValueError as exc:
            error_message += "[float_value: {}]".format(str(exc))

    # Raises an exception if we cannot set any of the two fields.
    if not question.answer.answer_coordinates and not question.answer.HasField(
            "float_value"):
        raise ValueError("Cannot parse answer: {}".format(error_message))

    return question
Esempio n. 4
0
def read_from_tsv_file(file_handle):
    """Parses a TSV file in SQA format into a list of interactions.

  Args:
    file_handle:  File handle of a TSV file in SQA format.

  Returns:
    Questions grouped into interactions.
  """
    questions = {}
    for row in csv.DictReader(file_handle, delimiter='\t'):
        sequence_id = text_utils.get_sequence_id(row[_ID], row[_ANNOTATOR])
        key = sequence_id, row[_TABLE_FILE]
        if key not in questions:
            questions[key] = {}

        position = int(row[_POSITION])

        answer = interaction_pb2.Answer()
        _parse_answer_coordinates(row[_ANSWER_COORDINATES], answer)
        _parse_answer_text(row[_ANSWER_TEXT], answer)

        if _AGGREGATION in row:
            agg_func = row[_AGGREGATION].upper().strip()
            if agg_func:
                answer.aggregation_function = _AggregationFunction.Value(
                    agg_func)
        if _ANSWER_FLOAT_VALUE in row:
            float_value = row[_ANSWER_FLOAT_VALUE]
            if float_value:
                answer.float_value = float(float_value)
        if _ANSWER_CLASS_INDEX in row:
            class_index = row[_ANSWER_CLASS_INDEX]
            if class_index:
                answer.class_index = int(class_index)

        questions[key][position] = interaction_pb2.Question(
            id=text_utils.get_question_id(sequence_id, position),
            original_text=row[_QUESTION],
            answer=answer)

    interactions = []
    for (sequence_id,
         table_file), question_dict in sorted(questions.items(),
                                              key=lambda sid: sid[0]):
        question_list = [
            question for _, question in sorted(question_dict.items(),
                                               key=lambda pos: pos[0])
        ]
        interactions.append(
            interaction_pb2.Interaction(
                id=sequence_id,
                questions=question_list,
                table=interaction_pb2.Table(table_id=table_file)))
    return interactions
Esempio n. 5
0
 def test_convert(self):
     max_seq_length = 12
     with tempfile.TemporaryDirectory() as input_dir:
         vocab_file = os.path.join(input_dir, 'vocab.txt')
         _create_vocab(vocab_file, range(10))
         converter = tf_example_utils.ToClassifierTensorflowExample(
             config=tf_example_utils.ClassifierConversionConfig(
                 vocab_file=vocab_file,
                 max_seq_length=max_seq_length,
                 max_column_id=max_seq_length,
                 max_row_id=max_seq_length,
                 strip_column_names=False,
                 add_aggregation_candidates=False,
             ))
         interaction = interaction_pb2.Interaction(
             table=interaction_pb2.Table(
                 columns=[
                     interaction_pb2.Cell(text='A'),
                     interaction_pb2.Cell(text='B'),
                     interaction_pb2.Cell(text='C'),
                 ],
                 rows=[
                     interaction_pb2.Cells(cells=[
                         interaction_pb2.Cell(text='0'),
                         interaction_pb2.Cell(text='4'),
                         interaction_pb2.Cell(text='5'),
                     ]),
                     interaction_pb2.Cells(cells=[
                         interaction_pb2.Cell(text='1'),
                         interaction_pb2.Cell(text='3'),
                         interaction_pb2.Cell(text='5'),
                     ]),
                 ],
             ),
             questions=[
                 interaction_pb2.Question(id='id', original_text='2')
             ],
         )
         number_annotation_utils.add_numeric_values(interaction)
         example = converter.convert(interaction, 0)
         logging.info(example)
         self.assertEqual(_get_int_feature(example, 'input_ids'),
                          [2, 8, 3, 1, 1, 1, 6, 10, 11, 7, 9, 11])
         self.assertEqual(_get_int_feature(example, 'row_ids'),
                          [0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 2])
         self.assertEqual(_get_int_feature(example, 'column_ids'),
                          [0, 0, 0, 1, 2, 3, 1, 2, 3, 1, 2, 3])
         self.assertEqual(_get_int_feature(example, 'column_ranks'),
                          [0, 0, 0, 0, 0, 0, 1, 2, 1, 2, 1, 1])
         self.assertEqual(_get_int_feature(example, 'numeric_relations'),
                          [0, 0, 0, 0, 0, 0, 4, 2, 2, 4, 2, 2])
         self.assertEqual(
             _get_float_feature(example, 'question_numeric_values'),
             _clean_nans([2.0] + [_NAN] * (_MAX_NUMERIC_VALUES - 1)))
Esempio n. 6
0
 def _get_interaction(self):
     return interaction_pb2.Interaction(
         table=interaction_pb2.Table(
             columns=[
                 interaction_pb2.Cell(text="A:/, c"),
                 interaction_pb2.Cell(text="B"),
                 interaction_pb2.Cell(text="C"),
             ],
             rows=[
                 interaction_pb2.Cells(cells=[
                     interaction_pb2.Cell(text="0"),
                     interaction_pb2.Cell(text="4"),
                     interaction_pb2.Cell(text="6"),
                 ]),
                 interaction_pb2.Cells(cells=[
                     interaction_pb2.Cell(text="1"),
                     interaction_pb2.Cell(text="3"),
                     interaction_pb2.Cell(text="5"),
                 ]),
             ],
         ),
         questions=[
             interaction_pb2.Question(
                 id="id-1",
                 original_text="A is 5",
                 text="A is 5",
                 answer=interaction_pb2.Answer(answer_coordinates=[
                     interaction_pb2.AnswerCoordinate(row_index=2,
                                                      column_index=2),
                     interaction_pb2.AnswerCoordinate(row_index=0,
                                                      column_index=2),
                 ])),
             interaction_pb2.Question(id="id-2",
                                      original_text="B is A",
                                      text="A is 5 B is A")
         ],
     )
Esempio n. 7
0
 def test_convert_with_context_heading(self):
     max_seq_length = 20
     with tempfile.TemporaryDirectory() as input_dir:
         vocab_file = os.path.join(input_dir, 'vocab.txt')
         _create_vocab(vocab_file, ['a', 'b', 'c', 'd', 'e'])
         converter = tf_example_utils.ToClassifierTensorflowExample(
             config=tf_example_utils.ClassifierConversionConfig(
                 vocab_file=vocab_file,
                 max_seq_length=max_seq_length,
                 max_column_id=max_seq_length,
                 max_row_id=max_seq_length,
                 strip_column_names=False,
                 add_aggregation_candidates=False,
                 use_document_title=True,
                 use_context_title=True,
                 update_answer_coordinates=True,
             ))
         interaction = interaction_pb2.Interaction(
             table=interaction_pb2.Table(
                 document_title='E E',
                 columns=[
                     interaction_pb2.Cell(text='A'),
                     interaction_pb2.Cell(text='A B C'),
                 ],
                 rows=[
                     interaction_pb2.Cells(cells=[
                         interaction_pb2.Cell(text='A B'),
                         interaction_pb2.Cell(text='A B C'),
                     ]),
                 ],
                 context_heading='B',
             ),
             questions=[
                 interaction_pb2.Question(
                     id='id',
                     original_text='D',
                     answer=interaction_pb2.Answer(answer_texts=['B C']),
                 )
             ],
         )
         example = converter.convert(interaction, 0)
         logging.info(example)
         self.assertEqual(
             _get_int_feature(example, 'input_ids'),
             [2, 5, 3, 10, 10, 3, 7, 3, 6, 6, 7, 8, 6, 7, 6, 7, 8, 0, 0, 0])
         self.assertEqual(
             _get_int_feature(example, 'label_ids'),
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0])
Esempio n. 8
0
 def test_convert_with_trimmed_cell(self):
     max_seq_length = 16
     with tempfile.TemporaryDirectory() as input_dir:
         vocab_file = os.path.join(input_dir, 'vocab.txt')
         _create_vocab(vocab_file, range(10))
         converter = tf_example_utils.ToClassifierTensorflowExample(
             config=tf_example_utils.ClassifierConversionConfig(
                 vocab_file=vocab_file,
                 max_seq_length=max_seq_length,
                 max_column_id=max_seq_length,
                 max_row_id=max_seq_length,
                 strip_column_names=False,
                 add_aggregation_candidates=False,
                 cell_trim_length=2,
                 drop_rows_to_fit=True))
         interaction = interaction_pb2.Interaction(
             table=interaction_pb2.Table(
                 columns=[
                     interaction_pb2.Cell(text='A'),
                     interaction_pb2.Cell(text='A A'),
                     interaction_pb2.Cell(text='A A A A'),
                 ],
                 rows=[
                     interaction_pb2.Cells(cells=[
                         interaction_pb2.Cell(text='A A A'),
                         interaction_pb2.Cell(text='A A A'),
                         interaction_pb2.Cell(text='A A A'),
                     ]),
                     interaction_pb2.Cells(cells=[
                         interaction_pb2.Cell(text='A A A'),
                         interaction_pb2.Cell(text='A A A'),
                         interaction_pb2.Cell(text='A A A'),
                     ]),
                 ],
             ),
             questions=[
                 interaction_pb2.Question(id='id', original_text='A')
             ],
         )
         number_annotation_utils.add_numeric_values(interaction)
         example = converter.convert(interaction, 0)
         logging.info(example)
         # We expect the second row to be dropped all cells should be trimmed to
         # >= 2 tokens.
         self.assertEqual(_get_int_feature(example, 'column_ids'),
                          [0, 0, 0, 1, 2, 2, 3, 3, 1, 1, 2, 2, 3, 3, 0, 0])
def _get_interaction(interaction_id, question_id,
                     table):
  interaction = interaction_pb2.Interaction(
      id=interaction_id,
      questions=[
          interaction_pb2.Question(
              id=question_id,
              text="",
              original_text="What position does the player who played for team 1?",
              answer=interaction_pb2.Answer(
                  answer_coordinates=[
                      interaction_pb2.AnswerCoordinate(
                          row_index=0, column_index=0),
                  ],
                  answer_texts=["first_answer", "second_answer"],
              ),
          ),
      ],
  )
  interaction.table.CopyFrom(table)
  return interaction
Esempio n. 10
0
 def get_empty_example(self):
     interaction = interaction_pb2.Interaction(questions=[
         interaction_pb2.Question(id=text_utils.get_padded_question_id())
     ])
     return self.convert(interaction, index=0)
Esempio n. 11
0
    def test_convert_with_token_selection(self):
        max_seq_length = 12
        with tempfile.TemporaryDirectory() as input_dir:
            vocab_file = os.path.join(input_dir, 'vocab.txt')
            _create_vocab(vocab_file, range(10))
            converter = tf_example_utils.ToClassifierTensorflowExample(
                config=tf_example_utils.ClassifierConversionConfig(
                    vocab_file=vocab_file,
                    max_seq_length=max_seq_length,
                    max_column_id=max_seq_length,
                    max_row_id=max_seq_length,
                    strip_column_names=False,
                    add_aggregation_candidates=False,
                ))
            interaction = interaction_pb2.Interaction(
                table=interaction_pb2.Table(
                    columns=[
                        interaction_pb2.Cell(text='A'),
                        interaction_pb2.Cell(text='B'),
                        interaction_pb2.Cell(text='C'),
                    ],
                    rows=[
                        interaction_pb2.Cells(cells=[
                            interaction_pb2.Cell(text='0 6'),
                            interaction_pb2.Cell(text='4 7'),
                            interaction_pb2.Cell(text='5 6'),
                        ]),
                        interaction_pb2.Cells(cells=[
                            interaction_pb2.Cell(text='1 7'),
                            interaction_pb2.Cell(text='3 6'),
                            interaction_pb2.Cell(text='5 5'),
                        ]),
                    ],
                ),
                questions=[
                    interaction_pb2.Question(id='id', original_text='2')
                ],
            )
            table_coordinates = []
            for r, c, t in [(0, 0, 0), (1, 0, 0), (1, 2, 0), (2, 0, 0),
                            (2, 2, 0), (2, 2, 1)]:
                table_coordinates.append(
                    table_selection_pb2.TableSelection.TokenCoordinates(
                        row_index=r, column_index=c, token_index=t))
            interaction.questions[0].Extensions[
                table_selection_pb2.TableSelection.
                table_selection_ext].CopyFrom(
                    table_selection_pb2.TableSelection(
                        selected_tokens=table_coordinates))

            number_annotation_utils.add_numeric_values(interaction)
            example = converter.convert(interaction, 0)
            logging.info(example)
            self.assertEqual(_get_int_feature(example, 'input_ids'),
                             [2, 8, 3, 1, 6, 11, 7, 11, 11, 0, 0, 0])
            self.assertEqual(_get_int_feature(example, 'row_ids'),
                             [0, 0, 0, 0, 1, 1, 2, 2, 2, 0, 0, 0])
            self.assertEqual(_get_int_feature(example, 'column_ids'),
                             [0, 0, 0, 1, 1, 3, 1, 3, 3, 0, 0, 0])
            self.assertEqual(_get_int_feature(example, 'column_ranks'),
                             [0, 0, 0, 0, 1, 1, 2, 1, 1, 0, 0, 0])
            self.assertEqual(_get_int_feature(example, 'numeric_relations'),
                             [0, 0, 0, 0, 4, 2, 4, 2, 2, 0, 0, 0])
Esempio n. 12
0
 def test_convert_with_negative_tables(self):
     max_seq_length = 12
     with tempfile.TemporaryDirectory() as input_dir:
         vocab_file = os.path.join(input_dir, 'vocab.txt')
         _create_vocab(vocab_file, range(10))
         converter = tf_example_utils.ToRetrievalTensorflowExample(
             config=tf_example_utils.RetrievalConversionConfig(
                 vocab_file=vocab_file,
                 max_seq_length=max_seq_length,
                 max_column_id=max_seq_length,
                 max_row_id=max_seq_length,
                 strip_column_names=False,
             ))
         interaction = interaction_pb2.Interaction(
             table=interaction_pb2.Table(
                 columns=[
                     interaction_pb2.Cell(text='A'),
                     interaction_pb2.Cell(text='B'),
                     interaction_pb2.Cell(text='C'),
                 ],
                 rows=[
                     interaction_pb2.Cells(cells=[
                         interaction_pb2.Cell(text='0 6'),
                         interaction_pb2.Cell(text='4 7'),
                         interaction_pb2.Cell(text='5 6'),
                     ]),
                     interaction_pb2.Cells(cells=[
                         interaction_pb2.Cell(text='1 7'),
                         interaction_pb2.Cell(text='3 6'),
                         interaction_pb2.Cell(text='5 5'),
                     ]),
                 ],
                 table_id='table_0',
             ),
             questions=[
                 interaction_pb2.Question(
                     id='id',
                     original_text='2',
                 )
             ],
         )
         number_annotation_utils.add_numeric_values(interaction)
         n_table = interaction_pb2.Table(
             columns=[
                 interaction_pb2.Cell(text='A'),
                 interaction_pb2.Cell(text='B'),
             ],
             rows=[
                 interaction_pb2.Cells(cells=[
                     interaction_pb2.Cell(text='0 6'),
                     interaction_pb2.Cell(text='4 7'),
                 ]),
                 interaction_pb2.Cells(cells=[
                     interaction_pb2.Cell(text='1 7'),
                     interaction_pb2.Cell(text='3 6'),
                 ]),
             ],
             table_id='table_1',
         )
         number_annotation_utils.add_numeric_table_values(n_table)
         n_example = _NegativeRetrievalExample()
         n_example.table.CopyFrom(n_table)
         n_example.score = -82.0
         n_example.rank = 1
         example = converter.convert(interaction, 0, n_example)
         logging.info(example)
         self.assertEqual(_get_int_feature(example, 'input_ids'), [
             2, 5, 3, 1, 1, 1, 6, 10, 11, 7, 9, 11, 2, 5, 3, 1, 1, 6, 10, 7,
             9, 0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'row_ids'), [
             0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 0, 0, 1, 1, 2, 2,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'column_ids'), [
             0, 0, 0, 1, 2, 3, 1, 2, 3, 1, 2, 3, 0, 0, 0, 1, 2, 1, 2, 1, 2,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'segment_ids'), [
             0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'input_mask'), [
             1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'inv_column_ranks'), [
             0, 0, 0, 0, 0, 0, 2, 1, 1, 1, 2, 1, 0, 0, 0, 0, 0, 2, 1, 1, 2,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'column_ranks'), [
             0, 0, 0, 0, 0, 0, 1, 2, 1, 2, 1, 1, 0, 0, 0, 0, 0, 1, 2, 2, 1,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'numeric_relations'), [
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'table_id_hash'),
                          [911224864, 1294380046])
         self.assertEqual(_get_float_feature(example, 'numeric_values'), [
             'nan', 'nan', 'nan', 'nan', 'nan', 'nan', 0.0, 4.0, 5.0, 1.0,
             3.0, 5.0, 'nan', 'nan', 'nan', 'nan', 'nan', 0.0, 4.0, 1.0,
             3.0, 'nan', 'nan', 'nan'
         ])
         self.assertEqual(
             _get_float_feature(example, 'numeric_values_scale'), [
                 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
                 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0
             ])
         self.assertEqual([
             i.decode('utf-8')
             for i in _get_byte_feature(example, 'table_id')
         ], ['table_0', 'table_1'])
Esempio n. 13
0
    def test_eval_cell_selection(self):
        tempfile_name = tempfile.mkstemp()[1]
        answer_coordinates = str(['(5, 6)', '(1, 1)', '(1, 2)'])
        token_probabilities = json.dumps([(1, 1, 0.8), (2, 1, 0.7),
                                          (2, 3, 0.4), (6, 5, 0.9)])
        with open(tempfile_name, 'w') as f:
            f.write('question_id\tanswer_coordinates\ttoken_probabilities\n')
            f.write(
                f'meaning-1_0\t{answer_coordinates}\t{token_probabilities}\n')

        question_1 = interaction_pb2.Question(
            id='meaning-0_0', alternative_answers=[interaction_pb2.Answer()])
        question_2 = text_format.Parse(
            """
    id: "meaning-1_0"
    original_text: "Meaning of life"
    answer {
      answer_coordinates {
        row_index: 5
        column_index: 6
      }
      answer_coordinates {
        row_index: 2
        column_index: 3
      }
      answer_texts: "42"
    }
    alternative_answers {
    }
    """, interaction_pb2.Question())
        questions = {question_1.id: question_1, question_2.id: question_2}
        metrics = dict(
            hybridqa_utils.eval_cell_selection(questions, tempfile_name))

        self.assertEqual(
            metrics, {
                _AnswerType.ALL:
                _CellSelectionMetrics(recall=0.5,
                                      precision=1 / 3,
                                      non_empty=0.5,
                                      answer_len=1.5,
                                      coverage=0.5,
                                      recall_at_1=0.5,
                                      recall_at_3=0.5,
                                      recall_at_5=0.5),
                _AnswerType.MANY_IN_TEXT:
                _CellSelectionMetrics(recall=1.0,
                                      precision=1 / 3,
                                      non_empty=1.0,
                                      answer_len=3.0,
                                      coverage=1.0,
                                      recall_at_1=1.0,
                                      recall_at_3=1.0,
                                      recall_at_5=1.0),
                _AnswerType.NO_ANSWER:
                _CellSelectionMetrics(recall=0.0,
                                      precision=None,
                                      non_empty=0.0,
                                      answer_len=0.0,
                                      coverage=0.0,
                                      recall_at_1=0.0,
                                      recall_at_3=0.0,
                                      recall_at_5=0.0),
            })
Esempio n. 14
0
  def test_add_entity_descriptions_to_table(self):
    annotated_cell_ext = annotated_text_pb2.AnnotatedText.annotated_cell_ext
    table = interaction_pb2.Table(
        columns=[
            interaction_pb2.Cell(text='A'),
            interaction_pb2.Cell(text='B'),
            interaction_pb2.Cell(text='C'),
        ],
        rows=[
            interaction_pb2.Cells(cells=[
                interaction_pb2.Cell(text='0 6'),
                interaction_pb2.Cell(text='4 7'),
                interaction_pb2.Cell(text='5 6'),
            ]),
            interaction_pb2.Cells(cells=[
                interaction_pb2.Cell(text='1 7'),
                interaction_pb2.Cell(text='3 6'),
                interaction_pb2.Cell(text='5 5'),
            ]),
        ],
    )
    # Add some annotations to the table
    entities = ['0', '3']
    for row in table.rows:
      for cell in row.cells:
        for entity in entities:
          if entity in cell.text:
            cell.Extensions[annotated_cell_ext].annotations.add(
                identifier=f'/wiki/{entity}',)

    question = interaction_pb2.Question(
        id='id', text='What prime number has religious meaning?')
    descriptions = {
        '/wiki/0': ('0 (zero) is a number, and the digit used to represent ' +
                    'that number in numerals. It fulfills a central role in ' +
                    'mathematics as the additive identity of the integers.'),
        '/wiki/3':
            ('3 (three) is a number, numeral, and glyph. It is the natural ' +
             'number following 2 and preceding 4, and is the smallest odd ' +
             'prime number. It has religious or cultural significance.')
    }
    expected_table = interaction_pb2.Table()
    expected_table.CopyFrom(table)
    # Only the top two sentences are used, based on tf-idf score
    expected_table.rows[1].cells[1].text = (
        '3 6 ( It is the natural number following 2 and preceding 4, and is ' +
        'the smallest odd prime number. It has religious or cultural ' +
        'significance. )')

    table_without_titles = interaction_pb2.Table()
    table_without_titles.CopyFrom(table)
    tf_example_utils._add_entity_descriptions_to_table(
        question,
        descriptions,
        table_without_titles,
        num_results=2,
        use_entity_title=False)
    self.assertEqual(table_without_titles, expected_table)

    table_with_titles = interaction_pb2.Table()
    table_with_titles.CopyFrom(table)
    expected_table.rows[1].cells[1].text = (
        '3 6 ( 3 : It is the natural number following 2 and preceding 4, and ' +
        'is the smallest odd prime number. It has religious or cultural ' +
        'significance. )')
    tf_example_utils._add_entity_descriptions_to_table(
        question,
        descriptions,
        table_with_titles,
        num_results=2,
        use_entity_title=True)
    self.assertEqual(table_with_titles, expected_table)