def _convert_single_wtq(interaction_file, prediction_file, output_file):
    """Convert predictions to WikiTablequestions format."""

    interactions = dict(
        (prediction_utils.parse_interaction_id(i.id), i)
        for i in prediction_utils.iterate_interactions(interaction_file))
    missing_interaction_ids = set(interactions.keys())

    with tf.io.gfile.GFile(output_file, 'w') as output_file:
        for prediction in prediction_utils.iterate_predictions(
                prediction_file):
            interaction_id = prediction['id']
            if interaction_id in missing_interaction_ids:
                missing_interaction_ids.remove(interaction_id)
            else:
                continue

            coordinates = prediction_utils.parse_coordinates(
                prediction['answer_coordinates'])

            denot_pred, _ = calc_metrics_utils.execute(
                int(prediction.get('pred_aggr', '0')), coordinates,
                prediction_utils.table_to_panda_frame(
                    interactions[interaction_id].table))

            answers = '\t'.join(sorted(map(str, denot_pred)))
            output_file.write('{}\t{}\n'.format(interaction_id, answers))

        for interaction_id in missing_interaction_ids:
            output_file.write('{}\n'.format(interaction_id))
示例#2
0
def read_data_examples_from_interactions(
    interactions_path):
  """Reads examples from an interactions file."""
  data_examples = {}
  for interaction in prediction_utils.iterate_interactions(interactions_path):
    for question in interaction.questions:
      data_examples[question.id] = example_from_question(interaction, question)
  return data_examples
示例#3
0
 def test_iterate_interactions(self):
   filepath = tempfile.mktemp(suffix='.tfrecord')
   interactions = [
       interaction_pb2.Interaction(id='dev_723'),
       interaction_pb2.Interaction(id='dev_456'),
       interaction_pb2.Interaction(id='dev_123'),
   ]
   with tf.io.TFRecordWriter(filepath) as writer:
     for interaction in interactions:
       writer.write(interaction.SerializeToString())
   actual_interactions = list(prediction_utils.iterate_interactions(filepath))
   self.assertEqual(interactions, actual_interactions)
def create_interactions_from_hybridqa_predictions(output_dir, ):
    """Create Interactions for HybridQA RC Task from HybridQA interactions and predictions file.

  Args:
    output_dir: Location where HybridQaRcConfig can be found

  Returns:
    A mapping of dataset split: iterable of interactions
  """
    config_path = os.path.join(output_dir, _CONFIG_NAME)
    with tf.io.gfile.GFile(config_path, 'r') as input_file:
        config = HybridQaRcConfig.from_json(input_file)
    e2e_interactions = {}

    splits_dict = config.hybridqa_prediction_filepaths
    interaction_dirs = config.interactions_dirs

    for ((split, run_id), predictions_file) in splits_dict.items():

        answer_coordinates: MutableMapping[Text,
                                           Sequence[_CoordinateType]] = {}
        for p in get_predictions(predictions_file):
            if p.token_probabilities is not None:
                ranked_cells = hybridqa_utils.get_best_cells(
                    p.token_probabilities)
            else:
                ranked_cells = p.answer_coordinates
            example_id = text_utils.get_example_id(p.interaction_id)
            answer_coordinates[
                example_id] = ranked_cells[:config.num_predictions_to_keep]
        original_interactions_file = os.path.join(interaction_dirs[run_id],
                                                  f'{split.lower()}.tfrecord')
        original_interactions = prediction_utils.iterate_interactions(
            original_interactions_file)
        interaction_iterables = []
        for original_interaction in original_interactions:
            example_id = text_utils.get_example_id(original_interaction.id)
            interaction_iterables.append(
                _create_eval_answer_interactions(
                    original_interaction,
                    answer_coordinates[example_id],
                    single_cell_examples=config.single_cell_examples))

        if len(splits_dict) == 1:
            # Drop name for backwards compatibility.
            name = split
        else:
            name = f'{split}_{run_id:02}'
        e2e_interactions[name] = itertools.chain(*interaction_iterables)

    return e2e_interactions
