Exemplo n.º 1
0
  def test_float_value(self):
    interaction = text_format.Parse(
        """
      table {
        columns { text: "Column0" }
        rows { cells { text: "a" } }
        rows { cells { text: "a" } }
      }
      questions {
        answer {
          answer_texts: "1.0"
        }
      }""", interaction_pb2.Interaction())

    question = interaction_utils_parser.parse_question(interaction.table,
                                                       interaction.questions[0],
                                                       _Mode.REMOVE_ALL)

    expected_answer = text_format.Parse(
        """
      answer_texts: "1.0"
      float_value: 1.0
    """, interaction_pb2.Answer())

    self.assertEqual(expected_answer, question.answer)
Exemplo n.º 2
0
def _convert_data(
    all_questions,
    input_file,
    tables,
):
    """Converts TabFact data to interactions format."""
    logging.info('Converting data from: %s...', input_file)

    counter = collections.Counter()  # Counter for stats.

    with tf.io.gfile.GFile(input_file) as file_in:
        for table_id in json.load(file_in):
            questions, labels, _ = all_questions[table_id]
            for i, (text, label) in enumerate(zip(questions, labels)):
                # The extra zeros are there to match SQA id format.
                question_id = f'{table_id}_{i}-0'
                question = interaction_pb2.Question(
                    id=f'{question_id}_0',
                    original_text=text,
                    answer=interaction_pb2.Answer(class_index=label))
                table = interaction_pb2.Table()
                table.CopyFrom(tables[table_id])
                yield interaction_pb2.Interaction(id=question_id,
                                                  questions=[question],
                                                  table=table)

                counter['questions'] += 1
                if counter['questions'] % 1000 == 0:
                    logging.info('Processed %d questions...',
                                 counter['questions'])

        _log_stats(counter, input_file)
Exemplo n.º 3
0
  def test_set_use_answer_text_when_single_float_answer(self):
    interaction = text_format.Parse(
        """
      table {
        columns { text: "Column0" }
        rows { cells { text: "2008.00000000000" } }
      }
      questions {
        answer {
          answer_texts: "2008.00000000000"
          float_value: 2008.0
        }
      }""", interaction_pb2.Interaction())

    question = interaction_utils_parser.parse_question(interaction.table,
                                                       interaction.questions[0],
                                                       _Mode.REMOVE_ALL)
    _set_float32_safe_interaction(interaction)
    expected_answer = text_format.Parse(
        """
          answer_coordinates {
            row_index: 0
            column_index: 0
          }
          answer_texts: "2008.00000000000"
          float_value: 2008.0
    """, interaction_pb2.Answer())
    _set_float32_safe_answer(expected_answer)
    self.assertEqual(expected_answer, question.answer)
Exemplo n.º 4
0
  def test_strategies(self, mode, expected_answer):
    interaction = text_format.Parse(
        """
      table {
        columns { text: "Column0" }
        rows { cells { text: "a" } }
        rows { cells { text: "b" } }
      }
      questions {
        answer {
          answer_coordinates {
            row_index: 0
            column_index: 0
          }
          answer_coordinates {
            row_index: 1
            column_index: 0
          }
          answer_texts: "2"
          aggregation_function: COUNT
        }
      }""", interaction_pb2.Interaction())

    question = interaction_utils_parser.parse_question(interaction.table,
                                                       interaction.questions[0],
                                                       mode)
    self.assertEqual(
        text_format.Parse(expected_answer, interaction_pb2.Answer()),
        question.answer)
Exemplo n.º 5
0
 def test_set_answer_text_strange_float_format_when_multiple_answers(self):
   interaction = text_format.Parse(
       """
     table {
       columns { text: "Column0" }
       rows { cells { text: "2008" } }
     }
     questions {
       answer {
         answer_texts: "1"
         answer_texts: "2"
         float_value: 2008.001
       }
     }""", interaction_pb2.Interaction())
   _set_float32_safe_interaction(interaction)
   question = interaction_utils_parser.parse_question(interaction.table,
                                                      interaction.questions[0],
                                                      _Mode.REMOVE_ALL)
   _set_float32_safe_interaction(interaction)
   expected_answer = text_format.Parse(
       """
         answer_texts: "2008.0009765625"
         float_value: 2008.001
   """, interaction_pb2.Answer())
   _set_float32_safe_answer(expected_answer)
   self.assertEqual(expected_answer, question.answer)
    def test_unambiguous_matching(self):
        interaction = text_format.Parse(
            """
      table {
        columns { text: "Column0" }
        rows { cells { text: "a" } }
        rows { cells { text: "ab" } }
        rows { cells { text: "b" } }
        rows { cells { text: "bc" } }
      }
      questions {
        answer {
          answer_texts: "a"
          answer_texts: "b"
        }
      }""", interaction_pb2.Interaction())

        question = interaction_utils_parser.parse_question(
            interaction.table, interaction.questions[0], _Mode.REMOVE_ALL)

        expected_answer = text_format.Parse(
            """
          answer_coordinates {
            row_index: 0
            column_index: 0
          }
          answer_coordinates {
            row_index: 2
            column_index: 0
          }
          answer_texts: "a"
          answer_texts: "b"
    """, interaction_pb2.Answer())

        self.assertEqual(expected_answer, question.answer)
