def main(_): do_model_aggregation = FLAGS.num_aggregation_labels > 0 do_model_classification = FLAGS.num_classification_labels > 0 bert_config = experiment_utils.bert_config_from_flags() total_steps = experiment_utils.num_train_steps() tapas_config = tapas_classifier_model.TapasClassifierConfig( bert_config=bert_config, init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate, num_train_steps=total_steps, num_warmup_steps=experiment_utils.num_warmup_steps(), use_tpu=FLAGS.use_tpu, positive_weight=FLAGS.positive_weight, num_aggregation_labels=FLAGS.num_aggregation_labels, num_classification_labels=FLAGS.num_classification_labels, aggregation_loss_importance=FLAGS.aggregation_loss_importance, use_answer_as_supervision=FLAGS.use_answer_as_supervision, answer_loss_importance=FLAGS.answer_loss_importance, use_normalized_answer_loss=FLAGS.use_normalized_answer_loss, huber_loss_delta=FLAGS.huber_loss_delta, temperature=FLAGS.temperature, agg_temperature=FLAGS.agg_temperature, use_gumbel_for_cells=FLAGS.use_gumbel_for_cells, use_gumbel_for_agg=FLAGS.use_gumbel_for_agg, average_approximation_function=tapas_classifier_model.\ AverageApproximationFunction(FLAGS.average_approximation_function), cell_select_pref=FLAGS.cell_select_pref, answer_loss_cutoff=FLAGS.answer_loss_cutoff, grad_clipping=FLAGS.grad_clipping, reset_output_cls=FLAGS.reset_output_cls, disabled_features=FLAGS.disabled_features, max_num_rows=FLAGS.max_num_rows, max_num_columns=FLAGS.max_num_columns, average_logits_per_cell=FLAGS.average_logits_per_cell, init_cell_selection_weights_to_zero=FLAGS .init_cell_selection_weights_to_zero, select_one_column=FLAGS.select_one_column, allow_empty_column_selection=FLAGS.allow_empty_column_selection, disable_position_embeddings=FLAGS.disable_position_embeddings, disable_per_token_loss=FLAGS.disable_per_token_loss, reset_position_index_per_cell=FLAGS.reset_position_index_per_cell, span_prediction=tapas_classifier_model.SpanPredictionMode( FLAGS.span_prediction), proj_value_length=FLAGS.proj_value_length if FLAGS.proj_value_length > 0 else None,) model_fn = tapas_classifier_model.model_fn_builder(tapas_config) estimator = experiment_utils.build_estimator(model_fn) if FLAGS.do_train: tf.io.gfile.makedirs(FLAGS.model_dir) bert_config.to_json_file( os.path.join(FLAGS.model_dir, "bert_config.json")) tapas_config.to_json_file( os.path.join(FLAGS.model_dir, "tapas_config.json")) train_input_fn = functools.partial( tapas_classifier_model.input_fn, name="train", file_patterns=FLAGS.input_file_train, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=True, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=False) estimator.train(input_fn=train_input_fn, max_steps=total_steps) eval_input_fn = functools.partial( tapas_classifier_model.input_fn, name="eval", file_patterns=FLAGS.input_file_eval, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=False, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=not FLAGS.use_tpu) if FLAGS.do_eval: eval_name = FLAGS.eval_name if FLAGS.eval_name is not None else "default" for _, checkpoint in experiment_utils.iterate_checkpoints( model_dir=estimator.model_dir, total_steps=total_steps, marker_file_prefix=os.path.join(estimator.model_dir, f"eval_{eval_name}"), minutes_to_sleep=FLAGS.minutes_to_sleep_before_predictions): tf.logging.info("Running eval: %s", eval_name) result = estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.num_eval_steps, name=eval_name, checkpoint_path=checkpoint) tf.logging.info("Eval result:\n%s", result) if FLAGS.do_predict: predict_input_fn = functools.partial( tapas_classifier_model.input_fn, name="predict", file_patterns=FLAGS.input_file_predict, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=False, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=not FLAGS.use_tpu) if FLAGS.prediction_output_dir: prediction_output_dir = FLAGS.prediction_output_dir tf.io.gfile.makedirs(prediction_output_dir) else: prediction_output_dir = estimator.model_dir marker_file_prefix = os.path.join(prediction_output_dir, "predict") # When two separate jobs are launched we don't want conflicting markers. if FLAGS.input_file_predict is not None: marker_file_prefix += "_test" if FLAGS.input_file_eval is not None: marker_file_prefix += "_dev" for current_step, checkpoint in experiment_utils.iterate_checkpoints( model_dir=estimator.model_dir, total_steps=total_steps, single_step=FLAGS.evaluated_checkpoint_step, marker_file_prefix=marker_file_prefix): try: if FLAGS.input_file_predict is not None: _predict_and_export_metrics( mode="predict", name=FLAGS.eval_name, input_fn=predict_input_fn, output_dir=prediction_output_dir, estimator=estimator, checkpoint=checkpoint, current_step=current_step, do_model_classification=do_model_classification, do_model_aggregation=do_model_aggregation, output_token_answers=not FLAGS.disable_per_token_loss) if FLAGS.input_file_eval is not None: _predict_and_export_metrics( mode="eval", name=FLAGS.eval_name, input_fn=eval_input_fn, output_dir=prediction_output_dir, estimator=estimator, checkpoint=checkpoint, current_step=current_step, do_model_classification=do_model_classification, do_model_aggregation=do_model_aggregation, output_token_answers=not FLAGS.disable_per_token_loss) except ValueError: tf.logging.error( "Error getting predictions for checkpoint %s: %s", checkpoint, traceback.format_exc())
def main(_): bert_config = experiment_utils.bert_config_from_flags() total_steps = experiment_utils.num_train_steps() retriever_config = table_retriever_model.RetrieverConfig( bert_config=bert_config, init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate, num_train_steps=total_steps, num_warmup_steps=experiment_utils.num_warmup_steps(), use_tpu=FLAGS.use_tpu, grad_clipping=FLAGS.grad_clipping, down_projection_dim=FLAGS.down_projection_dim, init_from_single_encoder=FLAGS.init_from_single_encoder, max_query_length=FLAGS.max_query_length, mask_repeated_tables=FLAGS.mask_repeated_tables, mask_repeated_questions=FLAGS.mask_repeated_questions, use_out_of_core_negatives=FLAGS.use_out_of_core_negatives, ignore_table_content=FLAGS.ignore_table_content, disabled_features=FLAGS.disabled_features, use_mined_negatives=FLAGS.use_mined_negatives, ) model_fn = table_retriever_model.model_fn_builder(retriever_config) estimator = experiment_utils.build_estimator(model_fn) if FLAGS.do_train: tf.io.gfile.makedirs(FLAGS.model_dir) bert_config.to_json_file( os.path.join(FLAGS.model_dir, "bert_config.json")) retriever_config.to_json_file( os.path.join(FLAGS.model_dir, "tapas_config.json")) train_input_fn = functools.partial( table_retriever_model.input_fn, name="train", file_patterns=FLAGS.input_file_train, data_format=FLAGS.data_format, is_training=True, max_seq_length=FLAGS.max_seq_length, compression_type=FLAGS.compression_type, use_mined_negatives=FLAGS.use_mined_negatives, include_id=False) estimator.train(input_fn=train_input_fn, max_steps=total_steps) eval_input_fn = _get_test_input_fn("eval", FLAGS.input_file_eval) if FLAGS.do_eval: if eval_input_fn is None: raise ValueError("No input_file_eval specified!") for _, checkpoint in experiment_utils.iterate_checkpoints( model_dir=estimator.model_dir, total_steps=total_steps, marker_file_prefix=os.path.join(estimator.model_dir, f"eval_{FLAGS.eval_name}"), minutes_to_sleep=FLAGS.minutes_to_sleep_before_predictions): tf.logging.info("Running eval: %s", FLAGS.eval_name) result = estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.num_eval_steps, name=FLAGS.eval_name, checkpoint_path=checkpoint) tf.logging.info("Eval result:\n%s", result) if FLAGS.do_predict: predict_input_fn = _get_test_input_fn("predict", FLAGS.input_file_predict) if FLAGS.prediction_output_dir: prediction_output_dir = FLAGS.prediction_output_dir tf.io.gfile.makedirs(prediction_output_dir) else: prediction_output_dir = estimator.model_dir marker_file_prefix = os.path.join(prediction_output_dir, "predict") # When two separate jobs are launched we don't want conflicting markers. if predict_input_fn is not None: marker_file_prefix += "_test" if eval_input_fn is not None: marker_file_prefix += "_dev" single_step = FLAGS.evaluated_checkpoint_step if FLAGS.evaluated_checkpoint_metric: single_step = experiment_utils.get_best_step_for_metric( estimator.model_dir, FLAGS.evaluated_checkpoint_metric) for current_step, checkpoint in experiment_utils.iterate_checkpoints( model_dir=estimator.model_dir, total_steps=total_steps, marker_file_prefix=marker_file_prefix, single_step=single_step): if predict_input_fn is not None: _predict_and_export_metrics(mode="predict", input_fn=predict_input_fn, checkpoint_path=checkpoint, step=current_step, estimator=estimator, output_dir=prediction_output_dir) if eval_input_fn is not None: _predict_and_export_metrics(mode="eval", input_fn=eval_input_fn, checkpoint_path=checkpoint, step=current_step, estimator=estimator, output_dir=prediction_output_dir)
def main(_): do_model_aggregation = FLAGS.num_aggregation_labels > 0 do_model_classification = FLAGS.num_classification_labels > 0 bert_config = experiment_utils.bert_config_from_flags() total_steps = experiment_utils.num_train_steps() tapas_config = tapas_classifier_model.TapasClassifierConfig( bert_config=bert_config, init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate, num_train_steps=total_steps, num_warmup_steps=experiment_utils.num_warmup_steps(), use_tpu=FLAGS.use_tpu, positive_weight=FLAGS.positive_weight, num_aggregation_labels=FLAGS.num_aggregation_labels, num_classification_labels=FLAGS.num_classification_labels, aggregation_loss_importance=FLAGS.aggregation_loss_importance, use_answer_as_supervision=FLAGS.use_answer_as_supervision, answer_loss_importance=FLAGS.answer_loss_importance, use_normalized_answer_loss=FLAGS.use_normalized_answer_loss, huber_loss_delta=FLAGS.huber_loss_delta, temperature=FLAGS.temperature, agg_temperature=FLAGS.agg_temperature, use_gumbel_for_cells=FLAGS.use_gumbel_for_cells, use_gumbel_for_agg=FLAGS.use_gumbel_for_agg, average_approximation_function=( tapas_classifier_model.AverageApproximationFunction( FLAGS.average_approximation_function)), cell_select_pref=FLAGS.cell_select_pref, answer_loss_cutoff=FLAGS.answer_loss_cutoff, grad_clipping=FLAGS.grad_clipping, classification_label_weight={ int(pair.split(":")[0]): float(pair.split(":")[1]) for pair in FLAGS.classification_label_weight.split(",") if pair }, table_pruning_config_file=FLAGS.table_pruning_config_file, restrict_attention_mode=(attention_utils.RestrictAttentionMode( FLAGS.restrict_attention_mode)), restrict_attention_bucket_size=FLAGS.restrict_attention_bucket_size, restrict_attention_header_size=FLAGS.restrict_attention_header_size, restrict_attention_row_heads_ratio=( FLAGS.restrict_attention_row_heads_ratio), mask_examples_without_labels=FLAGS.mask_examples_without_labels, cell_cross_entropy_hard_em=FLAGS.cell_cross_entropy_hard_em, cell_cross_entropy=FLAGS.cell_cross_entropy, reset_output_cls=FLAGS.reset_output_cls, disabled_features=FLAGS.disabled_features, max_num_rows=FLAGS.max_num_rows, max_num_columns=FLAGS.max_num_columns, average_logits_per_cell=FLAGS.average_logits_per_cell, init_cell_selection_weights_to_zero=FLAGS .init_cell_selection_weights_to_zero, select_one_column=FLAGS.select_one_column, allow_empty_column_selection=FLAGS.allow_empty_column_selection, disable_position_embeddings=FLAGS.disable_position_embeddings, disable_per_token_loss=FLAGS.disable_per_token_loss, reset_position_index_per_cell=FLAGS.reset_position_index_per_cell, span_prediction=tapas_classifier_model.SpanPredictionMode( FLAGS.span_prediction), proj_value_length=_get_projection_length(FLAGS.proj_value_length), attention_bias_disabled=FLAGS.attention_bias_disabled, attention_bias_use_relative_scalar_only=FLAGS .attention_bias_use_relative_scalar_only, ) model_fn = tapas_classifier_model.model_fn_builder(tapas_config) estimator = experiment_utils.build_estimator(model_fn) if tapas_config.classification_label_weight: if any(x < 0 for x in tapas_config.classification_label_weight.values()): raise ValueError("Label weights cannot be negative in input: " f"{tapas_config.classification_label_weight}.") if any(x < 0 or x >= tapas_config.num_classification_labels for x in tapas_config.classification_label_weight.keys()): raise ValueError("Invalid label in label weights for input: " f"{tapas_config.classification_label_weight}.") if FLAGS.do_train: tf.io.gfile.makedirs(FLAGS.model_dir) # Copy the table pruning config if pruning is used. if FLAGS.table_pruning_config_file: table_pruning_path = os.path.join(FLAGS.model_dir, "table_pruning_config.textproto") tf.io.gfile.copy( FLAGS.table_pruning_config_file, table_pruning_path, overwrite=True) bert_config.to_json_file(os.path.join(FLAGS.model_dir, "bert_config.json")) tapas_config.to_json_file( os.path.join(FLAGS.model_dir, "tapas_config.json")) train_input_fn = functools.partial( tapas_classifier_model.input_fn, name="train", file_patterns=FLAGS.input_file_train, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=True, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=False) estimator.train(input_fn=train_input_fn, max_steps=total_steps) eval_input_fn = functools.partial( tapas_classifier_model.input_fn, name="eval", file_patterns=FLAGS.input_file_eval, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=False, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=not FLAGS.use_tpu) if FLAGS.do_eval: eval_name = FLAGS.eval_name if FLAGS.eval_name is not None else "default" for _, checkpoint in experiment_utils.iterate_checkpoints( model_dir=estimator.model_dir, total_steps=total_steps, marker_file_prefix=os.path.join(estimator.model_dir, f"eval_{eval_name}"), minutes_to_sleep=FLAGS.minutes_to_sleep_before_predictions): tf.logging.info("Running eval: %s", eval_name) result = estimator.evaluate( input_fn=eval_input_fn, steps=FLAGS.num_eval_steps, name=eval_name, checkpoint_path=checkpoint) tf.logging.info("Eval result:\n%s", result) if FLAGS.do_predict: predict_input_fn = functools.partial( tapas_classifier_model.input_fn, name="predict", file_patterns=FLAGS.input_file_predict, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=False, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=not FLAGS.use_tpu) if FLAGS.prediction_output_dir: prediction_output_dir = FLAGS.prediction_output_dir tf.io.gfile.makedirs(prediction_output_dir) else: prediction_output_dir = estimator.model_dir marker_file_prefix = os.path.join(prediction_output_dir, "predict") # When two separate jobs are launched we don't want conflicting markers. if FLAGS.input_file_predict is not None: marker_file_prefix += "_test" if FLAGS.input_file_eval is not None: marker_file_prefix += "_dev" for current_step, checkpoint in experiment_utils.iterate_checkpoints( model_dir=estimator.model_dir, total_steps=total_steps, single_step=FLAGS.evaluated_checkpoint_step, marker_file_prefix=marker_file_prefix): try: if FLAGS.input_file_predict is not None: _predict_and_export_metrics( mode="predict", name=FLAGS.eval_name, input_fn=predict_input_fn, output_dir=prediction_output_dir, estimator=estimator, checkpoint=checkpoint, current_step=current_step, do_model_classification=do_model_classification, do_model_aggregation=do_model_aggregation, output_token_answers=not FLAGS.disable_per_token_loss) if FLAGS.input_file_eval is not None: _predict_and_export_metrics( mode="eval", name=FLAGS.eval_name, input_fn=eval_input_fn, output_dir=prediction_output_dir, estimator=estimator, checkpoint=checkpoint, current_step=current_step, do_model_classification=do_model_classification, do_model_aggregation=do_model_aggregation, output_token_answers=not FLAGS.disable_per_token_loss) except ValueError: tf.logging.error("Error getting predictions for checkpoint %s: %s", checkpoint, traceback.format_exc())
def main(_): bert_config = experiment_utils.bert_config_from_flags() model_fn = tapas_pretraining_model.model_fn_builder( bert_config=bert_config, init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate, num_train_steps=experiment_utils.num_train_steps(), num_warmup_steps=experiment_utils.num_warmup_steps(), use_tpu=FLAGS.use_tpu, restrict_attention_mode=attention_utils.RestrictAttentionMode( FLAGS.restrict_attention_mode), restrict_attention_bucket_size=FLAGS.restrict_attention_bucket_size, restrict_attention_header_size=FLAGS.restrict_attention_header_size, restrict_attention_row_heads_ratio=( FLAGS.restrict_attention_row_heads_ratio), disabled_features=FLAGS.disabled_features, disable_position_embeddings=FLAGS.disable_position_embeddings, reset_position_index_per_cell=FLAGS.reset_position_index_per_cell, proj_value_length=FLAGS.proj_value_length if FLAGS.proj_value_length > 0 else None, attention_bias_disabled=FLAGS.attention_bias_disabled, attention_bias_use_relative_scalar_only=FLAGS. attention_bias_use_relative_scalar_only, ) estimator = experiment_utils.build_estimator(model_fn) if FLAGS.do_train: tf.io.gfile.makedirs(FLAGS.model_dir) bert_config.to_json_file( os.path.join(FLAGS.model_dir, "bert_config.json")) train_input_fn = functools.partial( tapas_pretraining_model.input_fn, name="train", file_patterns=FLAGS.input_file_train, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=True, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq) estimator.train(input_fn=train_input_fn, max_steps=experiment_utils.num_train_steps()) if FLAGS.do_eval: eval_input_fn = functools.partial( tapas_pretraining_model.input_fn, name="eval", file_patterns=FLAGS.input_file_eval, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=False, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq) current_step = 0 prev_checkpoint = None while True: checkpoint = estimator.latest_checkpoint() if checkpoint == prev_checkpoint: tf.logging.info("Sleeping 5 mins before evaluation") time.sleep(5 * 60) continue tf.logging.info("Running eval: %s", FLAGS.eval_name) result = estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.num_eval_steps, name=FLAGS.eval_name) tf.logging.info("Eval result:\n%s", result) current_step = int(os.path.basename(checkpoint).split("-")[1]) if current_step >= experiment_utils.num_train_steps(): tf.logging.info("Evaluation finished after training step %d", current_step) break prev_checkpoint = checkpoint
def main(_): do_model_aggregation = FLAGS.num_aggregation_labels > 0 do_model_classification = FLAGS.num_classification_labels > 0 bert_config = experiment_utils.bert_config_from_flags() tapas_config = tapas_classifier_model.TapasClassifierConfig( bert_config=bert_config, init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate, num_train_steps=experiment_utils.num_train_steps(), num_warmup_steps=experiment_utils.num_warmup_steps(), use_tpu=FLAGS.use_tpu, positive_weight=FLAGS.positive_weight, num_aggregation_labels=FLAGS.num_aggregation_labels, num_classification_labels=FLAGS.num_classification_labels, aggregation_loss_importance=FLAGS.aggregation_loss_importance, use_answer_as_supervision=FLAGS.use_answer_as_supervision, answer_loss_importance=FLAGS.answer_loss_importance, use_normalized_answer_loss=FLAGS.use_normalized_answer_loss, huber_loss_delta=FLAGS.huber_loss_delta, temperature=FLAGS.temperature, agg_temperature=FLAGS.agg_temperature, use_gumbel_for_cells=FLAGS.use_gumbel_for_cells, use_gumbel_for_agg=FLAGS.use_gumbel_for_agg, average_approximation_function=tapas_classifier_model.\ AverageApproximationFunction(FLAGS.average_approximation_function), cell_select_pref=FLAGS.cell_select_pref, answer_loss_cutoff=FLAGS.answer_loss_cutoff, grad_clipping=FLAGS.grad_clipping, disabled_features=FLAGS.disabled_features, max_num_rows=FLAGS.max_num_rows, max_num_columns=FLAGS.max_num_columns, average_logits_per_cell=FLAGS.average_logits_per_cell, init_cell_selection_weights_to_zero=FLAGS .init_cell_selection_weights_to_zero, select_one_column=FLAGS.select_one_column, allow_empty_column_selection=FLAGS.allow_empty_column_selection, disable_position_embeddings=FLAGS.disable_position_embeddings, disable_per_token_loss=FLAGS.disable_per_token_loss) model_fn = tapas_classifier_model.model_fn_builder(tapas_config) estimator = experiment_utils.build_estimator(model_fn) if FLAGS.do_train: tf.io.gfile.makedirs(FLAGS.model_dir) bert_config.to_json_file( os.path.join(FLAGS.model_dir, "bert_config.json")) tapas_config.to_json_file( os.path.join(FLAGS.model_dir, "tapas_config.json")) train_input_fn = functools.partial( tapas_classifier_model.input_fn, name="train", file_patterns=FLAGS.input_file_train, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=True, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=False) estimator.train(input_fn=train_input_fn, max_steps=experiment_utils.num_train_steps()) eval_input_fn = functools.partial( tapas_classifier_model.input_fn, name="eval", file_patterns=FLAGS.input_file_eval, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=False, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=True) if FLAGS.do_eval: prev_checkpoint = None while True: checkpoint = estimator.latest_checkpoint() if checkpoint == prev_checkpoint: if FLAGS.minutes_to_sleep_before_predictions > 0: tf.logging.info("Sleeping %d mins before evaluation", FLAGS.minutes_to_sleep_before_predictions) time.sleep(FLAGS.minutes_to_sleep_before_predictions * 60) continue tf.logging.info("Running eval: %s", FLAGS.eval_name) result = estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.num_eval_steps, name=FLAGS.eval_name) tf.logging.info("Eval result:\n%s", result) current_step = int(os.path.basename(checkpoint).split("-")[1]) if experiment_utils.num_train_steps( ) is None or current_step >= experiment_utils.num_train_steps(): tf.logging.info("Evaluation finished after training step %d", current_step) break prev_checkpoint = checkpoint if FLAGS.do_predict: predict_input_fn = functools.partial( tapas_classifier_model.input_fn, name="predict", file_patterns=FLAGS.input_file_predict, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=False, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=True) prev_checkpoint = None while True: checkpoint = estimator.latest_checkpoint() if checkpoint == prev_checkpoint: tf.logging.info("Sleeping 5 mins before predicting") time.sleep(5 * 60) continue current_step = int(os.path.basename(checkpoint).split("-")[1]) def _predict_and_export_metrics( mode, input_fn, input_file, interactions_file, ): """Exports model predictions and calculates denotation metric.""" # Predict for each new checkpoint. tf.logging.info("Running predictor for step %d.", current_step) result = estimator.predict(input_fn=input_fn) if FLAGS.prediction_output_dir: output_dir = FLAGS.prediction_output_dir tf.io.gfile.makedirs(output_dir) else: output_dir = FLAGS.model_dir output_predict_file = os.path.join( output_dir, f"{mode}_results_{current_step}.tsv") prediction_utils.write_predictions( result, output_predict_file, do_model_aggregation, do_model_classification, FLAGS.cell_classification_threshold) if FLAGS.do_sequence_prediction: examples_by_position = prediction_utils.read_classifier_dataset( predict_data=input_file, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision) result_sequence = prediction_utils.compute_prediction_sequence( estimator=estimator, examples_by_position=examples_by_position) output_predict_file_sequence = os.path.join( FLAGS.model_dir, mode + "_results_sequence_{}.tsv").format(current_step) prediction_utils.write_predictions( result_sequence, output_predict_file_sequence, do_model_aggregation, do_model_classification, FLAGS.cell_classification_threshold) if FLAGS.input_file_predict is not None: _predict_and_export_metrics( mode="predict", input_fn=predict_input_fn, input_file=FLAGS.input_file_predict, interactions_file=FLAGS.predict_interactions_file) _predict_and_export_metrics( mode="eval", input_fn=eval_input_fn, input_file=FLAGS.input_file_eval, interactions_file=FLAGS.eval_interactions_file) num_train_steps = experiment_utils.num_train_steps() if num_train_steps is None or current_step >= num_train_steps: tf.logging.info("Predictor finished after training step %d", current_step) break prev_checkpoint = checkpoint