Ejemplo n.º 1
0
def _predict_sequence_for_set(
    estimator,
    do_model_aggregation,
    use_answer_as_supervision,
    example_file,
    prediction_file,
    other_prediction_file,
):
  """Runs realistic sequence evaluation for SQA."""
  examples_by_position = exp_prediction_utils.read_classifier_dataset(
      predict_data=example_file,
      data_format='tfrecord',
      compression_type=FLAGS.compression_type,
      max_seq_length=FLAGS.max_seq_length,
      max_predictions_per_seq=_MAX_PREDICTIONS_PER_SEQ,
      add_aggregation_function_id=do_model_aggregation,
      add_classification_labels=False,
      add_answer=use_answer_as_supervision)
  result = exp_prediction_utils.compute_prediction_sequence(
      estimator=estimator, examples_by_position=examples_by_position)
  exp_prediction_utils.write_predictions(
      result,
      prediction_file,
      do_model_aggregation,
      do_model_classification=False,
      cell_classification_threshold=_CELL_CLASSIFICATION_THRESHOLD,
  )
  tf.io.gfile.copy(prediction_file, other_prediction_file, overwrite=True)
Ejemplo n.º 2
0
def _predict_for_set(
    estimator,
    do_model_aggregation,
    use_answer_as_supervision,
    example_file,
    prediction_file,
    other_prediction_file,
):
  """Gets predictions and writes them to TSV file."""
  # TODO also predict for dev.
  predict_input_fn = functools.partial(
      tapas_classifier_model.input_fn,
      name='predict',
      file_patterns=example_file,
      data_format='tfrecord',
      compression_type=FLAGS.compression_type,
      is_training=False,
      max_seq_length=FLAGS.max_seq_length,
      max_predictions_per_seq=_MAX_PREDICTIONS_PER_SEQ,
      add_aggregation_function_id=do_model_aggregation,
      add_classification_labels=False,
      add_answer=use_answer_as_supervision,
      include_id=False)
  result = estimator.predict(input_fn=predict_input_fn)
  exp_prediction_utils.write_predictions(
      result,
      prediction_file,
      do_model_aggregation=do_model_aggregation,
      do_model_classification=False,
      cell_classification_threshold=_CELL_CLASSIFICATION_THRESHOLD)
  tf.io.gfile.copy(prediction_file, other_prediction_file, overwrite=True)
Ejemplo n.º 3
0
    def test_write_predictions(
        self,
        span_prediction,
        do_model_aggregation,
        do_model_classification,
        cell_classification_threshold,
    ):
        estimator = self._create_estimator(
            span_prediction=span_prediction,
            num_aggregation_labels=2 if do_model_aggregation else 0,
            num_classification_labels=2 if do_model_classification else 0,
        )

        def _input_fn(params):
            return table_dataset_test_utils.create_random_dataset(
                num_examples=params['batch_size'] * 2,
                batch_size=params['batch_size'],
                repeat=False,
                generator_kwargs=self._generator_kwargs(
                    add_aggregation_function_id=do_model_aggregation,
                    add_classification_labels=do_model_classification,
                ))

        output_predict_file = tempfile.mktemp()
        prediction_utils.write_predictions(
            predictions=estimator.predict(_input_fn),
            output_predict_file=output_predict_file,
            do_model_aggregation=do_model_aggregation,
            do_model_classification=do_model_classification,
            cell_classification_threshold=cell_classification_threshold,
        )
        with open(output_predict_file, 'r') as inputfile:
            rows = list(csv.DictReader(inputfile, delimiter='\t'))
        self.assertLen(rows, (_BATCH_SIZE * 2))
        self.assertIn('question_id', rows[0])