示例#5
0
def generate_hybridqa_codalab_predictions(interaction_file, prediction_file):
    """Generates Codaab prediction files for HybridQA Competition.

  This function generates the json prediction files used to submit to HybridQA
  competition hosted on Codalab. (go/hybridqa-competition)

  Args:
    interaction_file: A TF record file containing the examples as interactions.
    prediction_file: A TSV file that is the output of the table-classifier
      predict job on the input interactions.

  Yields:
    An iterable of json serializable python dicts.
  """
    vocab_file = _guess_vocab_file(interaction_file)
    logging.info("Vocab file: %s ", vocab_file)
    logging.info("Read: %s ", interaction_file)
    interactions = prediction_utils.iterate_interactions(interaction_file)
    logging.info("Read: %s ", prediction_file)
    predictions = prediction_utils.iterate_predictions(prediction_file)

    detokenizer = DeTokenizer(vocab_file)

    interactions_by_qid = collections.defaultdict(list)
    for interaction in interactions:
        qid = interaction.questions[0].id
        interactions_by_qid[_get_example_id(qid)].append(interaction)

    predictions_by_qid = {}
    for prediction in predictions:
        qid = prediction["question_id"]
        # TODO(eisenjulian): Select the best answer using model scores.
        predictions_by_qid[qid] = prediction

    for qid, candidates in interactions_by_qid.items():
        answer_text = ""
        results = list(
            _get_scored_candidates(
                detokenizer,
                candidates,
                predictions_by_qid,
            ))
        example_id = text_utils.get_example_id(qid)
        if results:
            best_result = max(results, key=lambda result: result.score)
            answer_text = best_result.answer

        yield {"question_id": example_id, "pred": answer_text}
示例#6
0
def evaluate_retrieval_e2e(
    interaction_file,
    prediction_file,
    references_file = None,
    vocab_file = None,
):
  """Computes e2e retrieval-QA metrics."""
  vocab_file = vocab_file or _guess_vocab_file(interaction_file)
  references = None
  logging.info("Vocab file: %s ", vocab_file)
  logging.info("Read: %s ", interaction_file)
  interactions = prediction_utils.iterate_interactions(interaction_file)
  logging.info("Read: %s ", prediction_file)
  predictions = prediction_utils.iterate_predictions(prediction_file)
  return _evaluate_retrieval_e2e(vocab_file, interactions, predictions,
                                 references)
示例#7
0
def get_cell_selection_metrics(
    interactions_file,
    predictions_file,
):
    """Evaluates cell selection results in HybridQA experiment."""
    questions = {}
    for interaction in prediction_utils.iterate_interactions(
            interactions_file):
        for question in interaction.questions:
            # Do not evaluate hidden test set examples
            if question.HasField('answer'):
                questions[question.id] = question
    if not questions:
        return {}
    cell_selection_metrics = dict(
        eval_cell_selection(questions, predictions_file))[AnswerType.ALL]
    return dataclasses.asdict(cell_selection_metrics)
示例#8
0
def main(_):

    max_table_rank = FLAGS.max_table_rank
    thresholds = [1, 5, 10, 15, max_table_rank]

    for interaction_file in FLAGS.interaction_files:
        _print(f"Test set: {interaction_file}")
        interactions = list(
            prediction_utils.iterate_interactions(interaction_file))

        for use_local_index in [True, False]:

            rows = []
            row_names = []

            for hparams in get_hparams():

                name = "local" if use_local_index else "global"
                name += "_bm25" if hparams["use_bm25"] else "_tfidf"
                name += f'_tm{hparams["multiplier"]}'

                _print(name)
                if use_local_index:
                    index = create_index(
                        tables=(i.table for i in interactions),
                        title_multiplicator=hparams["multiplier"],
                        use_bm25=hparams["use_bm25"],
                    )
                else:
                    index = create_index(
                        tables=tfidf_baseline_utils.iterate_tables(
                            FLAGS.table_file),
                        title_multiplicator=hparams["multiplier"],
                        use_bm25=hparams["use_bm25"],
                    )
                _print("... index created.")
                evaluate(index, max_table_rank, thresholds, interactions, rows)
                row_names.append(name)

                df = pd.DataFrame(rows, columns=thresholds, index=row_names)
                _print(df.to_string())
