def test_float_value(self):
        interaction = text_format.Parse(
            """
      table {
        columns { text: "Column0" }
        rows { cells { text: "a" } }
        rows { cells { text: "a" } }
      }
      questions {
        answer {
          answer_texts: "1.0"
        }
      }""", interaction_pb2.Interaction())

        question = interaction_utils_parser.parse_question(
            interaction.table, interaction.questions[0], _Mode.REMOVE_ALL)

        expected_answer = text_format.Parse(
            """
      answer_texts: "1.0"
      float_value: 1.0
    """, interaction_pb2.Answer())

        self.assertEqual(expected_answer, question.answer)
 def test_simple_token_answer(self):
   with tempfile.TemporaryDirectory() as input_dir:
     vocab_file = os.path.join(input_dir, "vocab.txt")
     _create_vocab(vocab_file, ["answer", "wrong"])
     interactions = [
         text_format.Parse(
             """
           table {
             rows {
               cells { text: "WRONG WRONG" }
             }
           }
           questions {
             id: "example_id-0_0"
             answer {
               class_index: 1
               answer_texts: "OTHER"
             }
             alternative_answers {
               answer_texts: "ANSWER"
             }
           }
         """, interaction_pb2.Interaction()),
         text_format.Parse(
             """
           table {
             rows {
               cells { text: "ANSWER WRONG" }
             }
           }
           questions {
             id: "example_id-1_0"
             answer {
               class_index: 1
               answer_texts: "ANSWER"
             }
           }
         """, interaction_pb2.Interaction())
     ]
     predictions = [
         {
             "question_id":
                 "example_id-0_0",
             "logits_cls":
                 "0",
             "answers":
                 prediction_utils.token_answers_to_text([
                     prediction_utils.TokenAnswer(
                         row_index=-1,
                         column_index=-1,
                         begin_token_index=-1,
                         end_token_index=-1,
                         token_ids=[1],
                         score=10.0,
                     )
                 ]),
         },
         {
             "question_id":
                 "example_id-1_0",
             "logits_cls":
                 "1",
             "answers":
                 prediction_utils.token_answers_to_text([
                     prediction_utils.TokenAnswer(
                         row_index=0,
                         column_index=0,
                         begin_token_index=-1,
                         end_token_index=-1,
                         token_ids=[6, 7],
                         score=10.0,
                     )
                 ]),
         },
     ]
     result = e2e_eval_utils._evaluate_retrieval_e2e(
         vocab_file,
         interactions,
         predictions,
     )
     logging.info("result: %s", result.to_dict())
     self.assertEqual(
         {
             "answer_accuracy": 0.0,
             "answer_precision": 0.0,
             "answer_token_f1": 0.6666666666666666,
             "answer_token_precision": 0.5,
             "answer_token_recall": 1.0,
             "oracle_answer_accuracy": 0.0,
             "oracle_answer_token_f1": 0.6666666666666666,
             "table_accuracy": 1.0,
             "table_precision": 1.0,
             "table_recall": 1.0,
             "answer_accuracy_table": None,
             "answer_accuracy_passage": None,
             "answer_token_f1_table": None,
             "answer_token_f1_passage": None,
         }, result.to_dict())
Example #3
0
 def test_parse_answer_interactions_basic(self):
     table_json = {
         'uid':
         0,
         'url':
         'https://en.wikipedia.org/wiki/Earth',
         'title':
         'Earth',
         'header': [['Name', []], ['Number', []]],
         'data': [
             [['U.K', ['/wiki/UK']], ['1.2', []]],
             [['Globe', ['/wiki/World']], ['3.2', []]],
         ]
     }
     interactions = hybridqa_rc_utils._parse_answer_interactions(
         table_json,
         descriptions={
             '/wiki/World':
             'The World is the Earth and all life on it, ...',
             '/wiki/UK': 'The United Kingdom',
         },
         example={
             'question_id': 'abcd',
             'question': 'Meaning of life',
             'answer-text': '42',
             'answer-node': [[
                 'Earth',
                 [1, 0],
                 '/wiki/World',
                 'passage',
             ]]
         },
         single_cell_examples=False)
     interactions = list(interactions)
     self.assertLen(interactions, 1)
     expected_interaction = text_format.Parse(
         """
 id: "abcd/0-0"
 table {
   columns {
       text: ""
   }
   rows {
     cells {
       text: "Globe : The World is the Earth and all life on it, ..."
       [language.tapas.AnnotatedText.annotated_cell_ext] {
         annotations {
           identifier: "/wiki/World"
         }
       }
     }
   }
   table_id: "0"
   document_title: "Earth"
   document_url: "https://en.wikipedia.org/wiki/Earth"
 }
 questions {
   id: "abcd/0-0_0"
   original_text: "Meaning of life"
   answer {
     answer_texts: "42"
   }
 }
 [language.tapas.AnnotationDescription.annotation_descriptions_ext] {
   descriptions {
     key: "/wiki/World"
     value: "The World is the Earth and all life on it, ..."
   }
 }
 """, interaction_pb2.Interaction())
     self.assertEqual(interactions[0], expected_interaction)