Ejemplo n.º 4
0
            def _predict_and_export_metrics(
                mode,
                input_fn,
                input_file,
                interactions_file,
            ):
                """Exports model predictions and calculates denotation metric."""
                # Predict for each new checkpoint.
                tf.logging.info(
                    "Running predictor for step %d (%s).",
                    current_step,
                    checkpoint,
                )
                result = estimator.predict(
                    input_fn=input_fn,
                    checkpoint_path=checkpoint,
                )
                if FLAGS.prediction_output_dir:
                    output_dir = FLAGS.prediction_output_dir
                    tf.io.gfile.makedirs(output_dir)
                else:
                    output_dir = FLAGS.model_dir
                output_predict_file = os.path.join(
                    output_dir, f"{mode}_results_{current_step}.tsv")
                prediction_utils.write_predictions(
                    result, output_predict_file, do_model_aggregation,
                    do_model_classification,
                    FLAGS.cell_classification_threshold)

                if FLAGS.do_sequence_prediction:
                    examples_by_position = prediction_utils.read_classifier_dataset(
                        predict_data=input_file,
                        data_format=FLAGS.data_format,
                        compression_type=FLAGS.compression_type,
                        max_seq_length=FLAGS.max_seq_length,
                        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
                        add_aggregation_function_id=do_model_aggregation,
                        add_classification_labels=do_model_classification,
                        add_answer=FLAGS.use_answer_as_supervision)
                    result_sequence = prediction_utils.compute_prediction_sequence(
                        estimator=estimator,
                        examples_by_position=examples_by_position)
                    output_predict_file_sequence = os.path.join(
                        FLAGS.model_dir,
                        mode + "_results_sequence_{}.tsv").format(current_step)
                    prediction_utils.write_predictions(
                        result_sequence, output_predict_file_sequence,
                        do_model_aggregation, do_model_classification,
                        FLAGS.cell_classification_threshold)
Ejemplo n.º 5
0
def _predict_and_export_metrics(
    mode,
    name,
    input_fn,
    estimator,
    current_step,
    checkpoint,
    output_dir,
    do_model_aggregation,
    do_model_classification,
    output_token_answers,
):
    """Exports model predictions and calculates denotation metric.

  Args:
    mode: Prediction mode. Can be "predict" or "eval".
    name: Used as name appendix if not default.
    input_fn: Function to generate exmaples to passed into `estimator.predict`.
    estimator: The Estimator instance.
    current_step: Current checkpoint step to be evaluated.
    checkpoint: Path to the checkpoint to be evaluated.
    output_dir: Path to save predictions generated by the checkpoint.
    do_model_aggregation: Whether model does aggregation.
    do_model_classification: Whether model does classification.
    output_token_answers: If true, output answer coordinates.

  Raises:
    ValueError: if an invalid mode is passed.
  """
    # Predict for each new checkpoint.
    tf.logging.info(
        "Running predictor for step %d (%s).",
        current_step,
        checkpoint,
    )
    if mode == "predict":
        input_file = FLAGS.input_file_predict
        interactions_file = FLAGS.predict_interactions_file
    elif mode == "eval":
        input_file = FLAGS.input_file_eval
        interactions_file = FLAGS.eval_interactions_file
    else:
        raise ValueError(f"Invalid mode {mode}")
    result = estimator.predict(
        input_fn=input_fn,
        checkpoint_path=checkpoint,
    )

    base_name = mode
    if name:
        base_name = f"{mode}_{name}"

    output_predict_file = os.path.join(
        output_dir, f"{base_name}_results_{current_step}.tsv")
    prediction_utils.write_predictions(
        result,
        output_predict_file,
        do_model_aggregation,
        do_model_classification,
        FLAGS.cell_classification_threshold,
        FLAGS.output_token_probabilities,
        output_token_answers=output_token_answers,
    )

    if FLAGS.do_sequence_prediction:
        examples_by_position = prediction_utils.read_classifier_dataset(
            predict_data=input_file,
            data_format=FLAGS.data_format,
            compression_type=FLAGS.compression_type,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            add_aggregation_function_id=do_model_aggregation,
            add_classification_labels=do_model_classification,
            add_answer=FLAGS.use_answer_as_supervision)
        result_sequence = prediction_utils.compute_prediction_sequence(
            estimator=estimator, examples_by_position=examples_by_position)
        output_predict_file_sequence = os.path.join(
            FLAGS.model_dir,
            base_name + "_results_sequence_{}.tsv").format(current_step)
        prediction_utils.write_predictions(
            result_sequence, output_predict_file_sequence,
            do_model_aggregation, do_model_classification,
            FLAGS.cell_classification_threshold,
            FLAGS.output_token_probabilities, output_token_answers)
