def test_set_answer_text_strict(self): interaction = text_format.Parse( """ table { columns { text: "Year" } rows { cells { text: "2008" } } rows { cells { text: "2010" } } rows { cells { text: "2008" } } } questions { answer { answer_texts: "2008" } }""", Interaction()) try: result = parse_all_question( interaction.table, interaction.questions[0], _Mode.REMOVE_ALL_STRICT, ) except ValueError as error: result = str(error) self.assertEqual( result, "Cannot parse answer: " "[answer_coordinates: Found multiple cells for answers]", )
def test_float_value(self): interaction = text_format.Parse( """ table { columns { text: "Column0" } rows { cells { text: "a" } } rows { cells { text: "a" } } } questions { answer { answer_texts: "1.0" } }""", Interaction()) question = parse_all_question(interaction.table, interaction.questions[0], _Mode.REMOVE_ALL) expected_answer = text_format.Parse( """ answer_texts: "1.0" float_value: 1.0 """, Answer()) self.assertEqual(expected_answer, question.answer)
def test_set_use_answer_text_when_single_float_answer(self): interaction = text_format.Parse( """ table { columns { text: "Column0" } rows { cells { text: "2008.00000000000" } } } questions { answer { answer_texts: "2008.00000000000" float_value: 2008.0 } }""", Interaction()) question = parse_all_question(interaction.table, interaction.questions[0], _Mode.REMOVE_ALL) _set_float32_safe_interaction(interaction) expected_answer = text_format.Parse( """ answer_coordinates { row_index: 0 column_index: 0 } answer_texts: "2008.00000000000" float_value: 2008.0 """, Answer()) _set_float32_safe_answer(expected_answer) self.assertEqual(expected_answer, question.answer)
def test_set_answer_text_strange_float_format_when_multiple_answers(self): interaction = text_format.Parse( """ table { columns { text: "Column0" } rows { cells { text: "2008" } } } questions { answer { answer_texts: "1" answer_texts: "2" float_value: 2008.001 } }""", Interaction()) _set_float32_safe_interaction(interaction) question = parse_all_question(interaction.table, interaction.questions[0], _Mode.REMOVE_ALL) _set_float32_safe_interaction(interaction) expected_answer = text_format.Parse( """ answer_texts: "2008.0009765625" float_value: 2008.001 """, Answer()) _set_float32_safe_answer(expected_answer) self.assertEqual(expected_answer, question.answer)
def test_strategies(self, mode, expected_answer): interaction = text_format.Parse( """ table { columns { text: "Column0" } rows { cells { text: "a" } } rows { cells { text: "b" } } } questions { answer { answer_coordinates { row_index: 0 column_index: 0 } answer_coordinates { row_index: 1 column_index: 0 } answer_texts: "2" aggregation_function: COUNT } }""", Interaction()) question = parse_all_question(interaction.table, interaction.questions[0], mode) self.assertEqual( text_format.Parse(expected_answer, Answer()), question.answer)
def read_from_tsv_file(file_handle): """Parses a TSV file in SQA format into a list of interactions. Args: file_handle: File handle of a TSV file in SQA format. Returns: Questions grouped into interactions. """ questions = {} for row in csv.DictReader(file_handle, delimiter='\t'): sequence_id = get_sequence_id(row[_ID], row[_ANNOTATOR]) key = sequence_id, row[_TABLE_FILE] if key not in questions: questions[key] = {} position = int(row[_POSITION]) answer = Answer() _parse_answer_coordinates(row[_ANSWER_COORDINATES], answer) _parse_answer_text(row[_ANSWER_TEXT], answer) if _AGGREGATION in row: agg_func = row[_AGGREGATION].upper().strip() if agg_func: answer.aggregation_function = _AggregationFunction.Value( agg_func) if _ANSWER_FLOAT_VALUE in row: float_value = row[_ANSWER_FLOAT_VALUE] if float_value: answer.float_value = float(float_value) if _ANSWER_CLASS_INDEX in row: class_index = row[_ANSWER_CLASS_INDEX] if class_index: answer.class_index = int(class_index) questions[key][position] = Question(id=get_question_id( sequence_id, position), original_text=row[_QUESTION], answer=answer) interactions = [] for (sequence_id, table_file), question_dict in sorted(questions.items(), key=lambda sid: sid[0]): question_list = [ question for _, question in sorted(question_dict.items(), key=lambda pos: pos[0]) ] interactions.append( Interaction(id=sequence_id, questions=question_list, table=Table(table_id=table_file))) return interactions
def test_ambiguous_matching(self): interaction = text_format.Parse( """ table { columns { text: "Column0" } rows { cells { text: "a" } } rows { cells { text: "a" } } } questions { answer { answer_texts: "a" answer_texts: "a" } }""", Interaction()) with self.assertRaises(ValueError): parse_all_question(interaction.table, interaction.questions[0], _Mode.REMOVE_ALL)
def test_unambiguous_matching(self): interaction = text_format.Parse( """ table { columns { text: "Column0" } rows { cells { text: "a" } } rows { cells { text: "ab" } } rows { cells { text: "b" } } rows { cells { text: "bc" } } } questions { answer { answer_texts: "a" answer_texts: "b" } }""", Interaction()) question = parse_all_question(interaction.table, interaction.questions[0], _Mode.REMOVE_ALL) expected_answer = text_format.Parse( """ answer_coordinates { row_index: 0 column_index: 0 } answer_coordinates { row_index: 2 column_index: 0 } answer_texts: "a" answer_texts: "b" """, Answer()) self.assertEqual(expected_answer, question.answer)
def _set_float32_safe_interaction(interaction): new_interaction = Interaction() new_interaction.ParseFromString(interaction.SerializeToString()) interaction.CopyFrom(new_interaction)