Example #4
0
    def test_numeric_relations(self, mock_read):
        input_file = 'interaction_00.pbtxt'
        expected_counters = {
            'Conversion success': 1,
            'Example emitted': 1,
            'Input question': 1,
            'Relation Set Index: 2': 5,
            'Relation Set Index: 4': 13,
            'Found answers: <= 4': 1,
        }

        with tf.gfile.Open(os.path.join(self._test_dir,
                                        input_file)) as input_file:
            interaction = text_format.ParseLines(input_file,
                                                 interaction_pb2.Interaction())

        _set_mock_read(mock_read, [interaction])

        max_seq_length = 512

        pipeline = create_data.build_classifier_pipeline(
            input_files=['input.tfrecord'],
            output_files=[self._output_path],
            config=_ClassifierConfig(
                vocab_file=os.path.join(self._test_dir, 'vocab.txt'),
                max_seq_length=max_seq_length,
                max_column_id=512,
                max_row_id=512,
                strip_column_names=False,
                add_aggregation_candidates=False,
            ))

        result = beam.runners.direct.direct_runner.DirectRunner().run(pipeline)
        result.wait_until_finish()

        self.assertEqual(
            {
                metric_result.key.metric.name: metric_result.committed
                for metric_result in result.metrics().query()['counters']
            }, expected_counters)

        output = _read_examples(self._output_path)

        self.assertLen(output, 1)
        actual_example = output[0]

        self.assertIn('numeric_relations',
                      actual_example.features.feature.keys())
        relations = actual_example.features.feature[
            'numeric_relations'].int64_list.value

        with tf.gfile.Open(os.path.join(self._test_dir,
                                        'vocab.txt')) as vocab_file:
            vocab = [line.strip() for line in vocab_file]
        inputs = actual_example.features.feature['input_ids'].int64_list.value
        pairs = [(vocab[input_id], relation)
                 for (input_id, relation) in zip(inputs, relations)
                 if input_id > 0]
        logging.info('pairs: %s', pairs)
        self.assertSequenceEqual(pairs, [('[CLS]', 0), ('which', 0),
                                         ('cities', 0),
                                         ('had', 0), ('less', 0), ('than', 0),
                                         ('2', 0), (',', 0), ('000', 0),
                                         ('pass', 0), ('##en', 0), ('##ge', 0),
                                         ('##rs', 0), ('?', 0), ('[SEP]', 0),
                                         ('ran', 0), ('##k', 0), ('city', 0),
                                         ('pass', 0), ('##en', 0), ('##ge', 0),
                                         ('##rs', 0), ('ran', 0), ('##ki', 0),
                                         ('##ng', 0), ('air', 0), ('##li', 0),
                                         ('##ne', 0), ('1', 4), ('united', 0),
                                         ('states', 0), (',', 0), ('los', 0),
                                         ('angeles', 0), ('14', 2), (',', 2),
                                         ('7', 2), ('##4', 2), ('##9', 2),
                                         ('[EMPTY]', 0),
                                         ('al', 0), ('##as', 0), ('##ka', 0),
                                         ('air', 0), ('##li', 0), ('##ne', 0),
                                         ('##s', 0), ('2', 4), ('united', 0),
                                         ('states', 0), (',', 0), ('h', 0),
                                         ('##ous', 0), ('##ton', 0), ('5', 2),
                                         (',', 2), ('4', 2), ('##6', 2),
                                         ('##5', 2), ('[EMPTY]', 0),
                                         ('united', 0), ('e', 0), ('##x', 0),
                                         ('##p', 0), ('##re', 0), ('##s', 0),
                                         ('##s', 0), ('3', 4), ('canada', 0),
                                         (',', 0), ('c', 0), ('##al', 0),
                                         ('##ga', 0), ('##ry', 0), ('3', 2),
                                         (',', 2), ('7', 2), ('##6', 2),
                                         ('##1', 2), ('[EMPTY]', 0),
                                         ('air', 0), ('t', 0), ('##ra', 0),
                                         ('##ns', 0), ('##a', 0), ('##t', 0),
                                         (',', 0), ('west', 0), ('##j', 0),
                                         ('##et', 0), ('4', 4), ('canada', 0),
                                         (',', 0), ('s', 0), ('##as', 0),
                                         ('##ka', 0), ('##to', 0), ('##on', 0),
                                         ('2', 2), (',', 2), ('28', 2),
                                         ('##2', 2), ('4', 0), ('[EMPTY]', 0),
                                         ('5', 4), ('canada', 0), (',', 0),
                                         ('van', 0), ('##co', 0), ('##u', 0),
                                         ('##ve', 0), ('##r', 0), ('2', 2),
                                         (',', 2), ('10', 2), ('##3', 2),
                                         ('[EMPTY]', 0), ('air', 0), ('t', 0),
                                         ('##ra', 0), ('##ns', 0), ('##a', 0),
                                         ('##t', 0), ('6', 4), ('united', 0),
                                         ('states', 0), (',', 0), ('p', 0),
                                         ('##h', 0), ('##o', 0), ('##en', 0),
                                         ('##i', 0), ('##x', 0), ('1', 4),
                                         (',', 4), ('8', 4), ('##2', 4),
                                         ('##9', 4), ('1', 0), ('us', 0),
                                         ('air', 0), ('##w', 0), ('##a', 0),
                                         ('##y', 0), ('##s', 0), ('7', 4),
                                         ('canada', 0), (',', 0), ('to', 0),
                                         ('##ro', 0), ('##nt', 0), ('##o', 0),
                                         ('1', 4), (',', 4), ('20', 4),
                                         ('##2', 4), ('1', 0), ('air', 0),
                                         ('t', 0), ('##ra', 0), ('##ns', 0),
                                         ('##a', 0), ('##t', 0), (',', 0),
                                         ('can', 0), ('##j', 0), ('##et', 0),
                                         ('8', 4), ('canada', 0), (',', 0),
                                         ('ed', 0), ('##m', 0), ('##on', 0),
                                         ('##ton', 0), ('11', 4), ('##0', 4),
                                         ('[EMPTY]', 0), ('[EMPTY]', 0),
                                         ('9', 4), ('united', 0),
                                         ('states', 0), (',', 0), ('o', 0),
                                         ('##a', 0), ('##k', 0), ('##land', 0),
                                         ('10', 4), ('##7', 4), ('[EMPTY]', 0),
                                         ('[EMPTY]', 0)])
Example #5
0
    def test_end_to_end_multiple_interactions(self, mock_read):
        with tf.gfile.Open(os.path.join(self._test_dir,
                                        'interaction_01.pbtxt')) as input_file:
            interaction = text_format.ParseLines(input_file,
                                                 interaction_pb2.Interaction())

        interactions = []
        for trial in range(100):
            table_id = f'table_id_{trial}'
            new_interaction = interaction_pb2.Interaction()
            new_interaction.CopyFrom(interaction)
            new_interaction.table.table_id = table_id
            new_interaction.id = table_id
            interactions.append(new_interaction)

        _set_mock_read(mock_read, interactions)

        self._create_vocab(
            list(_RESERVED_SYMBOLS) + list(string.ascii_lowercase) +
            ['##' + letter for letter in string.ascii_lowercase])

        pipeline = create_data.build_pretraining_pipeline(
            input_file='input.tfrecord',
            output_suffix='.tfrecord',
            output_dir=self._output_path,
            config=_PretrainConfig(
                vocab_file=self._vocab_path,
                max_seq_length=40,
                max_predictions_per_seq=10,
                random_seed=5,
                masked_lm_prob=0.5,
                max_column_id=5,
                max_row_id=5,
                min_question_length=5,
                max_question_length=10,
                always_continue_cells=True,
                strip_column_names=False,
            ),
            dupe_factor=1,
            min_num_columns=0,
            min_num_rows=0,
            num_random_table_bins=10,
            num_corpus_bins=
            100000,  # High number sends all examples to train set.
            add_random_table=True,
        )

        result = beam.runners.direct.direct_runner.DirectRunner().run(pipeline)
        result.wait_until_finish()

        counters = {
            metric_result.key.metric.name: metric_result.committed
            for metric_result in result.metrics().query()['counters']
        }

        self.assertEqual(
            counters, {
                'Examples': 100,
                'Examples with tables': 100,
                'Interactions': 100,
                'Interactions without random interaction': 11,
                'Question Length: < inf': 31,
                'Question Length: <= 10': 53,
                'Question Length: <= 7': 16,
                'Real Table Size: <= 8': 100,
                'Trimmed Table Size: <= 8': 100,
                'Column Sizes: <= 8': 100,
                'Row Sizes: <= 8': 100,
                'Table Token Sizes: <= 8': 100,
                'Inputs': 100,
            })

        output = _read_examples(
            os.path.join(self._output_path, 'train.tfrecord'))
        self.assertLen(output, 100)
