コード例 #1
0
  def test_set_answer_text_strict(self):
    interaction = text_format.Parse(
        """
      table {
        columns { text: "Year" }
        rows { cells { text: "2008" } }
        rows { cells { text: "2010" } }
        rows { cells { text: "2008" } }
      }
      questions {
        answer {
          answer_texts: "2008"
        }
      }""", Interaction())

    try:
      result = parse_all_question(
          interaction.table,
          interaction.questions[0],
          _Mode.REMOVE_ALL_STRICT,
      )
    except ValueError as error:
      result = str(error)
    self.assertEqual(
        result,
        "Cannot parse answer: "
        "[answer_coordinates: Found multiple cells for answers]",
    )
コード例 #2
0
  def test_float_value(self):
    interaction = text_format.Parse(
        """
      table {
        columns { text: "Column0" }
        rows { cells { text: "a" } }
        rows { cells { text: "a" } }
      }
      questions {
        answer {
          answer_texts: "1.0"
        }
      }""", Interaction())

    question = parse_all_question(interaction.table,
                                                       interaction.questions[0],
                                                       _Mode.REMOVE_ALL)

    expected_answer = text_format.Parse(
        """
      answer_texts: "1.0"
      float_value: 1.0
    """, Answer())

    self.assertEqual(expected_answer, question.answer)
コード例 #3
0
  def test_set_use_answer_text_when_single_float_answer(self):
    interaction = text_format.Parse(
        """
      table {
        columns { text: "Column0" }
        rows { cells { text: "2008.00000000000" } }
      }
      questions {
        answer {
          answer_texts: "2008.00000000000"
          float_value: 2008.0
        }
      }""", Interaction())

    question = parse_all_question(interaction.table,
                                                       interaction.questions[0],
                                                       _Mode.REMOVE_ALL)
    _set_float32_safe_interaction(interaction)
    expected_answer = text_format.Parse(
        """
          answer_coordinates {
            row_index: 0
            column_index: 0
          }
          answer_texts: "2008.00000000000"
          float_value: 2008.0
    """, Answer())
    _set_float32_safe_answer(expected_answer)
    self.assertEqual(expected_answer, question.answer)
コード例 #4
0
 def test_set_answer_text_strange_float_format_when_multiple_answers(self):
   interaction = text_format.Parse(
       """
     table {
       columns { text: "Column0" }
       rows { cells { text: "2008" } }
     }
     questions {
       answer {
         answer_texts: "1"
         answer_texts: "2"
         float_value: 2008.001
       }
     }""", Interaction())
   _set_float32_safe_interaction(interaction)
   question = parse_all_question(interaction.table,
                                                      interaction.questions[0],
                                                      _Mode.REMOVE_ALL)
   _set_float32_safe_interaction(interaction)
   expected_answer = text_format.Parse(
       """
         answer_texts: "2008.0009765625"
         float_value: 2008.001
   """, Answer())
   _set_float32_safe_answer(expected_answer)
   self.assertEqual(expected_answer, question.answer)
コード例 #5
0
  def test_strategies(self, mode, expected_answer):
    interaction = text_format.Parse(
        """
      table {
        columns { text: "Column0" }
        rows { cells { text: "a" } }
        rows { cells { text: "b" } }
      }
      questions {
        answer {
          answer_coordinates {
            row_index: 0
            column_index: 0
          }
          answer_coordinates {
            row_index: 1
            column_index: 0
          }
          answer_texts: "2"
          aggregation_function: COUNT
        }
      }""", Interaction())

    question = parse_all_question(interaction.table,
                                                       interaction.questions[0],
                                                       mode)
    self.assertEqual(
        text_format.Parse(expected_answer, Answer()),
        question.answer)
コード例 #6
0
def read_from_tsv_file(file_handle):
    """Parses a TSV file in SQA format into a list of interactions.
    Args:
      file_handle:  File handle of a TSV file in SQA format.
    Returns:
      Questions grouped into interactions.
    """
    questions = {}
    for row in csv.DictReader(file_handle, delimiter='\t'):
        sequence_id = get_sequence_id(row[_ID], row[_ANNOTATOR])
        key = sequence_id, row[_TABLE_FILE]
        if key not in questions:
            questions[key] = {}

        position = int(row[_POSITION])

        answer = Answer()
        _parse_answer_coordinates(row[_ANSWER_COORDINATES], answer)
        _parse_answer_text(row[_ANSWER_TEXT], answer)

        if _AGGREGATION in row:
            agg_func = row[_AGGREGATION].upper().strip()
            if agg_func:
                answer.aggregation_function = _AggregationFunction.Value(
                    agg_func)
        if _ANSWER_FLOAT_VALUE in row:
            float_value = row[_ANSWER_FLOAT_VALUE]
            if float_value:
                answer.float_value = float(float_value)
        if _ANSWER_CLASS_INDEX in row:
            class_index = row[_ANSWER_CLASS_INDEX]
            if class_index:
                answer.class_index = int(class_index)

        questions[key][position] = Question(id=get_question_id(
            sequence_id, position),
                                            original_text=row[_QUESTION],
                                            answer=answer)

    interactions = []
    for (sequence_id,
         table_file), question_dict in sorted(questions.items(),
                                              key=lambda sid: sid[0]):
        question_list = [
            question for _, question in sorted(question_dict.items(),
                                               key=lambda pos: pos[0])
        ]
        interactions.append(
            Interaction(id=sequence_id,
                        questions=question_list,
                        table=Table(table_id=table_file)))
    return interactions
コード例 #7
0
  def test_ambiguous_matching(self):
    interaction = text_format.Parse(
        """
      table {
        columns { text: "Column0" }
        rows { cells { text: "a" } }
        rows { cells { text: "a" } }
      }
      questions {
        answer {
          answer_texts: "a"
          answer_texts: "a"
        }
      }""", Interaction())

    with self.assertRaises(ValueError):
      parse_all_question(interaction.table,
                                              interaction.questions[0],
                                              _Mode.REMOVE_ALL)
コード例 #8
0
  def test_unambiguous_matching(self):
    interaction = text_format.Parse(
        """
      table {
        columns { text: "Column0" }
        rows { cells { text: "a" } }
        rows { cells { text: "ab" } }
        rows { cells { text: "b" } }
        rows { cells { text: "bc" } }
      }
      questions {
        answer {
          answer_texts: "a"
          answer_texts: "b"
        }
      }""", Interaction())

    question = parse_all_question(interaction.table,
                                                       interaction.questions[0],
                                                       _Mode.REMOVE_ALL)

    expected_answer = text_format.Parse(
        """
          answer_coordinates {
            row_index: 0
            column_index: 0
          }
          answer_coordinates {
            row_index: 2
            column_index: 0
          }
          answer_texts: "a"
          answer_texts: "b"
    """, Answer())

    self.assertEqual(expected_answer, question.answer)
コード例 #9
0
def _set_float32_safe_interaction(interaction):
  new_interaction = Interaction()
  new_interaction.ParseFromString(interaction.SerializeToString())
  interaction.CopyFrom(new_interaction)