Exemplo n.º 7
0
def read_from_tsv_file(file_handle):
    """Parses a TSV file in SQA format into a list of interactions.

  Args:
    file_handle:  File handle of a TSV file in SQA format.

  Returns:
    Questions grouped into interactions.
  """
    questions = {}
    for row in csv.DictReader(file_handle, delimiter='\t'):
        sequence_id = text_utils.get_sequence_id(row[_ID], row[_ANNOTATOR])
        key = sequence_id, row[_TABLE_FILE]
        if key not in questions:
            questions[key] = {}

        position = int(row[_POSITION])

        answer = interaction_pb2.Answer()
        _parse_answer_coordinates(row[_ANSWER_COORDINATES], answer)
        _parse_answer_text(row[_ANSWER_TEXT], answer)

        if _AGGREGATION in row:
            agg_func = row[_AGGREGATION].upper().strip()
            if agg_func:
                answer.aggregation_function = _AggregationFunction.Value(
                    agg_func)
        if _ANSWER_FLOAT_VALUE in row:
            float_value = row[_ANSWER_FLOAT_VALUE]
            if float_value:
                answer.float_value = float(float_value)
        if _ANSWER_CLASS_INDEX in row:
            class_index = row[_ANSWER_CLASS_INDEX]
            if class_index:
                answer.class_index = int(class_index)

        questions[key][position] = interaction_pb2.Question(
            id=text_utils.get_question_id(sequence_id, position),
            original_text=row[_QUESTION],
            answer=answer)

    interactions = []
    for (sequence_id,
         table_file), question_dict in sorted(questions.items(),
                                              key=lambda sid: sid[0]):
        question_list = [
            question for _, question in sorted(question_dict.items(),
                                               key=lambda pos: pos[0])
        ]
        interactions.append(
            interaction_pb2.Interaction(
                id=sequence_id,
                questions=question_list,
                table=interaction_pb2.Table(table_id=table_file)))
    return interactions
Exemplo n.º 8
0
 def test_convert_with_context_heading(self):
     max_seq_length = 20
     with tempfile.TemporaryDirectory() as input_dir:
         vocab_file = os.path.join(input_dir, 'vocab.txt')
         _create_vocab(vocab_file, ['a', 'b', 'c', 'd', 'e'])
         converter = tf_example_utils.ToClassifierTensorflowExample(
             config=tf_example_utils.ClassifierConversionConfig(
                 vocab_file=vocab_file,
                 max_seq_length=max_seq_length,
                 max_column_id=max_seq_length,
                 max_row_id=max_seq_length,
                 strip_column_names=False,
                 add_aggregation_candidates=False,
                 use_document_title=True,
                 use_context_title=True,
                 update_answer_coordinates=True,
             ))
         interaction = interaction_pb2.Interaction(
             table=interaction_pb2.Table(
                 document_title='E E',
                 columns=[
                     interaction_pb2.Cell(text='A'),
                     interaction_pb2.Cell(text='A B C'),
                 ],
                 rows=[
                     interaction_pb2.Cells(cells=[
                         interaction_pb2.Cell(text='A B'),
                         interaction_pb2.Cell(text='A B C'),
                     ]),
                 ],
                 context_heading='B',
             ),
             questions=[
                 interaction_pb2.Question(
                     id='id',
                     original_text='D',
                     answer=interaction_pb2.Answer(answer_texts=['B C']),
                 )
             ],
         )
         example = converter.convert(interaction, 0)
         logging.info(example)
         self.assertEqual(
             _get_int_feature(example, 'input_ids'),
             [2, 5, 3, 10, 10, 3, 7, 3, 6, 6, 7, 8, 6, 7, 6, 7, 8, 0, 0, 0])
         self.assertEqual(
             _get_int_feature(example, 'label_ids'),
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0])
def _get_interaction(interaction_id, question_id,
                     table):
  interaction = interaction_pb2.Interaction(
      id=interaction_id,
      questions=[
          interaction_pb2.Question(
              id=question_id,
              text="",
              original_text="What position does the player who played for team 1?",
              answer=interaction_pb2.Answer(
                  answer_coordinates=[
                      interaction_pb2.AnswerCoordinate(
                          row_index=0, column_index=0),
                  ],
                  answer_texts=["first_answer", "second_answer"],
              ),
          ),
      ],
  )
  interaction.table.CopyFrom(table)
  return interaction