def add_numeric_values_fn(element):
    key, interaction = element
    new_interaction = interaction_pb2.Interaction()
    new_interaction.CopyFrom(interaction)
    number_annotation_utils.add_numeric_values(new_interaction)
    return key, new_interaction
 def test_simple_questions(self):
     with open(os.path.join(self.test_data_dir, 'questions.tsv'),
               'r') as file_handle:
         interactions = interaction_utils.read_from_tsv_file(file_handle)
     self.assertLen(interactions, 2)
     self.assertEqual(
         text_format.Parse(
             """
   id: "nt-14053-1"
   table {
     table_id: "table_csv/203_386.csv"
   }
   questions {
     id: "nt-14053-1_0"
     original_text: "who were the captains?"
     answer {
       answer_coordinates {
         row_index: 0
         column_index: 3
       }
       answer_coordinates {
         row_index: 1
         column_index: 3
       }
       answer_texts: "Heinrich Brodda"
       answer_texts: "Oskar Staudinger"
     }
   }
   questions {
     id: "nt-14053-1_1"
     original_text: "which ones lost their u-boat on may 5?"
     answer {
       answer_coordinates {
         row_index: 1
         column_index: 3
       }
       answer_coordinates {
         row_index: 2
         column_index: 3
       }
       answer_texts: "Oskar Staudinger"
       answer_texts: "Herbert Neckel"
     }
   }
   questions {
     id: "nt-14053-1_2"
     original_text: "of those, which one is not oskar staudinger?"
     answer {
       answer_coordinates {
         row_index: 2
         column_index: 3
       }
       answer_texts: "Herbert Neckel"
     }
   }
 """, interaction_pb2.Interaction()), interactions[0])
     self.assertEqual(
         text_format.Parse(
             """
   id: "nt-5431-0"
   table {
     table_id: "table_csv/204_703.csv"
   }
   questions {
     id: "nt-5431-0_0"
     original_text: "what are all the countries?"
     answer {
       answer_coordinates {
         row_index: 0
         column_index: 1
       }
       answer_coordinates {
         row_index: 1
         column_index: 1
       }
       answer_texts: "Canada (CAN)"
       answer_texts: "Russia (RUS)"
     }
   }
 """, interaction_pb2.Interaction()), interactions[1])
Example #8
0
    def test_convert_with_token_selection(self):
        max_seq_length = 12
        with tempfile.TemporaryDirectory() as input_dir:
            vocab_file = os.path.join(input_dir, 'vocab.txt')
            _create_vocab(vocab_file, range(10))
            converter = tf_example_utils.ToClassifierTensorflowExample(
                config=tf_example_utils.ClassifierConversionConfig(
                    vocab_file=vocab_file,
                    max_seq_length=max_seq_length,
                    max_column_id=max_seq_length,
                    max_row_id=max_seq_length,
                    strip_column_names=False,
                    add_aggregation_candidates=False,
                ))
            interaction = interaction_pb2.Interaction(
                table=interaction_pb2.Table(
                    columns=[
                        interaction_pb2.Cell(text='A'),
                        interaction_pb2.Cell(text='B'),
                        interaction_pb2.Cell(text='C'),
                    ],
                    rows=[
                        interaction_pb2.Cells(cells=[
                            interaction_pb2.Cell(text='0 6'),
                            interaction_pb2.Cell(text='4 7'),
                            interaction_pb2.Cell(text='5 6'),
                        ]),
                        interaction_pb2.Cells(cells=[
                            interaction_pb2.Cell(text='1 7'),
                            interaction_pb2.Cell(text='3 6'),
                            interaction_pb2.Cell(text='5 5'),
                        ]),
                    ],
                ),
                questions=[
                    interaction_pb2.Question(id='id', original_text='2')
                ],
            )
            table_coordinates = []
            for r, c, t in [(0, 0, 0), (1, 0, 0), (1, 2, 0), (2, 0, 0),
                            (2, 2, 0), (2, 2, 1)]:
                table_coordinates.append(
                    table_selection_pb2.TableSelection.TokenCoordinates(
                        row_index=r, column_index=c, token_index=t))
            interaction.questions[0].Extensions[
                table_selection_pb2.TableSelection.
                table_selection_ext].CopyFrom(
                    table_selection_pb2.TableSelection(
                        selected_tokens=table_coordinates))

            number_annotation_utils.add_numeric_values(interaction)
            example = converter.convert(interaction, 0)
            logging.info(example)
            self.assertEqual(_get_int_feature(example, 'input_ids'),
                             [2, 8, 3, 1, 6, 11, 7, 11, 11, 0, 0, 0])
            self.assertEqual(_get_int_feature(example, 'row_ids'),
                             [0, 0, 0, 0, 1, 1, 2, 2, 2, 0, 0, 0])
            self.assertEqual(_get_int_feature(example, 'column_ids'),
                             [0, 0, 0, 1, 1, 3, 1, 3, 3, 0, 0, 0])
            self.assertEqual(_get_int_feature(example, 'column_ranks'),
                             [0, 0, 0, 0, 1, 1, 2, 1, 1, 0, 0, 0])
            self.assertEqual(_get_int_feature(example, 'numeric_relations'),
                             [0, 0, 0, 0, 4, 2, 4, 2, 2, 0, 0, 0])
Example #9
0
def iterate_interaction_tables(interaction_file):
    for value in tf.python_io.tf_record_iterator(interaction_file):
        interaction = interaction_pb2.Interaction()
        interaction.ParseFromString(value)
        yield interaction.table
Example #10
0
def _parse_interaction(
    table,
    descriptions,
    example,
    counters,
):
    """Converts a single example to an interaction with a single question.

  Args:
    table: Table proto for this interaction.
    descriptions: The Wikipedia intro for each entity in the Table annotations.
    example: Question parsed from input JSON file.
    counters: Used for logging events as the interactions are parsed.

  Returns:
    Interaction proto.
  """

    interaction = interaction_pb2.Interaction()

    # We append -0 that corresponds to position annotator field
    interaction.id = example['question_id'] + '-0'
    interaction.table.CopyFrom(table)

    desc_map = interaction.Extensions[_annotation_descriptions].descriptions
    for key, value in descriptions.items():
        desc_map[key] = value

    question = interaction.questions.add()
    # We append _0 that corresponds to SQA position field
    question.id = f'{interaction.id}_0'
    question.original_text = example['question']

    # Reference answer for the question. The test set answers are hidden.
    if 'answer-text' in example:
        true_coordinates, table_only_coordinates, matched_identifiers = find_answer_coordinates(
            example['answer-text'], table, desc_map)

        question.answer.answer_texts.append(example['answer-text'])

        # We use this field to store just the table answers
        table_only_answer = question.alternative_answers.add()
        for row_index, column_index in table_only_coordinates:
            table_only_answer.answer_coordinates.add(row_index=row_index,
                                                     column_index=column_index)

        for row_index, column_index in true_coordinates:
            question.answer.answer_coordinates.add(row_index=row_index,
                                                   column_index=column_index)

        # This is used to compare the examples we find against the ones in the data.
        dataset_coordinates = frozenset(find_dataset_coordinates(example))

        if true_coordinates > dataset_coordinates:
            counters['Missing answers in dataset'] += 1
        elif true_coordinates < dataset_coordinates:
            counters['Missing answers in extraction'] += 1
        elif true_coordinates == dataset_coordinates:
            counters['Same answers'] += 1
        else:
            counters['Disjoint answers'] += 1

        counters[f'Answer type is {_get_answer_type(question).value}'] += 1

        if any(count > 1 for count in matched_identifiers.values()):
            counters['Answers in repeated identifier'] += 1

        if len(true_coordinates) > 1:
            counters['Multiple answers'] += 1

    return interaction
