コード例 #1
0
ファイル: run_task_main.py プロジェクト: vishwajeet93/tapas
def _predict_sequence_for_set(
    estimator,
    do_model_aggregation,
    use_answer_as_supervision,
    example_file,
    prediction_file,
    other_prediction_file,
):
  """Runs realistic sequence evaluation for SQA."""
  examples_by_position = exp_prediction_utils.read_classifier_dataset(
      predict_data=example_file,
      data_format='tfrecord',
      compression_type=FLAGS.compression_type,
      max_seq_length=FLAGS.max_seq_length,
      max_predictions_per_seq=_MAX_PREDICTIONS_PER_SEQ,
      add_aggregation_function_id=do_model_aggregation,
      add_classification_labels=False,
      add_answer=use_answer_as_supervision)
  result = exp_prediction_utils.compute_prediction_sequence(
      estimator=estimator, examples_by_position=examples_by_position)
  exp_prediction_utils.write_predictions(
      result,
      prediction_file,
      do_model_aggregation,
      do_model_classification=False,
      cell_classification_threshold=_CELL_CLASSIFICATION_THRESHOLD,
  )
  tf.io.gfile.copy(prediction_file, other_prediction_file, overwrite=True)
コード例 #2
0
            def _predict_and_export_metrics(
                mode,
                input_fn,
                input_file,
                interactions_file,
            ):
                """Exports model predictions and calculates denotation metric."""
                # Predict for each new checkpoint.
                tf.logging.info(
                    "Running predictor for step %d (%s).",
                    current_step,
                    checkpoint,
                )
                result = estimator.predict(
                    input_fn=input_fn,
                    checkpoint_path=checkpoint,
                )
                if FLAGS.prediction_output_dir:
                    output_dir = FLAGS.prediction_output_dir
                    tf.io.gfile.makedirs(output_dir)
                else:
                    output_dir = FLAGS.model_dir
                output_predict_file = os.path.join(
                    output_dir, f"{mode}_results_{current_step}.tsv")
                prediction_utils.write_predictions(
                    result, output_predict_file, do_model_aggregation,
                    do_model_classification,
                    FLAGS.cell_classification_threshold)

                if FLAGS.do_sequence_prediction:
                    examples_by_position = prediction_utils.read_classifier_dataset(
                        predict_data=input_file,
                        data_format=FLAGS.data_format,
                        compression_type=FLAGS.compression_type,
                        max_seq_length=FLAGS.max_seq_length,
                        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
                        add_aggregation_function_id=do_model_aggregation,
                        add_classification_labels=do_model_classification,
                        add_answer=FLAGS.use_answer_as_supervision)
                    result_sequence = prediction_utils.compute_prediction_sequence(
                        estimator=estimator,
                        examples_by_position=examples_by_position)
                    output_predict_file_sequence = os.path.join(
                        FLAGS.model_dir,
                        mode + "_results_sequence_{}.tsv").format(current_step)
                    prediction_utils.write_predictions(
                        result_sequence, output_predict_file_sequence,
                        do_model_aggregation, do_model_classification,
                        FLAGS.cell_classification_threshold)
コード例 #3
0
 def test_read_classifier_dataset(self):
   examples_by_position = prediction_utils.read_classifier_dataset(
       predict_data=self._predict_data(),
       data_format='tfrecord',
       compression_type='',
       max_seq_length=512,
       max_predictions_per_seq=20,
       add_aggregation_function_id=False,
       add_classification_labels=False,
       add_answer=False)
   # Check that we loaded something.
   self.assertNotEmpty(examples_by_position)
   for examples_by_question_id in examples_by_position.values():
     self.assertNotEmpty(examples_by_question_id)
     for example in examples_by_question_id.values():
       # Test that at least some features are there.
       self.assertIn('input_ids', example)
       self.assertIn('label_ids', example)
コード例 #4
0
  def test_compute_prediction_sequence(self):
    """Tests that `compute_prediction_sequence` does not crash."""
    examples_by_position = prediction_utils.read_classifier_dataset(
        predict_data=self._predict_data(),
        data_format='tfrecord',
        compression_type='',
        max_seq_length=512,
        max_predictions_per_seq=20,
        add_aggregation_function_id=False,
        add_classification_labels=False,
        add_answer=False)

    # Make sure that for all examples there is a predecessor. This is always
    # true for the full data but the testing data is incomplete.
    for position in range(len(examples_by_position) - 1, 0, -1):
      for example_id in examples_by_position[position]:
        if example_id not in examples_by_position[position - 1]:
          example = examples_by_position[position][example_id]
          examples_by_position[position - 1][example_id] = example

    results = prediction_utils.compute_prediction_sequence(
        estimator=self._create_estimator(),
        examples_by_position=examples_by_position)
    self.assertNotEmpty(results)