示例#9
0
def _create_examples(
    interaction_dir,
    example_dir,
    vocab_file,
    filename,
    batch_size,
    test_mode,
):
  """Creates TF example for a single dataset."""

  filename = f'{filename}.tfrecord'
  interaction_path = os.path.join(interaction_dir, filename)
  example_path = os.path.join(example_dir, filename)

  config = tf_example_utils.ClassifierConversionConfig(
      vocab_file=vocab_file,
      max_seq_length=FLAGS.max_seq_length,
      max_column_id=_MAX_TABLE_ID,
      max_row_id=_MAX_TABLE_ID,
      strip_column_names=False,
      add_aggregation_candidates=False,
  )
  converter = tf_example_utils.ToClassifierTensorflowExample(config)

  examples = []
  num_questions = 0
  num_conversion_errors = 0
  for interaction in prediction_utils.iterate_interactions(interaction_path):
    number_annotation_utils.add_numeric_values(interaction)
    for i in range(len(interaction.questions)):
      num_questions += 1

      try:
        examples.append(converter.convert(interaction, i))
      except ValueError as e:
        num_conversion_errors += 1
        logging.info("Can't convert interaction: %s error: %s", interaction.id,
                     e)
    if test_mode and len(examples) >= 100:
      break

  _print(f'Processed: {filename}')
  _print(f'Num questions processed: {num_questions}')
  _print(f'Num examples: {len(examples)}')
  _print(f'Num conversion errors: {num_conversion_errors}')

  if batch_size is None:
    random.shuffle(examples)
  else:
    # Make sure the eval sets are divisible by the test batch size since
    # otherwise examples will be dropped on TPU.
    # These examples will later be ignored when writing the predictions.
    originial_num_examples = len(examples)
    while len(examples) % batch_size != 0:
      examples.append(converter.get_empty_example())
    if originial_num_examples != len(examples):
      _print(f'Padded with {len(examples) - originial_num_examples} examples.')

  with tf.io.TFRecordWriter(
      example_path,
      options=_to_tf_compression_type(FLAGS.compression_type),
  ) as writer:
    for example in examples:
      writer.write(example.SerializeToString())
def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")

    print("Creating output dir ...")
    tf.io.gfile.makedirs(FLAGS.output_dir)

    interaction_files = []
    for filename in tf.io.gfile.listdir(FLAGS.input_dir):
        interaction_files.append(os.path.join(FLAGS.input_dir, filename))

    tables = {}
    if FLAGS.table_file:
        print("Reading tables ...")
        tables.update({
            table.table_id: table
            for table in tfidf_baseline_utils.iterate_tables(FLAGS.table_file)
        })

    print("Adding interactions tables ...")
    for interaction_file in interaction_files:
        interactions = prediction_utils.iterate_interactions(interaction_file)
        for interaction in interactions:
            tables[interaction.table.table_id] = interaction.table

    print("Creating index ...")

    if FLAGS.index_files_pattern:
        neighbors = _get_neural_nearest_neighbors(FLAGS.index_files_pattern)
        retrieve_fn = lambda question: neighbors.get(question.id, [])
    else:
        index = tfidf_baseline_utils.create_bm25_index(
            tables=tables.values(),
            title_multiplicator=FLAGS.title_multiplicator,
            num_tables=len(tables),
        )
        retrieve_fn = lambda question: index.retrieve(question.original_text)

    print("Processing interactions ...")
    for interaction_file in interaction_files:
        interactions = list(
            prediction_utils.iterate_interactions(interaction_file))

        examples = collections.defaultdict(list)
        for interaction in interactions:
            example_id, _ = preprocess_nq_utils.parse_interaction_id(
                interaction.id)
            examples[example_id].append(interaction)

        filename = os.path.basename(interaction_file)
        is_train = "train" in filename
        output = os.path.join(FLAGS.output_dir, filename)
        with tf.io.TFRecordWriter(output) as writer:
            num_correct = 0
            with tqdm.tqdm(
                    examples.items(),
                    total=len(examples),
                    desc=filename,
                    postfix=[{
                        "prec": "0.00",
                        "multiple_tables": 0,
                        "multiple_answers": 0,
                        "no_hits": 0,
                    }],
            ) as pbar:
                for nr, example in enumerate(pbar):
                    example_id, interaction_list = example

                    questions = []
                    for interaction in interaction_list:
                        if len(interaction.questions) != 1:
                            raise ValueError(
                                f"Unexpected question in {interaction}")
                        questions.append(interaction.questions[0])

                    answers = get_answer_texts(questions)

                    if len(set(q.original_text for q in questions)) != 1:
                        raise ValueError(f"Different questions {questions}")
                    question_text = questions[0].original_text
                    scored_hits = retrieve_fn(questions[0])
                    if not scored_hits:
                        pbar.postfix[0]["no_hits"] += 1
                    candidate_hits = scored_hits[:FLAGS.max_rank]

                    correct_table_ids = {
                        interaction.table.table_id
                        for interaction in interaction_list
                    }

                    table_ids = {table_id for table_id, _ in candidate_hits}

                    if correct_table_ids & table_ids:
                        num_correct += 1
                    prec = num_correct / (nr + 1)
                    pbar.postfix[0]["prec"] = f"{prec:.2f}"
                    if len(correct_table_ids) > 1:
                        pbar.postfix[0]["multiple_tables"] += 1

                    if is_train or FLAGS.oracle_retrieval:
                        table_ids.update(correct_table_ids)

                    for table_index, table_id in enumerate(sorted(table_ids)):
                        table = tables[table_id]
                        new_interaction = interaction_pb2.Interaction()
                        new_interaction.table.CopyFrom(table)
                        new_question = new_interaction.questions.add()
                        new_question.original_text = question_text
                        _try_to_set_answer(table, answers, new_question)
                        _set_retriever_info(new_question, scored_hits,
                                            table_id)
                        new_question.answer.is_valid = True
                        if new_question.alternative_answers:
                            pbar.postfix[0]["multiple_answers"] += 1
                        if table_id in correct_table_ids:
                            new_question.answer.class_index = 1
                        else:
                            new_question.answer.class_index = 0
                            if not FLAGS.add_negatives:
                                continue
                        new_interaction.id = text_utils.get_sequence_id(
                            example_id, str(table_index))
                        new_question.id = text_utils.get_question_id(
                            new_interaction.id, position=0)
                        writer.write(new_interaction.SerializeToString())
