Beispiel #1
0
 def test_get_question_id(self, example_id):
     sequence_id = text_utils.get_sequence_id(example_id, "0")
     question_id = text_utils.get_question_id(sequence_id, 0)
     new_example_id, annotator, position = text_utils.parse_question_id(
         question_id)
     self.assertEqual(new_example_id, example_id)
     self.assertEqual(annotator, "0")
     self.assertEqual(position, 0)
def read_from_tsv_file(file_handle):
    """Parses a TSV file in SQA format into a list of interactions.

  Args:
    file_handle:  File handle of a TSV file in SQA format.

  Returns:
    Questions grouped into interactions.
  """
    questions = {}
    for row in csv.DictReader(file_handle, delimiter='\t'):
        sequence_id = text_utils.get_sequence_id(row[_ID], row[_ANNOTATOR])
        key = sequence_id, row[_TABLE_FILE]
        if key not in questions:
            questions[key] = {}

        position = int(row[_POSITION])

        answer = interaction_pb2.Answer()
        _parse_answer_coordinates(row[_ANSWER_COORDINATES], answer)
        _parse_answer_text(row[_ANSWER_TEXT], answer)

        if _AGGREGATION in row:
            agg_func = row[_AGGREGATION].upper().strip()
            if agg_func:
                answer.aggregation_function = _AggregationFunction.Value(
                    agg_func)
        if _ANSWER_FLOAT_VALUE in row:
            float_value = row[_ANSWER_FLOAT_VALUE]
            if float_value:
                answer.float_value = float(float_value)
        if _ANSWER_CLASS_INDEX in row:
            class_index = row[_ANSWER_CLASS_INDEX]
            if class_index:
                answer.class_index = int(class_index)

        questions[key][position] = interaction_pb2.Question(
            id=text_utils.get_question_id(sequence_id, position),
            original_text=row[_QUESTION],
            answer=answer)

    interactions = []
    for (sequence_id,
         table_file), question_dict in sorted(questions.items(),
                                              key=lambda sid: sid[0]):
        question_list = [
            question for _, question in sorted(question_dict.items(),
                                               key=lambda pos: pos[0])
        ]
        interactions.append(
            interaction_pb2.Interaction(
                id=sequence_id,
                questions=question_list,
                table=interaction_pb2.Table(table_id=table_file)))
    return interactions
def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")

    print("Creating output dir ...")
    tf.io.gfile.makedirs(FLAGS.output_dir)

    interaction_files = []
    for filename in tf.io.gfile.listdir(FLAGS.input_dir):
        interaction_files.append(os.path.join(FLAGS.input_dir, filename))

    tables = {}
    if FLAGS.table_file:
        print("Reading tables ...")
        tables.update({
            table.table_id: table
            for table in tfidf_baseline_utils.iterate_tables(FLAGS.table_file)
        })

    print("Adding interactions tables ...")
    for interaction_file in interaction_files:
        interactions = prediction_utils.iterate_interactions(interaction_file)
        for interaction in interactions:
            tables[interaction.table.table_id] = interaction.table

    print("Creating index ...")

    if FLAGS.index_files_pattern:
        neighbors = _get_neural_nearest_neighbors(FLAGS.index_files_pattern)
        retrieve_fn = lambda question: neighbors.get(question.id, [])
    else:
        index = tfidf_baseline_utils.create_bm25_index(
            tables=tables.values(),
            title_multiplicator=FLAGS.title_multiplicator,
            num_tables=len(tables),
        )
        retrieve_fn = lambda question: index.retrieve(question.original_text)

    print("Processing interactions ...")
    for interaction_file in interaction_files:
        interactions = list(
            prediction_utils.iterate_interactions(interaction_file))

        examples = collections.defaultdict(list)
        for interaction in interactions:
            example_id, _ = preprocess_nq_utils.parse_interaction_id(
                interaction.id)
            examples[example_id].append(interaction)

        filename = os.path.basename(interaction_file)
        is_train = "train" in filename
        output = os.path.join(FLAGS.output_dir, filename)
        with tf.io.TFRecordWriter(output) as writer:
            num_correct = 0
            with tqdm.tqdm(
                    examples.items(),
                    total=len(examples),
                    desc=filename,
                    postfix=[{
                        "prec": "0.00",
                        "multiple_tables": 0,
                        "multiple_answers": 0,
                        "no_hits": 0,
                    }],
            ) as pbar:
                for nr, example in enumerate(pbar):
                    example_id, interaction_list = example

                    questions = []
                    for interaction in interaction_list:
                        if len(interaction.questions) != 1:
                            raise ValueError(
                                f"Unexpected question in {interaction}")
                        questions.append(interaction.questions[0])

                    answers = get_answer_texts(questions)

                    if len(set(q.original_text for q in questions)) != 1:
                        raise ValueError(f"Different questions {questions}")
                    question_text = questions[0].original_text
                    scored_hits = retrieve_fn(questions[0])
                    if not scored_hits:
                        pbar.postfix[0]["no_hits"] += 1
                    candidate_hits = scored_hits[:FLAGS.max_rank]

                    correct_table_ids = {
                        interaction.table.table_id
                        for interaction in interaction_list
                    }

                    table_ids = {table_id for table_id, _ in candidate_hits}

                    if correct_table_ids & table_ids:
                        num_correct += 1
                    prec = num_correct / (nr + 1)
                    pbar.postfix[0]["prec"] = f"{prec:.2f}"
                    if len(correct_table_ids) > 1:
                        pbar.postfix[0]["multiple_tables"] += 1

                    if is_train or FLAGS.oracle_retrieval:
                        table_ids.update(correct_table_ids)

                    for table_index, table_id in enumerate(sorted(table_ids)):
                        table = tables[table_id]
                        new_interaction = interaction_pb2.Interaction()
                        new_interaction.table.CopyFrom(table)
                        new_question = new_interaction.questions.add()
                        new_question.original_text = question_text
                        _try_to_set_answer(table, answers, new_question)
                        _set_retriever_info(new_question, scored_hits,
                                            table_id)
                        new_question.answer.is_valid = True
                        if new_question.alternative_answers:
                            pbar.postfix[0]["multiple_answers"] += 1
                        if table_id in correct_table_ids:
                            new_question.answer.class_index = 1
                        else:
                            new_question.answer.class_index = 0
                            if not FLAGS.add_negatives:
                                continue
                        new_interaction.id = text_utils.get_sequence_id(
                            example_id, str(table_index))
                        new_question.id = text_utils.get_question_id(
                            new_interaction.id, position=0)
                        writer.write(new_interaction.SerializeToString())