Exemple #1
0
def create_index(tables, title_multiplicator, use_bm25):
    if use_bm25:
        return tfidf_baseline_utils.create_bm25_index(
            tables,
            title_multiplicator=title_multiplicator,
        )
    return tfidf_baseline_utils.create_inverted_index(
        tables=tables,
        min_rank=FLAGS.min_term_rank,
        drop_term_frequency=FLAGS.drop_term_frequency,
        title_multiplicator=title_multiplicator,
    )
Exemple #2
0
 def test_simple_bm25(self):
     expected = [("AA", [("table_0", 1.5475852968796064)]),
                 ("BB", [("table_1", 1.2426585328757855)]),
                 ("AA CC", [("table_0", 2.0749815245480145),
                            ("table_1", 0.668184203698534)])]
     index = tfidf_baseline_utils.create_bm25_index([
         interaction_pb2.Table(table_id="table_0",
                               document_title="aa aa cc"),
         interaction_pb2.Table(table_id="table_1", document_title="bb cc"),
         interaction_pb2.Table(table_id="table_2", document_title="dd"),
         interaction_pb2.Table(table_id="table_3", document_title="ee"),
         interaction_pb2.Table(table_id="table_4", document_title="ff"),
         interaction_pb2.Table(table_id="table_5", document_title="gg"),
         interaction_pb2.Table(table_id="table_6", document_title="hh"),
     ])
     for query, results in expected:
         self.assertEqual(index.retrieve(query), results)
def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")

    print("Creating output dir ...")
    tf.io.gfile.makedirs(FLAGS.output_dir)

    interaction_files = []
    for filename in tf.io.gfile.listdir(FLAGS.input_dir):
        interaction_files.append(os.path.join(FLAGS.input_dir, filename))

    tables = {}
    if FLAGS.table_file:
        print("Reading tables ...")
        tables.update({
            table.table_id: table
            for table in tfidf_baseline_utils.iterate_tables(FLAGS.table_file)
        })

    print("Adding interactions tables ...")
    for interaction_file in interaction_files:
        interactions = prediction_utils.iterate_interactions(interaction_file)
        for interaction in interactions:
            tables[interaction.table.table_id] = interaction.table

    print("Creating index ...")

    if FLAGS.index_files_pattern:
        neighbors = _get_neural_nearest_neighbors(FLAGS.index_files_pattern)
        retrieve_fn = lambda question: neighbors.get(question.id, [])
    else:
        index = tfidf_baseline_utils.create_bm25_index(
            tables=tables.values(),
            title_multiplicator=FLAGS.title_multiplicator,
            num_tables=len(tables),
        )
        retrieve_fn = lambda question: index.retrieve(question.original_text)

    print("Processing interactions ...")
    for interaction_file in interaction_files:
        interactions = list(
            prediction_utils.iterate_interactions(interaction_file))

        examples = collections.defaultdict(list)
        for interaction in interactions:
            example_id, _ = preprocess_nq_utils.parse_interaction_id(
                interaction.id)
            examples[example_id].append(interaction)

        filename = os.path.basename(interaction_file)
        is_train = "train" in filename
        output = os.path.join(FLAGS.output_dir, filename)
        with tf.io.TFRecordWriter(output) as writer:
            num_correct = 0
            with tqdm.tqdm(
                    examples.items(),
                    total=len(examples),
                    desc=filename,
                    postfix=[{
                        "prec": "0.00",
                        "multiple_tables": 0,
                        "multiple_answers": 0,
                        "no_hits": 0,
                    }],
            ) as pbar:
                for nr, example in enumerate(pbar):
                    example_id, interaction_list = example

                    questions = []
                    for interaction in interaction_list:
                        if len(interaction.questions) != 1:
                            raise ValueError(
                                f"Unexpected question in {interaction}")
                        questions.append(interaction.questions[0])

                    answers = get_answer_texts(questions)

                    if len(set(q.original_text for q in questions)) != 1:
                        raise ValueError(f"Different questions {questions}")
                    question_text = questions[0].original_text
                    scored_hits = retrieve_fn(questions[0])
                    if not scored_hits:
                        pbar.postfix[0]["no_hits"] += 1
                    candidate_hits = scored_hits[:FLAGS.max_rank]

                    correct_table_ids = {
                        interaction.table.table_id
                        for interaction in interaction_list
                    }

                    table_ids = {table_id for table_id, _ in candidate_hits}

                    if correct_table_ids & table_ids:
                        num_correct += 1
                    prec = num_correct / (nr + 1)
                    pbar.postfix[0]["prec"] = f"{prec:.2f}"
                    if len(correct_table_ids) > 1:
                        pbar.postfix[0]["multiple_tables"] += 1

                    if is_train or FLAGS.oracle_retrieval:
                        table_ids.update(correct_table_ids)

                    for table_index, table_id in enumerate(sorted(table_ids)):
                        table = tables[table_id]
                        new_interaction = interaction_pb2.Interaction()
                        new_interaction.table.CopyFrom(table)
                        new_question = new_interaction.questions.add()
                        new_question.original_text = question_text
                        _try_to_set_answer(table, answers, new_question)
                        _set_retriever_info(new_question, scored_hits,
                                            table_id)
                        new_question.answer.is_valid = True
                        if new_question.alternative_answers:
                            pbar.postfix[0]["multiple_answers"] += 1
                        if table_id in correct_table_ids:
                            new_question.answer.class_index = 1
                        else:
                            new_question.answer.class_index = 0
                            if not FLAGS.add_negatives:
                                continue
                        new_interaction.id = text_utils.get_sequence_id(
                            example_id, str(table_index))
                        new_question.id = text_utils.get_question_id(
                            new_interaction.id, position=0)
                        writer.write(new_interaction.SerializeToString())