def main(argv):
  if len(argv) > 1:
    raise app.UsageError("Too many command-line arguments.")

  print("Creating output dir ...")
  tf.io.gfile.makedirs(FLAGS.output_dir)

  interaction_files = []
  for filename in tf.io.gfile.listdir(FLAGS.input_dir):
    interaction_files.append(os.path.join(FLAGS.input_dir, filename))

  tables = {}
  if FLAGS.table_file:
    print("Reading tables ...")
    for table in tqdm.tqdm(
        tfidf_baseline_utils.iterate_tables(FLAGS.table_file), total=375_000):
      tables[table.table_id] = table

  print("Adding interactions tables ...")
  for interaction_file in interaction_files:
    interactions = prediction_utils.iterate_interactions(interaction_file)
    for interaction in interactions:
      tables[interaction.table.table_id] = interaction.table

  print("Creating index ...")
  index = tfidf_baseline_utils.create_bm25_index(
      tables=tables.values(),
      title_multiplicator=FLAGS.title_multiplicator,
      num_tables=len(tables),
  )

  print("Processing interactions ...")
  for interaction_file in interaction_files:
    interactions = list(prediction_utils.iterate_interactions(interaction_file))

    examples = collections.defaultdict(list)
    for interaction in interactions:
      example_id, _ = preprocess_nq_utils.parse_interaction_id(interaction.id)
      examples[example_id].append(interaction)

    filename = os.path.basename(interaction_file)
    filename = os.path.splitext(filename)[0]
    output = os.path.join(FLAGS.output_dir, filename + "_results.jsonl")
    with tf.io.gfile.GFile(output, "w") as file_writer:
      num_correct = 0
      with tqdm.tqdm(
          examples.items(),
          total=len(examples),
          desc=filename,
      ) as pbar:
        for nr, example in enumerate(pbar):
          example_id, interaction_list = example

          questions = []
          for interaction in interaction_list:
            if len(interaction.questions) != 1:
              raise ValueError(f"Unexpected question in {interaction}")
            questions.append(interaction.questions[0])

          if len(set(q.original_text for q in questions)) != 1:
            raise ValueError(f"Different questions {questions}")
          question_text = questions[0].original_text
          scored_hits = index.retrieve(question_text)
          scored_hits = scored_hits[:FLAGS.max_rank]

          table_scores = []
          for scored_hit in scored_hits:
            table_scores.append({
                "table_id": scored_hit[0],
                "score": -scored_hit[1],
            })

          result = {
              "query_id": example_id + "_0_0",
              "table_scores": table_scores,
          }

          file_writer.write(json.dumps(result))
          file_writer.write("\n")