def test_get_question_id(self, example_id): sequence_id = text_utils.get_sequence_id(example_id, "0") question_id = text_utils.get_question_id(sequence_id, 0) new_example_id, annotator, position = text_utils.parse_question_id( question_id) self.assertEqual(new_example_id, example_id) self.assertEqual(annotator, "0") self.assertEqual(position, 0)
def read_from_tsv_file(file_handle): """Parses a TSV file in SQA format into a list of interactions. Args: file_handle: File handle of a TSV file in SQA format. Returns: Questions grouped into interactions. """ questions = {} for row in csv.DictReader(file_handle, delimiter='\t'): sequence_id = text_utils.get_sequence_id(row[_ID], row[_ANNOTATOR]) key = sequence_id, row[_TABLE_FILE] if key not in questions: questions[key] = {} position = int(row[_POSITION]) answer = interaction_pb2.Answer() _parse_answer_coordinates(row[_ANSWER_COORDINATES], answer) _parse_answer_text(row[_ANSWER_TEXT], answer) if _AGGREGATION in row: agg_func = row[_AGGREGATION].upper().strip() if agg_func: answer.aggregation_function = _AggregationFunction.Value( agg_func) if _ANSWER_FLOAT_VALUE in row: float_value = row[_ANSWER_FLOAT_VALUE] if float_value: answer.float_value = float(float_value) if _ANSWER_CLASS_INDEX in row: class_index = row[_ANSWER_CLASS_INDEX] if class_index: answer.class_index = int(class_index) questions[key][position] = interaction_pb2.Question( id=text_utils.get_question_id(sequence_id, position), original_text=row[_QUESTION], answer=answer) interactions = [] for (sequence_id, table_file), question_dict in sorted(questions.items(), key=lambda sid: sid[0]): question_list = [ question for _, question in sorted(question_dict.items(), key=lambda pos: pos[0]) ] interactions.append( interaction_pb2.Interaction( id=sequence_id, questions=question_list, table=interaction_pb2.Table(table_id=table_file))) return interactions
def main(argv): if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") print("Creating output dir ...") tf.io.gfile.makedirs(FLAGS.output_dir) interaction_files = [] for filename in tf.io.gfile.listdir(FLAGS.input_dir): interaction_files.append(os.path.join(FLAGS.input_dir, filename)) tables = {} if FLAGS.table_file: print("Reading tables ...") tables.update({ table.table_id: table for table in tfidf_baseline_utils.iterate_tables(FLAGS.table_file) }) print("Adding interactions tables ...") for interaction_file in interaction_files: interactions = prediction_utils.iterate_interactions(interaction_file) for interaction in interactions: tables[interaction.table.table_id] = interaction.table print("Creating index ...") if FLAGS.index_files_pattern: neighbors = _get_neural_nearest_neighbors(FLAGS.index_files_pattern) retrieve_fn = lambda question: neighbors.get(question.id, []) else: index = tfidf_baseline_utils.create_bm25_index( tables=tables.values(), title_multiplicator=FLAGS.title_multiplicator, num_tables=len(tables), ) retrieve_fn = lambda question: index.retrieve(question.original_text) print("Processing interactions ...") for interaction_file in interaction_files: interactions = list( prediction_utils.iterate_interactions(interaction_file)) examples = collections.defaultdict(list) for interaction in interactions: example_id, _ = preprocess_nq_utils.parse_interaction_id( interaction.id) examples[example_id].append(interaction) filename = os.path.basename(interaction_file) is_train = "train" in filename output = os.path.join(FLAGS.output_dir, filename) with tf.io.TFRecordWriter(output) as writer: num_correct = 0 with tqdm.tqdm( examples.items(), total=len(examples), desc=filename, postfix=[{ "prec": "0.00", "multiple_tables": 0, "multiple_answers": 0, "no_hits": 0, }], ) as pbar: for nr, example in enumerate(pbar): example_id, interaction_list = example questions = [] for interaction in interaction_list: if len(interaction.questions) != 1: raise ValueError( f"Unexpected question in {interaction}") questions.append(interaction.questions[0]) answers = get_answer_texts(questions) if len(set(q.original_text for q in questions)) != 1: raise ValueError(f"Different questions {questions}") question_text = questions[0].original_text scored_hits = retrieve_fn(questions[0]) if not scored_hits: pbar.postfix[0]["no_hits"] += 1 candidate_hits = scored_hits[:FLAGS.max_rank] correct_table_ids = { interaction.table.table_id for interaction in interaction_list } table_ids = {table_id for table_id, _ in candidate_hits} if correct_table_ids & table_ids: num_correct += 1 prec = num_correct / (nr + 1) pbar.postfix[0]["prec"] = f"{prec:.2f}" if len(correct_table_ids) > 1: pbar.postfix[0]["multiple_tables"] += 1 if is_train or FLAGS.oracle_retrieval: table_ids.update(correct_table_ids) for table_index, table_id in enumerate(sorted(table_ids)): table = tables[table_id] new_interaction = interaction_pb2.Interaction() new_interaction.table.CopyFrom(table) new_question = new_interaction.questions.add() new_question.original_text = question_text _try_to_set_answer(table, answers, new_question) _set_retriever_info(new_question, scored_hits, table_id) new_question.answer.is_valid = True if new_question.alternative_answers: pbar.postfix[0]["multiple_answers"] += 1 if table_id in correct_table_ids: new_question.answer.class_index = 1 else: new_question.answer.class_index = 0 if not FLAGS.add_negatives: continue new_interaction.id = text_utils.get_sequence_id( example_id, str(table_index)) new_question.id = text_utils.get_question_id( new_interaction.id, position=0) writer.write(new_interaction.SerializeToString())