def main(argv):
  if len(argv) > 1:
    raise app.UsageError("Too many command-line arguments.")

  print("Creating output dir ...")
  tf.io.gfile.makedirs(FLAGS.output_dir)

  interaction_files = []
  for filename in tf.io.gfile.listdir(FLAGS.input_dir):
    interaction_files.append(os.path.join(FLAGS.input_dir, filename))

  tables = {}
  if FLAGS.table_file:
    print("Reading tables ...")
    for table in tqdm.tqdm(
        tfidf_baseline_utils.iterate_tables(FLAGS.table_file), total=375_000):
      tables[table.table_id] = table

  print("Adding interactions tables ...")
  for interaction_file in interaction_files:
    interactions = prediction_utils.iterate_interactions(interaction_file)
    for interaction in interactions:
      tables[interaction.table.table_id] = interaction.table

  print("Creating index ...")
  index = tfidf_baseline_utils.create_bm25_index(
      tables=tables.values(),
      title_multiplicator=FLAGS.title_multiplicator,
      num_tables=len(tables),
  )

  print("Processing interactions ...")
  for interaction_file in interaction_files:
    interactions = list(prediction_utils.iterate_interactions(interaction_file))

    examples = collections.defaultdict(list)
    for interaction in interactions:
      example_id, _ = preprocess_nq_utils.parse_interaction_id(interaction.id)
      examples[example_id].append(interaction)

    filename = os.path.basename(interaction_file)
    filename = os.path.splitext(filename)[0]
    output = os.path.join(FLAGS.output_dir, filename + "_results.jsonl")
    with tf.io.gfile.GFile(output, "w") as file_writer:
      num_correct = 0
      with tqdm.tqdm(
          examples.items(),
          total=len(examples),
          desc=filename,
      ) as pbar:
        for nr, example in enumerate(pbar):
          example_id, interaction_list = example

          questions = []
          for interaction in interaction_list:
            if len(interaction.questions) != 1:
              raise ValueError(f"Unexpected question in {interaction}")
            questions.append(interaction.questions[0])

          if len(set(q.original_text for q in questions)) != 1:
            raise ValueError(f"Different questions {questions}")
          question_text = questions[0].original_text
          scored_hits = index.retrieve(question_text)
          scored_hits = scored_hits[:FLAGS.max_rank]

          table_scores = []
          for scored_hit in scored_hits:
            table_scores.append({
                "table_id": scored_hit[0],
                "score": -scored_hit[1],
            })

          result = {
              "query_id": example_id + "_0_0",
              "table_scores": table_scores,
          }

          file_writer.write(json.dumps(result))
          file_writer.write("\n")