Example #11
0
def iterate_interactions(interactions_file):
    """Get interactions from file."""
    for value in tf.python_io.tf_record_iterator(interactions_file):
        interaction = interaction_pb2.Interaction()
        interaction.ParseFromString(value)
        yield interaction
Example #12
0
 def setUp(self):
     super().setUp()
     self.input_interaction: interaction_pb2.Interaction = text_format.Parse(
         """
 id: "1234/1-0"
 table {
   columns {text: "Name"}
   columns {text: "Subject"}
   columns {text: "Score"}
   rows {
     cells {text: "Jessica"}
     cells {
       text: "Biology"
       [language.tapas.AnnotatedText.annotated_cell_ext] {
         annotations {
           identifier: "/wiki/Biology"
         }
       }
     }
     cells {text:"98.1"}
   }
   rows {
     cells {text: "Ralph"}
     cells {
       text: "Mathematics"
       [language.tapas.AnnotatedText.annotated_cell_ext] {
         annotations {
           identifier: "/wiki/Mathematics"
         }
       }
     }
     cells {text:"67.8"}
   }
   rows {
     cells {
       text: "Erwin"
       [language.tapas.AnnotatedText.annotated_cell_ext] {
         annotations {
           identifier: "/wiki/Erwin_Schrödinger"
         }
       }
     }
     cells {
       text: "Physics"
       [language.tapas.AnnotatedText.annotated_cell_ext] {
         annotations {
           identifier: "/wiki/Physics"
         }
       }
     }
     cells {text:"99.2"}
   }
   table_id: "0"
   document_title: "Earth"
   document_url: "https://en.wikipedia.org/wiki/Earth"
 }
 questions {
   id: "1234/1-0_0"
   original_text: "How much did the developer of the theory of relativity scored in Physics in his high school?"
   answer {
     answer_texts: "99.2"
   }
 }
 [language.tapas.AnnotationDescription.annotation_descriptions_ext] {
   descriptions {
     key: "/wiki/Biology"
     value: "Biology is the natural science that studies life and living organisms."
   }
   descriptions {
     key: "/wiki/Mathematics"
     value: "The abstract science of number, quantity, and space."
   }
   descriptions {
     key: "/wiki/Physics"
     value: "The branch of science concerned with the nature and properties of matter and energy."
   }
   descriptions {
     key: "/wiki/Erwin_Schrödinger"
     value: "Nobel Prize-winning Austrian-Irish physicist who developed a number of fundamental results in quantum theory."
   }
 }
 """, interaction_pb2.Interaction())
Example #13
0
 def test_parse_answer_interactions_with_all_coordinates(self):
     table_json = {
         'uid':
         0,
         'url':
         'https://en.wikipedia.org/wiki/Earth',
         'title':
         'Earth',
         'header': [['Name', []], ['Number', []]],
         'data': [
             [['U.K', ['/wiki/UK']], ['1.2', []]],
             [['France', ['/wiki/France']], ['1.2', []]],
             [['Globe', ['/wiki/World']], ['3.2', []]],
         ]
     }
     interactions = hybridqa_rc_utils._parse_answer_interactions(
         table_json,
         descriptions={
             '/wiki/World':
             'The World is the Earth and all life on it, ...',
             '/wiki/UK':
             'The United Kingdom is a sovereign country located off the north­western coast of the European mainland.',
             '/wiki/France': 'France is a country in the Europe continent.',
         },
         example={
             'question_id': '1234',
             'question': 'What is the Number for European countries?',
             'answer-text': '1.2',
             'answer-node': [[
                 '1.2',
                 [0, 1],
                 None,
                 'table',
             ]]
         },
         single_cell_examples=True,
         use_original_coordinates=False)
     interactions = list(interactions)
     self.assertLen(interactions, 2)
     expected_interaction = text_format.Parse(
         """
 id: "1234/1-0"
 table {
   columns {
       text: ""
   }
   rows {
     cells {
       text: "1.2"
     }
   }
   table_id: "0"
   document_title: "Earth"
   document_url: "https://en.wikipedia.org/wiki/Earth"
 }
 questions {
   id: "1234/1-0_0"
   original_text: "What is the Number for European countries?"
   answer {
     answer_texts: "1.2"
   }
 }
 [language.tapas.AnnotationDescription.annotation_descriptions_ext] {
 }
 """, interaction_pb2.Interaction())
     self.assertEqual(interactions[1], expected_interaction)
     self.assertEqual(interactions[0].id, '1234/0-0')
     self.assertLen(interactions[0].table.columns, 1)
     self.assertLen(interactions[0].table.rows, 1)
     self.assertLen(interactions[0].table.rows[0].cells, 1)
  def test_empty_predictions(self):
    with tempfile.TemporaryDirectory() as input_dir:
      vocab_file = os.path.join(input_dir, "vocab.txt")
      _create_vocab(vocab_file, [])
      interactions = [
          text_format.Parse(
              """
            table {
              rows {
                cells { text: "ANSWER" }
                cells { text: "OTHER" }
              }
            }
            questions {
              id: "example_id-0_0"
              answer {
                class_index: 1
                answer_texts: "ANSWER"
              }
            }
          """, interaction_pb2.Interaction()),
          text_format.Parse(
              """
            table {
              rows {
                cells { text: "ANSWER" }
                cells { text: "OTHER" }
              }
            }
            questions {
              id: "example_id-1_0"
              answer {
                class_index: 1
                answer_texts: "ANSWER"
              }
            }
          """, interaction_pb2.Interaction())
      ]
      result = e2e_eval_utils._evaluate_retrieval_e2e(
          vocab_file,
          interactions,
          predictions=[],
      )

    logging.info("result: %s", result)
    for name, value in result.to_dict().items():
      if name in {
          "answer_accuracy_table",
          "answer_accuracy_passage",
          "answer_token_f1_table",
          "answer_token_f1_passage",
      }:
        self.assertIsNone(value)
      elif name in [
          "table_precision",
          "answer_precision",
          "answer_token_precision",
      ]:
        self.assertEqual(value, 1.0)
      else:
        self.assertEqual(value, 0.0)