Exemplo n.º 10
0
 def _get_interaction(self):
     return interaction_pb2.Interaction(
         table=interaction_pb2.Table(
             columns=[
                 interaction_pb2.Cell(text="A:/, c"),
                 interaction_pb2.Cell(text="B"),
                 interaction_pb2.Cell(text="C"),
             ],
             rows=[
                 interaction_pb2.Cells(cells=[
                     interaction_pb2.Cell(text="0"),
                     interaction_pb2.Cell(text="4"),
                     interaction_pb2.Cell(text="6"),
                 ]),
                 interaction_pb2.Cells(cells=[
                     interaction_pb2.Cell(text="1"),
                     interaction_pb2.Cell(text="3"),
                     interaction_pb2.Cell(text="5"),
                 ]),
             ],
         ),
         questions=[
             interaction_pb2.Question(
                 id="id-1",
                 original_text="A is 5",
                 text="A is 5",
                 answer=interaction_pb2.Answer(answer_coordinates=[
                     interaction_pb2.AnswerCoordinate(row_index=2,
                                                      column_index=2),
                     interaction_pb2.AnswerCoordinate(row_index=0,
                                                      column_index=2),
                 ])),
             interaction_pb2.Question(id="id-2",
                                      original_text="B is A",
                                      text="A is 5 B is A")
         ],
     )
Exemplo n.º 11
0
def _set_float32_safe_answer(answer):
  new_answer = interaction_pb2.Answer()
  new_answer.ParseFromString(answer.SerializeToString())
  answer.CopyFrom(new_answer)
Exemplo n.º 12
0
    def test_eval_cell_selection(self):
        tempfile_name = tempfile.mkstemp()[1]
        answer_coordinates = str(['(5, 6)', '(1, 1)', '(1, 2)'])
        token_probabilities = json.dumps([(1, 1, 0.8), (2, 1, 0.7),
                                          (2, 3, 0.4), (6, 5, 0.9)])
        with open(tempfile_name, 'w') as f:
            f.write('question_id\tanswer_coordinates\ttoken_probabilities\n')
            f.write(
                f'meaning-1_0\t{answer_coordinates}\t{token_probabilities}\n')

        question_1 = interaction_pb2.Question(
            id='meaning-0_0', alternative_answers=[interaction_pb2.Answer()])
        question_2 = text_format.Parse(
            """
    id: "meaning-1_0"
    original_text: "Meaning of life"
    answer {
      answer_coordinates {
        row_index: 5
        column_index: 6
      }
      answer_coordinates {
        row_index: 2
        column_index: 3
      }
      answer_texts: "42"
    }
    alternative_answers {
    }
    """, interaction_pb2.Question())
        questions = {question_1.id: question_1, question_2.id: question_2}
        metrics = dict(
            hybridqa_utils.eval_cell_selection(questions, tempfile_name))

        self.assertEqual(
            metrics, {
                _AnswerType.ALL:
                _CellSelectionMetrics(recall=0.5,
                                      precision=1 / 3,
                                      non_empty=0.5,
                                      answer_len=1.5,
                                      coverage=0.5,
                                      recall_at_1=0.5,
                                      recall_at_3=0.5,
                                      recall_at_5=0.5),
                _AnswerType.MANY_IN_TEXT:
                _CellSelectionMetrics(recall=1.0,
                                      precision=1 / 3,
                                      non_empty=1.0,
                                      answer_len=3.0,
                                      coverage=1.0,
                                      recall_at_1=1.0,
                                      recall_at_3=1.0,
                                      recall_at_5=1.0),
                _AnswerType.NO_ANSWER:
                _CellSelectionMetrics(recall=0.0,
                                      precision=None,
                                      non_empty=0.0,
                                      answer_len=0.0,
                                      coverage=0.0,
                                      recall_at_1=0.0,
                                      recall_at_3=0.0,
                                      recall_at_5=0.0),
            })