def _predict_sequence_for_set( estimator, do_model_aggregation, use_answer_as_supervision, example_file, prediction_file, other_prediction_file, ): """Runs realistic sequence evaluation for SQA.""" examples_by_position = exp_prediction_utils.read_classifier_dataset( predict_data=example_file, data_format='tfrecord', compression_type=FLAGS.compression_type, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=_MAX_PREDICTIONS_PER_SEQ, add_aggregation_function_id=do_model_aggregation, add_classification_labels=False, add_answer=use_answer_as_supervision) result = exp_prediction_utils.compute_prediction_sequence( estimator=estimator, examples_by_position=examples_by_position) exp_prediction_utils.write_predictions( result, prediction_file, do_model_aggregation, do_model_classification=False, cell_classification_threshold=_CELL_CLASSIFICATION_THRESHOLD, ) tf.io.gfile.copy(prediction_file, other_prediction_file, overwrite=True)
def _predict_for_set( estimator, do_model_aggregation, use_answer_as_supervision, example_file, prediction_file, other_prediction_file, ): """Gets predictions and writes them to TSV file.""" # TODO also predict for dev. predict_input_fn = functools.partial( tapas_classifier_model.input_fn, name='predict', file_patterns=example_file, data_format='tfrecord', compression_type=FLAGS.compression_type, is_training=False, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=_MAX_PREDICTIONS_PER_SEQ, add_aggregation_function_id=do_model_aggregation, add_classification_labels=False, add_answer=use_answer_as_supervision, include_id=False) result = estimator.predict(input_fn=predict_input_fn) exp_prediction_utils.write_predictions( result, prediction_file, do_model_aggregation=do_model_aggregation, do_model_classification=False, cell_classification_threshold=_CELL_CLASSIFICATION_THRESHOLD) tf.io.gfile.copy(prediction_file, other_prediction_file, overwrite=True)
def test_write_predictions( self, span_prediction, do_model_aggregation, do_model_classification, cell_classification_threshold, ): estimator = self._create_estimator( span_prediction=span_prediction, num_aggregation_labels=2 if do_model_aggregation else 0, num_classification_labels=2 if do_model_classification else 0, ) def _input_fn(params): return table_dataset_test_utils.create_random_dataset( num_examples=params['batch_size'] * 2, batch_size=params['batch_size'], repeat=False, generator_kwargs=self._generator_kwargs( add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, )) output_predict_file = tempfile.mktemp() prediction_utils.write_predictions( predictions=estimator.predict(_input_fn), output_predict_file=output_predict_file, do_model_aggregation=do_model_aggregation, do_model_classification=do_model_classification, cell_classification_threshold=cell_classification_threshold, ) with open(output_predict_file, 'r') as inputfile: rows = list(csv.DictReader(inputfile, delimiter='\t')) self.assertLen(rows, (_BATCH_SIZE * 2)) self.assertIn('question_id', rows[0])
def _predict_and_export_metrics( mode, input_fn, input_file, interactions_file, ): """Exports model predictions and calculates denotation metric.""" # Predict for each new checkpoint. tf.logging.info( "Running predictor for step %d (%s).", current_step, checkpoint, ) result = estimator.predict( input_fn=input_fn, checkpoint_path=checkpoint, ) if FLAGS.prediction_output_dir: output_dir = FLAGS.prediction_output_dir tf.io.gfile.makedirs(output_dir) else: output_dir = FLAGS.model_dir output_predict_file = os.path.join( output_dir, f"{mode}_results_{current_step}.tsv") prediction_utils.write_predictions( result, output_predict_file, do_model_aggregation, do_model_classification, FLAGS.cell_classification_threshold) if FLAGS.do_sequence_prediction: examples_by_position = prediction_utils.read_classifier_dataset( predict_data=input_file, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision) result_sequence = prediction_utils.compute_prediction_sequence( estimator=estimator, examples_by_position=examples_by_position) output_predict_file_sequence = os.path.join( FLAGS.model_dir, mode + "_results_sequence_{}.tsv").format(current_step) prediction_utils.write_predictions( result_sequence, output_predict_file_sequence, do_model_aggregation, do_model_classification, FLAGS.cell_classification_threshold)
def _predict_and_export_metrics( mode, name, input_fn, estimator, current_step, checkpoint, output_dir, do_model_aggregation, do_model_classification, output_token_answers, ): """Exports model predictions and calculates denotation metric. Args: mode: Prediction mode. Can be "predict" or "eval". name: Used as name appendix if not default. input_fn: Function to generate exmaples to passed into `estimator.predict`. estimator: The Estimator instance. current_step: Current checkpoint step to be evaluated. checkpoint: Path to the checkpoint to be evaluated. output_dir: Path to save predictions generated by the checkpoint. do_model_aggregation: Whether model does aggregation. do_model_classification: Whether model does classification. output_token_answers: If true, output answer coordinates. Raises: ValueError: if an invalid mode is passed. """ # Predict for each new checkpoint. tf.logging.info( "Running predictor for step %d (%s).", current_step, checkpoint, ) if mode == "predict": input_file = FLAGS.input_file_predict interactions_file = FLAGS.predict_interactions_file elif mode == "eval": input_file = FLAGS.input_file_eval interactions_file = FLAGS.eval_interactions_file else: raise ValueError(f"Invalid mode {mode}") result = estimator.predict( input_fn=input_fn, checkpoint_path=checkpoint, ) base_name = mode if name: base_name = f"{mode}_{name}" output_predict_file = os.path.join( output_dir, f"{base_name}_results_{current_step}.tsv") prediction_utils.write_predictions( result, output_predict_file, do_model_aggregation, do_model_classification, FLAGS.cell_classification_threshold, FLAGS.output_token_probabilities, output_token_answers=output_token_answers, ) if FLAGS.do_sequence_prediction: examples_by_position = prediction_utils.read_classifier_dataset( predict_data=input_file, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision) result_sequence = prediction_utils.compute_prediction_sequence( estimator=estimator, examples_by_position=examples_by_position) output_predict_file_sequence = os.path.join( FLAGS.model_dir, base_name + "_results_sequence_{}.tsv").format(current_step) prediction_utils.write_predictions( result_sequence, output_predict_file_sequence, do_model_aggregation, do_model_classification, FLAGS.cell_classification_threshold, FLAGS.output_token_probabilities, output_token_answers)
def _predict_sequence_for_set( estimator, do_model_aggregation, use_answer_as_supervision, example_file, prediction_file, other_prediction_file, ): """Runs realistic sequence evaluation for SQA.""" examples_by_position = exp_prediction_utils.read_classifier_dataset( predict_data=example_file, data_format='tfrecord', compression_type=FLAGS.compression_type, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=_MAX_PREDICTIONS_PER_SEQ, add_aggregation_function_id=do_model_aggregation, add_classification_labels=False, add_answer=use_answer_as_supervision) result = exp_prediction_utils.compute_prediction_sequence( estimator=estimator, examples_by_position=examples_by_position) if (FLAGS.write_prob_table): for query in result: num_rows = max(query['row_ids']) result_array = np.zeros((num_rows,FLAGS.num_columns)) for i in range(1, FLAGS.max_seq_length): if (query['segment_ids'][i] == 1) and (not (query['row_ids'][i-1] == query['row_ids'][i]) or not (query['column_ids'][i-1] == query['column_ids'][i])): row = query['row_ids'][i] column = FLAGS.column_order[query['column_ids'][i]-1] prob = query['probabilities'][i] result_array[row-1][query['column_ids'][i]-1] = prob prob_array = np.concatenate((np.load(f'{FLAGS.output_dir}/probs.npy'), result_array)) np.save(f'{FLAGS.output_dir}/probs.npy', prob_array) if (FLAGS.write_embed_table): for query in result: len_embedding = len(query['embeddings'][0]) num_rows = max(query['row_ids']) embed_array = np.zeros((num_rows,FLAGS.num_columns, len_embedding)) num_array = np.zeros((num_rows,FLAGS.num_columns)) row = 0 column = 0 query_array = np.zeros(len_embedding) query_array_num = 0 at_beginning = True for i in range(1, FLAGS.max_seq_length): if (query['segment_ids'][i] == 0) and (at_beginning): if (i != 0) and (query['segment_ids'][i+1] != 1): query_array += np.array(query['embeddings'][i]) query_array_num += 1 if (query['segment_ids'][i] == 1) and not (row == query['row_ids'][i]-1 and column == int(FLAGS.column_order[query['column_ids'][i]-1])-1): at_beginning = False row = query['row_ids'][i]-1 column = int(FLAGS.column_order[query['column_ids'][i]-1])-1 embed_array[row][column] += np.array(query['embeddings'][i]) num_array[row][column] += 1 elif (query['segment_ids'][i] == 1): at_beginning = False embed_array[row][column] += np.array(query['embeddings'][i]) num_array[row][column] += 1 print(query_array_num) print(query) query_array = query_array/query_array_num result_array = np.array([embed/num for (embed_row, num_row) in zip(embed_array,num_array) for (embed, num) in zip(embed_row, num_row)]) np.save(f'{FLAGS.output_dir}/query.npy', query_array) np.save(f'{FLAGS.output_dir}/embeds.npy', result_array) exp_prediction_utils.write_predictions( result, prediction_file, do_model_aggregation, do_model_classification=False, cell_classification_threshold=_CELL_CLASSIFICATION_THRESHOLD, ) tf.io.gfile.copy(prediction_file, other_prediction_file, overwrite=True)