def _to_contrastive_statements_fn(
    key_interaction,
    use_fake_table,
    drop_without_support_rate,
):
    """Converts pretraining interaction to contrastive interaction."""

    # Make a copy since beam functions should not manipulate inputs.
    new_interaction = interaction_pb2.Interaction()
    new_interaction.CopyFrom(key_interaction[1])
    interaction = new_interaction

    iid = interaction.table.table_id
    rng = random.Random(beam_utils.to_numpy_seed(iid))

    generated_statements = set()

    for result in contrastive_statements.get_contrastive_statements(
            rng, interaction, count_fn=_count):

        has_support, statement, contrastive_statement = result

        beam.metrics.Metrics.counter(_NS, "Pairs").inc()

        if not has_support and rng.random() < drop_without_support_rate:
            beam.metrics.Metrics.counter(
                _NS, "Pairs: Down-sampled pairs without support").inc()
            continue

        if contrastive_statement in generated_statements:
            beam.metrics.Metrics.counter(_NS, "Pairs: Duplicates").inc()
            continue

        generated_statements.add(contrastive_statement)

        new_interaction = interaction_pb2.Interaction()
        new_interaction.CopyFrom(interaction)
        del new_interaction.questions[:]

        new_interaction.id = _to_id((
            iid,
            (statement, contrastive_statement),
        ))

        if use_fake_table:
            _clear_table(new_interaction)

        new_interaction.table.table_id = new_interaction.id

        new_question = new_interaction.questions.add()
        new_question.id = _to_id((iid, statement))
        new_question.original_text = statement
        new_question.answer.class_index = 1

        new_question = new_interaction.questions.add()
        new_question.id = _to_id((iid, contrastive_statement))
        new_question.original_text = contrastive_statement
        new_question.answer.class_index = 0

        beam.metrics.Metrics.counter(_NS, "Pairs emitted").inc()
        yield new_interaction.id, new_interaction
Example #16
0
def _set_float32_safe_interaction(interaction):
  new_interaction = interaction_pb2.Interaction()
  new_interaction.ParseFromString(interaction.SerializeToString())
  interaction.CopyFrom(new_interaction)
Example #17
0
 def test_convert_with_negative_tables(self):
     max_seq_length = 12
     with tempfile.TemporaryDirectory() as input_dir:
         vocab_file = os.path.join(input_dir, 'vocab.txt')
         _create_vocab(vocab_file, range(10))
         converter = tf_example_utils.ToRetrievalTensorflowExample(
             config=tf_example_utils.RetrievalConversionConfig(
                 vocab_file=vocab_file,
                 max_seq_length=max_seq_length,
                 max_column_id=max_seq_length,
                 max_row_id=max_seq_length,
                 strip_column_names=False,
             ))
         interaction = interaction_pb2.Interaction(
             table=interaction_pb2.Table(
                 columns=[
                     interaction_pb2.Cell(text='A'),
                     interaction_pb2.Cell(text='B'),
                     interaction_pb2.Cell(text='C'),
                 ],
                 rows=[
                     interaction_pb2.Cells(cells=[
                         interaction_pb2.Cell(text='0 6'),
                         interaction_pb2.Cell(text='4 7'),
                         interaction_pb2.Cell(text='5 6'),
                     ]),
                     interaction_pb2.Cells(cells=[
                         interaction_pb2.Cell(text='1 7'),
                         interaction_pb2.Cell(text='3 6'),
                         interaction_pb2.Cell(text='5 5'),
                     ]),
                 ],
                 table_id='table_0',
             ),
             questions=[
                 interaction_pb2.Question(
                     id='id',
                     original_text='2',
                 )
             ],
         )
         number_annotation_utils.add_numeric_values(interaction)
         n_table = interaction_pb2.Table(
             columns=[
                 interaction_pb2.Cell(text='A'),
                 interaction_pb2.Cell(text='B'),
             ],
             rows=[
                 interaction_pb2.Cells(cells=[
                     interaction_pb2.Cell(text='0 6'),
                     interaction_pb2.Cell(text='4 7'),
                 ]),
                 interaction_pb2.Cells(cells=[
                     interaction_pb2.Cell(text='1 7'),
                     interaction_pb2.Cell(text='3 6'),
                 ]),
             ],
             table_id='table_1',
         )
         number_annotation_utils.add_numeric_table_values(n_table)
         n_example = _NegativeRetrievalExample()
         n_example.table.CopyFrom(n_table)
         n_example.score = -82.0
         n_example.rank = 1
         example = converter.convert(interaction, 0, n_example)
         logging.info(example)
         self.assertEqual(_get_int_feature(example, 'input_ids'), [
             2, 5, 3, 1, 1, 1, 6, 10, 11, 7, 9, 11, 2, 5, 3, 1, 1, 6, 10, 7,
             9, 0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'row_ids'), [
             0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 0, 0, 1, 1, 2, 2,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'column_ids'), [
             0, 0, 0, 1, 2, 3, 1, 2, 3, 1, 2, 3, 0, 0, 0, 1, 2, 1, 2, 1, 2,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'segment_ids'), [
             0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'input_mask'), [
             1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'inv_column_ranks'), [
             0, 0, 0, 0, 0, 0, 2, 1, 1, 1, 2, 1, 0, 0, 0, 0, 0, 2, 1, 1, 2,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'column_ranks'), [
             0, 0, 0, 0, 0, 0, 1, 2, 1, 2, 1, 1, 0, 0, 0, 0, 0, 1, 2, 2, 1,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'numeric_relations'), [
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'table_id_hash'),
                          [911224864, 1294380046])
         self.assertEqual(_get_float_feature(example, 'numeric_values'), [
             'nan', 'nan', 'nan', 'nan', 'nan', 'nan', 0.0, 4.0, 5.0, 1.0,
             3.0, 5.0, 'nan', 'nan', 'nan', 'nan', 'nan', 0.0, 4.0, 1.0,
             3.0, 'nan', 'nan', 'nan'
         ])
         self.assertEqual(
             _get_float_feature(example, 'numeric_values_scale'), [
                 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
                 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0
             ])
         self.assertEqual([
             i.decode('utf-8')
             for i in _get_byte_feature(example, 'table_id')
         ], ['table_0', 'table_1'])
 def test_questions_with_aggregation(self):
     """Tests that the most important function names can be parsed."""
     with open(
             os.path.join(self.test_data_dir, 'questions_aggregation.tsv'),
             'r') as file_handle:
         interactions = interaction_utils.read_from_tsv_file(file_handle)
     self.assertLen(interactions, 2)
     self.assertEqual(
         text_format.Parse(
             """
   id: "nt-14053-1"
   table {
     table_id: "table_csv/203_386.csv"
   }
   questions {
     id: "nt-14053-1_0"
     original_text: "who were the captains?"
     answer {
       answer_coordinates {
         row_index: 0
         column_index: 3
       }
       answer_coordinates {
         row_index: 1
         column_index: 3
       }
       answer_texts: "Heinrich Brodda"
       answer_texts: "Oskar Staudinger"
     }
   }
   questions {
     id: "nt-14053-1_1"
     original_text: "which ones lost their u-boat on may 5?"
     answer {
       answer_coordinates {
         row_index: 1
         column_index: 3
       }
       answer_texts: "Oskar Staudinger"
       aggregation_function: NONE
     }
   }
   questions {
     id: "nt-14053-1_2"
     original_text: "of those, which one is not oskar staudinger?"
     answer {
       answer_coordinates {
         row_index: 2
         column_index: 3
       }
       answer_texts: "Herbert Neckel"
       aggregation_function: NONE
     }
   }
 """, interaction_pb2.Interaction()), interactions[0])
     self.assertEqual(
         text_format.Parse(
             """
   id: "nt-4436-0"
   table {
     table_id: "table_csv/203_88.csv"
   }
   questions {
     id: "nt-4436-0_0"
     original_text: "which language has more males then females?"
     answer {
       answer_coordinates {
         row_index: 2
         column_index: 0
       }
       aggregation_function: SUM
       answer_texts: "Russian"
     }
   }
   questions {
     id: "nt-4436-0_1"
     original_text: "which of those have less than 500 males?"
     answer {
       answer_coordinates {
         row_index: 5
         column_index: 0
       }
       aggregation_function: COUNT
       answer_texts: "Romanian"
     }
   }
   questions {
     id: "nt-4436-0_2"
     original_text: "the ones have less than 20 females?"
     answer {
       answer_coordinates {
         row_index: 5
         column_index: 0
       }
       answer_coordinates {
         row_index: 7
         column_index: 0
       }
       answer_texts: "Romanian"
       answer_texts: "Estonian"
       aggregation_function: AVERAGE
     }
   }
 """, interaction_pb2.Interaction()), interactions[1])