Ejemplo n.º 6
0
def _predict_sequence_for_set(
    estimator,
    do_model_aggregation,
    use_answer_as_supervision,
    example_file,
    prediction_file,
    other_prediction_file,
):
  """Runs realistic sequence evaluation for SQA."""
  examples_by_position = exp_prediction_utils.read_classifier_dataset(
      predict_data=example_file,
      data_format='tfrecord',
      compression_type=FLAGS.compression_type,
      max_seq_length=FLAGS.max_seq_length,
      max_predictions_per_seq=_MAX_PREDICTIONS_PER_SEQ,
      add_aggregation_function_id=do_model_aggregation,
      add_classification_labels=False,
      add_answer=use_answer_as_supervision)
  result = exp_prediction_utils.compute_prediction_sequence(
      estimator=estimator, examples_by_position=examples_by_position)

  if (FLAGS.write_prob_table):
    for query in result:
      num_rows = max(query['row_ids'])
      result_array = np.zeros((num_rows,FLAGS.num_columns))
      for i in range(1, FLAGS.max_seq_length):
        if (query['segment_ids'][i] == 1) and (not (query['row_ids'][i-1] == query['row_ids'][i]) or not (query['column_ids'][i-1] == query['column_ids'][i])):
          row = query['row_ids'][i]
          column = FLAGS.column_order[query['column_ids'][i]-1]
          prob = query['probabilities'][i]
          result_array[row-1][query['column_ids'][i]-1] = prob
      prob_array = np.concatenate((np.load(f'{FLAGS.output_dir}/probs.npy'), result_array))
      np.save(f'{FLAGS.output_dir}/probs.npy', prob_array)

  if (FLAGS.write_embed_table):
    for query in result:
      len_embedding = len(query['embeddings'][0])
      num_rows = max(query['row_ids'])
      embed_array = np.zeros((num_rows,FLAGS.num_columns, len_embedding))
      num_array = np.zeros((num_rows,FLAGS.num_columns))
      row = 0
      column = 0
      query_array = np.zeros(len_embedding)
      query_array_num = 0
      at_beginning = True

      for i in range(1, FLAGS.max_seq_length):
        if (query['segment_ids'][i] == 0) and (at_beginning):
          if (i != 0) and (query['segment_ids'][i+1] != 1):
            query_array += np.array(query['embeddings'][i])
            query_array_num += 1
        if (query['segment_ids'][i] == 1) and not (row == query['row_ids'][i]-1 and column == int(FLAGS.column_order[query['column_ids'][i]-1])-1):
          at_beginning = False
          row = query['row_ids'][i]-1
          column = int(FLAGS.column_order[query['column_ids'][i]-1])-1
          embed_array[row][column] += np.array(query['embeddings'][i])
          num_array[row][column] += 1
        elif (query['segment_ids'][i] == 1):
          at_beginning = False
          embed_array[row][column] += np.array(query['embeddings'][i])
          num_array[row][column] += 1
      print(query_array_num)
      print(query)
      query_array = query_array/query_array_num
      result_array = np.array([embed/num for (embed_row, num_row) in zip(embed_array,num_array) for (embed, num) in zip(embed_row, num_row)])
      np.save(f'{FLAGS.output_dir}/query.npy', query_array)
      np.save(f'{FLAGS.output_dir}/embeds.npy', result_array)

  exp_prediction_utils.write_predictions(
      result,
      prediction_file,
      do_model_aggregation,
      do_model_classification=False,
      cell_classification_threshold=_CELL_CLASSIFICATION_THRESHOLD,
  )
  tf.io.gfile.copy(prediction_file, other_prediction_file, overwrite=True)