コード例 #5
0
def _predict_and_export_metrics(
    mode,
    name,
    input_fn,
    estimator,
    current_step,
    checkpoint,
    output_dir,
    do_model_aggregation,
    do_model_classification,
    output_token_answers,
):
    """Exports model predictions and calculates denotation metric.

  Args:
    mode: Prediction mode. Can be "predict" or "eval".
    name: Used as name appendix if not default.
    input_fn: Function to generate exmaples to passed into `estimator.predict`.
    estimator: The Estimator instance.
    current_step: Current checkpoint step to be evaluated.
    checkpoint: Path to the checkpoint to be evaluated.
    output_dir: Path to save predictions generated by the checkpoint.
    do_model_aggregation: Whether model does aggregation.
    do_model_classification: Whether model does classification.
    output_token_answers: If true, output answer coordinates.

  Raises:
    ValueError: if an invalid mode is passed.
  """
    # Predict for each new checkpoint.
    tf.logging.info(
        "Running predictor for step %d (%s).",
        current_step,
        checkpoint,
    )
    if mode == "predict":
        input_file = FLAGS.input_file_predict
        interactions_file = FLAGS.predict_interactions_file
    elif mode == "eval":
        input_file = FLAGS.input_file_eval
        interactions_file = FLAGS.eval_interactions_file
    else:
        raise ValueError(f"Invalid mode {mode}")
    result = estimator.predict(
        input_fn=input_fn,
        checkpoint_path=checkpoint,
    )

    base_name = mode
    if name:
        base_name = f"{mode}_{name}"

    output_predict_file = os.path.join(
        output_dir, f"{base_name}_results_{current_step}.tsv")
    prediction_utils.write_predictions(
        result,
        output_predict_file,
        do_model_aggregation,
        do_model_classification,
        FLAGS.cell_classification_threshold,
        FLAGS.output_token_probabilities,
        output_token_answers=output_token_answers,
    )

    if FLAGS.do_sequence_prediction:
        examples_by_position = prediction_utils.read_classifier_dataset(
            predict_data=input_file,
            data_format=FLAGS.data_format,
            compression_type=FLAGS.compression_type,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            add_aggregation_function_id=do_model_aggregation,
            add_classification_labels=do_model_classification,
            add_answer=FLAGS.use_answer_as_supervision)
        result_sequence = prediction_utils.compute_prediction_sequence(
            estimator=estimator, examples_by_position=examples_by_position)
        output_predict_file_sequence = os.path.join(
            FLAGS.model_dir,
            base_name + "_results_sequence_{}.tsv").format(current_step)
        prediction_utils.write_predictions(
            result_sequence, output_predict_file_sequence,
            do_model_aggregation, do_model_classification,
            FLAGS.cell_classification_threshold,
            FLAGS.output_token_probabilities, output_token_answers)
コード例 #6
0
def _predict_sequence_for_set(
    estimator,
    do_model_aggregation,
    use_answer_as_supervision,
    example_file,
    prediction_file,
    other_prediction_file,
):
  """Runs realistic sequence evaluation for SQA."""
  examples_by_position = exp_prediction_utils.read_classifier_dataset(
      predict_data=example_file,
      data_format='tfrecord',
      compression_type=FLAGS.compression_type,
      max_seq_length=FLAGS.max_seq_length,
      max_predictions_per_seq=_MAX_PREDICTIONS_PER_SEQ,
      add_aggregation_function_id=do_model_aggregation,
      add_classification_labels=False,
      add_answer=use_answer_as_supervision)
  result = exp_prediction_utils.compute_prediction_sequence(
      estimator=estimator, examples_by_position=examples_by_position)

  if (FLAGS.write_prob_table):
    for query in result:
      num_rows = max(query['row_ids'])
      result_array = np.zeros((num_rows,FLAGS.num_columns))
      for i in range(1, FLAGS.max_seq_length):
        if (query['segment_ids'][i] == 1) and (not (query['row_ids'][i-1] == query['row_ids'][i]) or not (query['column_ids'][i-1] == query['column_ids'][i])):
          row = query['row_ids'][i]
          column = FLAGS.column_order[query['column_ids'][i]-1]
          prob = query['probabilities'][i]
          result_array[row-1][query['column_ids'][i]-1] = prob
      prob_array = np.concatenate((np.load(f'{FLAGS.output_dir}/probs.npy'), result_array))
      np.save(f'{FLAGS.output_dir}/probs.npy', prob_array)

  if (FLAGS.write_embed_table):
    for query in result:
      len_embedding = len(query['embeddings'][0])
      num_rows = max(query['row_ids'])
      embed_array = np.zeros((num_rows,FLAGS.num_columns, len_embedding))
      num_array = np.zeros((num_rows,FLAGS.num_columns))
      row = 0
      column = 0
      query_array = np.zeros(len_embedding)
      query_array_num = 0
      at_beginning = True

      for i in range(1, FLAGS.max_seq_length):
        if (query['segment_ids'][i] == 0) and (at_beginning):
          if (i != 0) and (query['segment_ids'][i+1] != 1):
            query_array += np.array(query['embeddings'][i])
            query_array_num += 1
        if (query['segment_ids'][i] == 1) and not (row == query['row_ids'][i]-1 and column == int(FLAGS.column_order[query['column_ids'][i]-1])-1):
          at_beginning = False
          row = query['row_ids'][i]-1
          column = int(FLAGS.column_order[query['column_ids'][i]-1])-1
          embed_array[row][column] += np.array(query['embeddings'][i])
          num_array[row][column] += 1
        elif (query['segment_ids'][i] == 1):
          at_beginning = False
          embed_array[row][column] += np.array(query['embeddings'][i])
          num_array[row][column] += 1
      print(query_array_num)
      print(query)
      query_array = query_array/query_array_num
      result_array = np.array([embed/num for (embed_row, num_row) in zip(embed_array,num_array) for (embed, num) in zip(embed_row, num_row)])
      np.save(f'{FLAGS.output_dir}/query.npy', query_array)
      np.save(f'{FLAGS.output_dir}/embeds.npy', result_array)

  exp_prediction_utils.write_predictions(
      result,
      prediction_file,
      do_model_aggregation,
      do_model_classification=False,
      cell_classification_threshold=_CELL_CLASSIFICATION_THRESHOLD,
  )
  tf.io.gfile.copy(prediction_file, other_prediction_file, overwrite=True)