def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")

    print("Creating output dir ...")
    tf.io.gfile.makedirs(FLAGS.output_dir)

    interaction_files = []
    for filename in tf.io.gfile.listdir(FLAGS.input_dir):
        interaction_files.append(os.path.join(FLAGS.input_dir, filename))

    tables = {}
    if FLAGS.table_file:
        print("Reading tables ...")
        tables.update({
            table.table_id: table
            for table in tfidf_baseline_utils.iterate_tables(FLAGS.table_file)
        })

    print("Adding interactions tables ...")
    for interaction_file in interaction_files:
        interactions = prediction_utils.iterate_interactions(interaction_file)
        for interaction in interactions:
            tables[interaction.table.table_id] = interaction.table

    print("Creating index ...")

    if FLAGS.index_files_pattern:
        neighbors = _get_neural_nearest_neighbors(FLAGS.index_files_pattern)
        retrieve_fn = lambda question: neighbors.get(question.id, [])
    else:
        index = tfidf_baseline_utils.create_bm25_index(
            tables=tables.values(),
            title_multiplicator=FLAGS.title_multiplicator,
            num_tables=len(tables),
        )
        retrieve_fn = lambda question: index.retrieve(question.original_text)

    print("Processing interactions ...")
    for interaction_file in interaction_files:
        interactions = list(
            prediction_utils.iterate_interactions(interaction_file))

        examples = collections.defaultdict(list)
        for interaction in interactions:
            example_id, _ = preprocess_nq_utils.parse_interaction_id(
                interaction.id)
            examples[example_id].append(interaction)

        filename = os.path.basename(interaction_file)
        is_train = "train" in filename
        output = os.path.join(FLAGS.output_dir, filename)
        with tf.io.TFRecordWriter(output) as writer:
            num_correct = 0
            with tqdm.tqdm(
                    examples.items(),
                    total=len(examples),
                    desc=filename,
                    postfix=[{
                        "prec": "0.00",
                        "multiple_tables": 0,
                        "multiple_answers": 0,
                        "no_hits": 0,
                    }],
            ) as pbar:
                for nr, example in enumerate(pbar):
                    example_id, interaction_list = example

                    questions = []
                    for interaction in interaction_list:
                        if len(interaction.questions) != 1:
                            raise ValueError(
                                f"Unexpected question in {interaction}")
                        questions.append(interaction.questions[0])

                    answers = get_answer_texts(questions)

                    if len(set(q.original_text for q in questions)) != 1:
                        raise ValueError(f"Different questions {questions}")
                    question_text = questions[0].original_text
                    scored_hits = retrieve_fn(questions[0])
                    if not scored_hits:
                        pbar.postfix[0]["no_hits"] += 1
                    candidate_hits = scored_hits[:FLAGS.max_rank]

                    correct_table_ids = {
                        interaction.table.table_id
                        for interaction in interaction_list
                    }

                    table_ids = {table_id for table_id, _ in candidate_hits}

                    if correct_table_ids & table_ids:
                        num_correct += 1
                    prec = num_correct / (nr + 1)
                    pbar.postfix[0]["prec"] = f"{prec:.2f}"
                    if len(correct_table_ids) > 1:
                        pbar.postfix[0]["multiple_tables"] += 1

                    if is_train or FLAGS.oracle_retrieval:
                        table_ids.update(correct_table_ids)

                    for table_index, table_id in enumerate(sorted(table_ids)):
                        table = tables[table_id]
                        new_interaction = interaction_pb2.Interaction()
                        new_interaction.table.CopyFrom(table)
                        new_question = new_interaction.questions.add()
                        new_question.original_text = question_text
                        _try_to_set_answer(table, answers, new_question)
                        _set_retriever_info(new_question, scored_hits,
                                            table_id)
                        new_question.answer.is_valid = True
                        if new_question.alternative_answers:
                            pbar.postfix[0]["multiple_answers"] += 1
                        if table_id in correct_table_ids:
                            new_question.answer.class_index = 1
                        else:
                            new_question.answer.class_index = 0
                            if not FLAGS.add_negatives:
                                continue
                        new_interaction.id = text_utils.get_sequence_id(
                            example_id, str(table_index))
                        new_question.id = text_utils.get_question_id(
                            new_interaction.id, position=0)
                        writer.write(new_interaction.SerializeToString())
def _parse_interaction(text_proto_line):
    interaction = text_format.Parse(text_proto_line,
                                    interaction_pb2.Interaction())
    return (interaction.id, interaction)
Example #21
0
    def test_interaction(self, prob_count_aggregation):
        config = synthesize_entablement.SynthesizationConfig(
            prob_count_aggregation=prob_count_aggregation, attempts=10)
        interaction = interaction_pb2.Interaction(
            id='i_id',
            table=interaction_pb2.Table(
                table_id='t_id',
                columns=[
                    interaction_pb2.Cell(text='Name'),
                    interaction_pb2.Cell(text='Height'),
                    interaction_pb2.Cell(text='Age')
                ],
                rows=[
                    interaction_pb2.Cells(cells=[
                        interaction_pb2.Cell(text='Peter'),
                        interaction_pb2.Cell(text='100'),
                        interaction_pb2.Cell(text='15')
                    ]),
                    interaction_pb2.Cells(cells=[
                        interaction_pb2.Cell(text='Bob'),
                        interaction_pb2.Cell(text='150'),
                        interaction_pb2.Cell(text='15')
                    ]),
                    interaction_pb2.Cells(cells=[
                        interaction_pb2.Cell(text='Tina'),
                        interaction_pb2.Cell(text='200'),
                        interaction_pb2.Cell(text='17')
                    ]),
                ]),
            questions=[])

        pos_statements = set()
        neg_statements = set()

        counter = TestCounter()

        for i in range(100):
            rng = np.random.RandomState(i)
            interactions = synthesize_entablement.synthesize_from_interaction(
                config, rng, interaction, counter)
            for new_interaction in interactions:
                question = new_interaction.questions[0]
                if question.answer.class_index == 1:
                    pos_statements.add(question.text)
                else:
                    assert question.answer.class_index == 0
                    neg_statements.add(question.text)
        self.assertEqual(neg_statements, pos_statements)
        logging.info('pos_statements: %s', pos_statements)

        counts = counter.get_counts()
        logging.info('counts: %s', counts)

        is_count_test = prob_count_aggregation == 1.0

        if is_count_test:
            self.assertGreater(len(pos_statements), 10)
            expected_statements = {
                '1 is less than the count when age is 15 and height is greater than 100',
                '1 is less than the count when height is less than 200 and age is 15',
                '1 is the count when height is greater than 100 and age is less than 17',
                '2 is the count when age is less than 17 and height is less than 200',
            }
        else:
            self.assertGreater(len(pos_statements), 100)
            expected_statements = {
                '0 is the range of age when height is greater than 100',
                '100 is less than the last height when height is less than 200',
                '125 is greater than height when name is peter',
                '15 is age when height is less than 150',
                '15 is the last age when height is less than 200',
                '150 is the last height when age is 15',
                '175 is the average height when age is less than 17',
                '200 is greater than the greatest height when age is less than 17',
                '250 is less than the total height when age is 15',
                '30 is less than the total age when height is greater than 100',
                'bob is name when age is greater than 15',
                'bob is the first name when age is 15 and height is less than 200',
                'peter is name when age is 15 and height is less than 150',
                'the average height when age is 15 is less than 175',
                'the first height when height is greater than 100 is 150',
                'the first height when height is less than 200 is 150',
                'the first name when age is 15 is name when name is peter',
                'the greatest height when age is 15 is less than 200',
                'the last age when height is greater than 100 is greater than 15',
                'the last name when age is 15 is bob',
                'the last name when age is less than 17 is peter',
                'the last name when height is greater than 100 is bob',
                'the last name when height is less than 200 is bob',
                'the lowest height when age is 15 is 150',
                'the range of age when height is greater than 100 is greater than 0',
                'the range of height when age is 15 is 100',
                'tina is name when age is greater than 15 and height is 200',
                'tina is the first name when age is 15',
                'tina is the last name when age is 15',
                'tina is the last name when height is greater than 100',
            }

        for statement in expected_statements:
            self.assertIn(statement, pos_statements)

        for name in ['pos', 'neg']:
            if is_count_test:
                self.assertGreater(counts[f'{name}: Synthesization success'],
                                   10)
                self.assertGreater(counts[f'{name}: Select: COUNT'], 10)
            else:
                self.assertEqual(counts[f'{name}: Synthesization success'],
                                 100)
                for aggregation in Aggregation:
                    self.assertGreater(
                        counts[f'{name}: Select: {aggregation.name}'], 0)
            for comparator in Comparator:
                min_count = 1 if prob_count_aggregation == 1.0 else 10
                self.assertGreater(
                    counts[f'{name}: Comparator {comparator.name}'], min_count)
                self.assertGreater(
                    counts[f'{name}: where: Comparator {comparator.name}'],
                    min_count)
Example #22
0
  def test_add_numeric_values_to_questions(self):
    actual_interaction = text_format.Parse(
        """
          questions {
            original_text: 'What are all the buildings in canada?'
          }
          questions {
            original_text: 'Which building has more than 17 floors?'
          }
          questions {
            original_text:
              'Are there one or two buildings build on March 17, 2015?'
          }""", interaction_pb2.Interaction())
    number_annotation_utils.add_numeric_values_to_questions(actual_interaction)

    expected_interaction = text_format.Parse(
        """
          questions {
            original_text: 'What are all the buildings in canada?'
            text: 'what are all the buildings in canada?'
            annotations {
            }
          }
          questions {
            original_text: 'Which building has more than 17 floors?'
            text: 'which building has more than 17 floors?'
            annotations {
             spans {
               begin_index: 29
               end_index: 31
               values {
                 float_value: 17.0
               }
             }
            }
          }
          questions {
            original_text:
              'Are there one or two buildings build on March 17, 2015?'
            text: 'are there one or two buildings build on march 17, 2015?'
            annotations {
             spans {
               begin_index: 10
               end_index: 13
               values {
                 float_value: 1.0
               }
             }
             spans {
               begin_index: 17
               end_index: 20
               values {
                 float_value: 2.0
               }
             }
             spans {
               begin_index: 40
               end_index: 54
               values {
                 date {
                   year: 2015
                   month: 3
                   day: 17
                 }
               }
             }
            }
          }""", interaction_pb2.Interaction())

    self.assertEqual(expected_interaction, actual_interaction)
Example #23
0
def _to_interaction_fn(element):
    key, table = element
    interaction = interaction_pb2.Interaction()
    interaction.table.CopyFrom(table)
    interaction.id = table.table_id
    return key, interaction
Example #24
0
    def test_gracefully_handle_big_examples(self, max_seq_length,
                                            max_column_id, max_row_id,
                                            expected_counters, mock_read):

        with tf.gfile.Open(os.path.join(self._test_dir,
                                        'interaction_02.pbtxt')) as input_file:
            interaction = text_format.ParseLines(input_file,
                                                 interaction_pb2.Interaction())

        _set_mock_read(mock_read, [interaction])

        pipeline = create_data.build_classifier_pipeline(
            input_files=['input.tfrecord'],
            output_files=[self._output_path],
            config=_ClassifierConfig(
                vocab_file=self._vocab_path,
                max_seq_length=60
                if max_seq_length is None else max_seq_length,
                max_column_id=5 if max_column_id is None else max_column_id,
                max_row_id=10 if max_row_id is None else max_row_id,
                strip_column_names=False,
                add_aggregation_candidates=False,
            ))

        result = beam.runners.direct.direct_runner.DirectRunner().run(pipeline)
        result.wait_until_finish()

        self.assertEqual(
            {
                metric_result.key.metric.name: metric_result.committed
                for metric_result in result.metrics().query()['counters']
            }, expected_counters)

        if max_seq_length is None and max_column_id is None and max_row_id is None:
            output = _read_examples(self._output_path)

            with tf.gfile.Open(
                    os.path.join(self._test_dir,
                                 'tf_example_02.pbtxt')) as input_file:
                expected_example = text_format.ParseLines(
                    input_file, tf.train.Example())
            with tf.gfile.Open(
                    os.path.join(self._test_dir,
                                 'tf_example_02_conv.pbtxt')) as input_file:
                expected_conversational_example = text_format.ParseLines(
                    input_file, tf.train.Example())

            self.assertLen(output, 2)

            actual_example = output[0]
            del actual_example.features.feature['column_ranks']
            del actual_example.features.feature['inv_column_ranks']
            del actual_example.features.feature['numeric_relations']
            del actual_example.features.feature['numeric_values']
            del actual_example.features.feature['numeric_values_scale']
            del actual_example.features.feature['question_id_ints']
            # assertEqual struggles with NaNs inside protos
            del actual_example.features.feature['answer']

            self.assertEqual(actual_example, expected_example)

            actual_example = output[1]
            del actual_example.features.feature['column_ranks']
            del actual_example.features.feature['inv_column_ranks']
            del actual_example.features.feature['numeric_relations']
            del actual_example.features.feature['numeric_values']
            del actual_example.features.feature['numeric_values_scale']
            del actual_example.features.feature['question_id_ints']
            # assertEqual struggles with NaNs inside protos
            del actual_example.features.feature['answer']

            self.assertEqual(actual_example, expected_conversational_example)
  def test_simple(self, answer):
    with tempfile.TemporaryDirectory() as input_dir:
      vocab_file = os.path.join(input_dir, "vocab.txt")
      _create_vocab(vocab_file, ["answer"])
      interactions = [
          text_format.Parse(
              """
            table {
              rows {
                cells { text: "ANSWER UNKNOWN" }
              }
            }
            questions {
              id: "example_id-0_0"
              answer {
                class_index: 1
                answer_texts: "OTHER"
              }
              alternative_answers {
                answer_texts: "ANSWER"
              }
            }
          """, interaction_pb2.Interaction()),
          text_format.Parse(
              """
            table {
              rows {
                cells { text: "ANSWER UNKNOWN" }
              }
            }
            questions {
              id: "example_id-1_0"
              answer {
                class_index: 1
                answer_texts: "ANSWER"
              }
            }
          """, interaction_pb2.Interaction())
      ]
      predictions = [
          {
              "question_id": "example_id-0_0",
              "logits_cls": "2",
              "answer": "[0]",
          },
          {
              "question_id": "example_id-1_0",
              "logits_cls": "11",
              "answer": answer,
          },
      ]
      result = e2e_eval_utils._evaluate_retrieval_e2e(
          vocab_file,
          interactions,
          predictions,
      )

    logging.info("result: %s", result)
    for name, value in result.to_dict().items():
      if name in {
          "answer_accuracy_table",
          "answer_accuracy_passage",
          "answer_token_f1_table",
          "answer_token_f1_passage",
      }:
        self.assertIsNone(value)
      elif name in [
          "table_accuracy",
          "table_recall",
      ]:
        continue
      elif name in [
          "table_precision",
      ]:
        self.assertEqual(value, 1.0)
      else:
        if answer == "[6]":
          self.assertEqual(value, 1.0)
        elif answer == "[6, 1]":
          values = {
              "answer_accuracy": 0.0,
              "answer_token_recall": 1.0,
              "answer_precision": 0.0,
              "answer_token_f1": 0.6666666666666666,
              "answer_token_precision": 0.5,
              "oracle_answer_token_f1": 0.6666666666666666,
              "oracle_answer_accuracy": 0.0,
          }
          self.assertEqual(value, values[name])
        elif answer == "[1]":
          self.assertEqual(value, 0.0)
Example #26
0
 def get_empty_example(self):
     interaction = interaction_pb2.Interaction(questions=[
         interaction_pb2.Question(id=text_utils.get_padded_question_id())
     ])
     return self.convert(interaction, index=0)
def _parse_answer_interactions(table_json,
                               descriptions,
                               example,
                               *,
                               single_cell_examples=False,
                               use_original_coordinates=True,
                               noisy_cell_strategy=NoisyCellStrategy.NONE):
    """Converts a single example to an interaction with a single question.

  Args:
    table_json: Table in JSON-format mapping.
    descriptions: The Wikipedia intro for each entity in the Table annotations.
    example: Question parsed from input JSON file.
    single_cell_examples: Generate multiple single-celled table interactions for
      each example (interaction) if this is set True.
    use_original_coordinates: Only use the coordinates of answer-text present in
      the dataset if this argument is set True, else it finds a more exhaustive
      set of coordinates where the answer-text is present --- either in the cell
      text, or in the description.
    noisy_cell_strategy: Determines the strategy for using noisy cells, which
      doesn't contain the answer-text along with the ground truth cells. If
      argument is set None, noisy cells are not included in the table.

  Yields:
    Interaction proto.
  """

    interaction = interaction_pb2.Interaction()

    # We append -0 that corresponds to position annotator field
    interaction.id = f"{example['question_id']}/{0}-0"
    desc_map = interaction.Extensions[_annotation_descriptions].descriptions
    for key, value in descriptions.items():
        desc_map[key] = value

    question = interaction.questions.add()
    # We append _0 that corresponds to SQA position field
    question.id = f'{interaction.id}_0'
    question.original_text = example['question']

    # Reference answer for the question. The test set answers are hidden.
    if 'answer-text' in example:
        question.answer.answer_texts.append(example['answer-text'])

        if use_original_coordinates:
            coordinates = hybridqa_utils.find_dataset_coordinates(example)
        else:
            original_table = hybridqa_utils.parse_table(
                table_json, descriptions)
            coordinates, _, _ = hybridqa_utils.find_answer_coordinates(
                example['answer-text'], original_table, desc_map)

        answer_coordinates = [*coordinates]

        random_seed = pretrain_utils.to_numpy_seed(interaction.id)
        random_state = np.random.RandomState(random_seed)
        n_noisy_samples = len(answer_coordinates)

        n_rows, n_columns = get_table_dimensions(table_json)
        noisy_coordinates, with_replacement = sample_noisy_coordinates(
            n_rows, n_columns, answer_coordinates, noisy_cell_strategy,
            n_noisy_samples, random_state)

        if with_replacement:
            logging.warning(
                "n_samples=%d higher than %s sample space for example_id %s. Sampled 'with-replacement' here.",
                n_noisy_samples, noisy_cell_strategy, example['question_id'])

        selected_coordinates = answer_coordinates + noisy_coordinates
        random_state.shuffle(selected_coordinates)

        if single_cell_examples:
            answer_tables = _parse_answer_tables(table_json, descriptions,
                                                 selected_coordinates)
            for table_idx, answer_table in enumerate(answer_tables):
                new_interaction = interaction_pb2.Interaction()
                new_interaction.CopyFrom(interaction)
                new_interaction.table.CopyFrom(answer_table)
                new_interaction.id = f"{example['question_id']}/{table_idx}-0"
                new_interaction.questions[0].id = f'{new_interaction.id}_0'
                remove_unreferred_annotation_descriptions(new_interaction)
                yield new_interaction

        else:
            answer_table = _parse_answer_table(table_json, descriptions,
                                               selected_coordinates)
            interaction.table.CopyFrom(answer_table)
            remove_unreferred_annotation_descriptions(interaction)
            yield interaction