def test_iterate_checkpoints_multi_step(self): test_tmpdir = tempfile.mkdtemp() checkpoints = [ os.path.join(test_tmpdir, checkpoint) for checkpoint in ['model.ckpt-00001', 'model.ckpt-00002', 'model.ckpt-00003'] ] # Write fake checkpoint file to tmpdir. state = tf.train.generate_checkpoint_state_proto( test_tmpdir, model_checkpoint_path=checkpoints[-1], all_model_checkpoint_paths=checkpoints) with open(os.path.join(test_tmpdir, 'checkpoint'), 'w') as f: f.write(text_format.MessageToString(state)) for checkpoint in checkpoints: with open(f'{checkpoint}.index', 'w') as f: f.write('\n') marker_file_prefix = os.path.join(test_tmpdir, 'marker') results = list( experiment_utils.iterate_checkpoints( model_dir=test_tmpdir, total_steps=3, marker_file_prefix=marker_file_prefix)) expected_steps = [1, 2, 3] self.assertEqual(results, list(zip(expected_steps, checkpoints))) for step in expected_steps: self.assertTrue(tf.gfile.Exists(f'{marker_file_prefix}-{step}.done')) results = list( experiment_utils.iterate_checkpoints( model_dir=test_tmpdir, total_steps=3, marker_file_prefix=marker_file_prefix)) self.assertEmpty(results)
def test_iterate_checkpoints_single_step(self): results = list( experiment_utils.iterate_checkpoints(model_dir='path', single_step=100, marker_file_prefix='path', total_steps=None)) self.assertEqual(results, [(100, 'path/model.ckpt-100')])
def main(_): do_model_aggregation = FLAGS.num_aggregation_labels > 0 do_model_classification = FLAGS.num_classification_labels > 0 bert_config = experiment_utils.bert_config_from_flags() total_steps = experiment_utils.num_train_steps() tapas_config = tapas_classifier_model.TapasClassifierConfig( bert_config=bert_config, init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate, num_train_steps=total_steps, num_warmup_steps=experiment_utils.num_warmup_steps(), use_tpu=FLAGS.use_tpu, positive_weight=FLAGS.positive_weight, num_aggregation_labels=FLAGS.num_aggregation_labels, num_classification_labels=FLAGS.num_classification_labels, aggregation_loss_importance=FLAGS.aggregation_loss_importance, use_answer_as_supervision=FLAGS.use_answer_as_supervision, answer_loss_importance=FLAGS.answer_loss_importance, use_normalized_answer_loss=FLAGS.use_normalized_answer_loss, huber_loss_delta=FLAGS.huber_loss_delta, temperature=FLAGS.temperature, agg_temperature=FLAGS.agg_temperature, use_gumbel_for_cells=FLAGS.use_gumbel_for_cells, use_gumbel_for_agg=FLAGS.use_gumbel_for_agg, average_approximation_function=tapas_classifier_model.\ AverageApproximationFunction(FLAGS.average_approximation_function), cell_select_pref=FLAGS.cell_select_pref, answer_loss_cutoff=FLAGS.answer_loss_cutoff, grad_clipping=FLAGS.grad_clipping, reset_output_cls=FLAGS.reset_output_cls, disabled_features=FLAGS.disabled_features, max_num_rows=FLAGS.max_num_rows, max_num_columns=FLAGS.max_num_columns, average_logits_per_cell=FLAGS.average_logits_per_cell, init_cell_selection_weights_to_zero=FLAGS .init_cell_selection_weights_to_zero, select_one_column=FLAGS.select_one_column, allow_empty_column_selection=FLAGS.allow_empty_column_selection, disable_position_embeddings=FLAGS.disable_position_embeddings, disable_per_token_loss=FLAGS.disable_per_token_loss, reset_position_index_per_cell=FLAGS.reset_position_index_per_cell, span_prediction=tapas_classifier_model.SpanPredictionMode( FLAGS.span_prediction), proj_value_length=FLAGS.proj_value_length if FLAGS.proj_value_length > 0 else None,) model_fn = tapas_classifier_model.model_fn_builder(tapas_config) estimator = experiment_utils.build_estimator(model_fn) if FLAGS.do_train: tf.io.gfile.makedirs(FLAGS.model_dir) bert_config.to_json_file( os.path.join(FLAGS.model_dir, "bert_config.json")) tapas_config.to_json_file( os.path.join(FLAGS.model_dir, "tapas_config.json")) train_input_fn = functools.partial( tapas_classifier_model.input_fn, name="train", file_patterns=FLAGS.input_file_train, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=True, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=False) estimator.train(input_fn=train_input_fn, max_steps=total_steps) eval_input_fn = functools.partial( tapas_classifier_model.input_fn, name="eval", file_patterns=FLAGS.input_file_eval, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=False, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=not FLAGS.use_tpu) if FLAGS.do_eval: eval_name = FLAGS.eval_name if FLAGS.eval_name is not None else "default" for _, checkpoint in experiment_utils.iterate_checkpoints( model_dir=estimator.model_dir, total_steps=total_steps, marker_file_prefix=os.path.join(estimator.model_dir, f"eval_{eval_name}"), minutes_to_sleep=FLAGS.minutes_to_sleep_before_predictions): tf.logging.info("Running eval: %s", eval_name) result = estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.num_eval_steps, name=eval_name, checkpoint_path=checkpoint) tf.logging.info("Eval result:\n%s", result) if FLAGS.do_predict: predict_input_fn = functools.partial( tapas_classifier_model.input_fn, name="predict", file_patterns=FLAGS.input_file_predict, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=False, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=not FLAGS.use_tpu) if FLAGS.prediction_output_dir: prediction_output_dir = FLAGS.prediction_output_dir tf.io.gfile.makedirs(prediction_output_dir) else: prediction_output_dir = estimator.model_dir marker_file_prefix = os.path.join(prediction_output_dir, "predict") # When two separate jobs are launched we don't want conflicting markers. if FLAGS.input_file_predict is not None: marker_file_prefix += "_test" if FLAGS.input_file_eval is not None: marker_file_prefix += "_dev" for current_step, checkpoint in experiment_utils.iterate_checkpoints( model_dir=estimator.model_dir, total_steps=total_steps, single_step=FLAGS.evaluated_checkpoint_step, marker_file_prefix=marker_file_prefix): try: if FLAGS.input_file_predict is not None: _predict_and_export_metrics( mode="predict", name=FLAGS.eval_name, input_fn=predict_input_fn, output_dir=prediction_output_dir, estimator=estimator, checkpoint=checkpoint, current_step=current_step, do_model_classification=do_model_classification, do_model_aggregation=do_model_aggregation, output_token_answers=not FLAGS.disable_per_token_loss) if FLAGS.input_file_eval is not None: _predict_and_export_metrics( mode="eval", name=FLAGS.eval_name, input_fn=eval_input_fn, output_dir=prediction_output_dir, estimator=estimator, checkpoint=checkpoint, current_step=current_step, do_model_classification=do_model_classification, do_model_aggregation=do_model_aggregation, output_token_answers=not FLAGS.disable_per_token_loss) except ValueError: tf.logging.error( "Error getting predictions for checkpoint %s: %s", checkpoint, traceback.format_exc())
def main(_): bert_config = experiment_utils.bert_config_from_flags() total_steps = experiment_utils.num_train_steps() retriever_config = table_retriever_model.RetrieverConfig( bert_config=bert_config, init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate, num_train_steps=total_steps, num_warmup_steps=experiment_utils.num_warmup_steps(), use_tpu=FLAGS.use_tpu, grad_clipping=FLAGS.grad_clipping, down_projection_dim=FLAGS.down_projection_dim, init_from_single_encoder=FLAGS.init_from_single_encoder, max_query_length=FLAGS.max_query_length, mask_repeated_tables=FLAGS.mask_repeated_tables, mask_repeated_questions=FLAGS.mask_repeated_questions, use_out_of_core_negatives=FLAGS.use_out_of_core_negatives, ignore_table_content=FLAGS.ignore_table_content, disabled_features=FLAGS.disabled_features, use_mined_negatives=FLAGS.use_mined_negatives, ) model_fn = table_retriever_model.model_fn_builder(retriever_config) estimator = experiment_utils.build_estimator(model_fn) if FLAGS.do_train: tf.io.gfile.makedirs(FLAGS.model_dir) bert_config.to_json_file( os.path.join(FLAGS.model_dir, "bert_config.json")) retriever_config.to_json_file( os.path.join(FLAGS.model_dir, "tapas_config.json")) train_input_fn = functools.partial( table_retriever_model.input_fn, name="train", file_patterns=FLAGS.input_file_train, data_format=FLAGS.data_format, is_training=True, max_seq_length=FLAGS.max_seq_length, compression_type=FLAGS.compression_type, use_mined_negatives=FLAGS.use_mined_negatives, include_id=False) estimator.train(input_fn=train_input_fn, max_steps=total_steps) eval_input_fn = _get_test_input_fn("eval", FLAGS.input_file_eval) if FLAGS.do_eval: if eval_input_fn is None: raise ValueError("No input_file_eval specified!") for _, checkpoint in experiment_utils.iterate_checkpoints( model_dir=estimator.model_dir, total_steps=total_steps, marker_file_prefix=os.path.join(estimator.model_dir, f"eval_{FLAGS.eval_name}"), minutes_to_sleep=FLAGS.minutes_to_sleep_before_predictions): tf.logging.info("Running eval: %s", FLAGS.eval_name) result = estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.num_eval_steps, name=FLAGS.eval_name, checkpoint_path=checkpoint) tf.logging.info("Eval result:\n%s", result) if FLAGS.do_predict: predict_input_fn = _get_test_input_fn("predict", FLAGS.input_file_predict) if FLAGS.prediction_output_dir: prediction_output_dir = FLAGS.prediction_output_dir tf.io.gfile.makedirs(prediction_output_dir) else: prediction_output_dir = estimator.model_dir marker_file_prefix = os.path.join(prediction_output_dir, "predict") # When two separate jobs are launched we don't want conflicting markers. if predict_input_fn is not None: marker_file_prefix += "_test" if eval_input_fn is not None: marker_file_prefix += "_dev" single_step = FLAGS.evaluated_checkpoint_step if FLAGS.evaluated_checkpoint_metric: single_step = experiment_utils.get_best_step_for_metric( estimator.model_dir, FLAGS.evaluated_checkpoint_metric) for current_step, checkpoint in experiment_utils.iterate_checkpoints( model_dir=estimator.model_dir, total_steps=total_steps, marker_file_prefix=marker_file_prefix, single_step=single_step): if predict_input_fn is not None: _predict_and_export_metrics(mode="predict", input_fn=predict_input_fn, checkpoint_path=checkpoint, step=current_step, estimator=estimator, output_dir=prediction_output_dir) if eval_input_fn is not None: _predict_and_export_metrics(mode="eval", input_fn=eval_input_fn, checkpoint_path=checkpoint, step=current_step, estimator=estimator, output_dir=prediction_output_dir)
def main(_): do_model_aggregation = FLAGS.num_aggregation_labels > 0 do_model_classification = FLAGS.num_classification_labels > 0 bert_config = experiment_utils.bert_config_from_flags() total_steps = experiment_utils.num_train_steps() tapas_config = tapas_classifier_model.TapasClassifierConfig( bert_config=bert_config, init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate, num_train_steps=total_steps, num_warmup_steps=experiment_utils.num_warmup_steps(), use_tpu=FLAGS.use_tpu, positive_weight=FLAGS.positive_weight, num_aggregation_labels=FLAGS.num_aggregation_labels, num_classification_labels=FLAGS.num_classification_labels, aggregation_loss_importance=FLAGS.aggregation_loss_importance, use_answer_as_supervision=FLAGS.use_answer_as_supervision, answer_loss_importance=FLAGS.answer_loss_importance, use_normalized_answer_loss=FLAGS.use_normalized_answer_loss, huber_loss_delta=FLAGS.huber_loss_delta, temperature=FLAGS.temperature, agg_temperature=FLAGS.agg_temperature, use_gumbel_for_cells=FLAGS.use_gumbel_for_cells, use_gumbel_for_agg=FLAGS.use_gumbel_for_agg, average_approximation_function=( tapas_classifier_model.AverageApproximationFunction( FLAGS.average_approximation_function)), cell_select_pref=FLAGS.cell_select_pref, answer_loss_cutoff=FLAGS.answer_loss_cutoff, grad_clipping=FLAGS.grad_clipping, classification_label_weight={ int(pair.split(":")[0]): float(pair.split(":")[1]) for pair in FLAGS.classification_label_weight.split(",") if pair }, table_pruning_config_file=FLAGS.table_pruning_config_file, restrict_attention_mode=(attention_utils.RestrictAttentionMode( FLAGS.restrict_attention_mode)), restrict_attention_bucket_size=FLAGS.restrict_attention_bucket_size, restrict_attention_header_size=FLAGS.restrict_attention_header_size, restrict_attention_row_heads_ratio=( FLAGS.restrict_attention_row_heads_ratio), mask_examples_without_labels=FLAGS.mask_examples_without_labels, cell_cross_entropy_hard_em=FLAGS.cell_cross_entropy_hard_em, cell_cross_entropy=FLAGS.cell_cross_entropy, reset_output_cls=FLAGS.reset_output_cls, disabled_features=FLAGS.disabled_features, max_num_rows=FLAGS.max_num_rows, max_num_columns=FLAGS.max_num_columns, average_logits_per_cell=FLAGS.average_logits_per_cell, init_cell_selection_weights_to_zero=FLAGS .init_cell_selection_weights_to_zero, select_one_column=FLAGS.select_one_column, allow_empty_column_selection=FLAGS.allow_empty_column_selection, disable_position_embeddings=FLAGS.disable_position_embeddings, disable_per_token_loss=FLAGS.disable_per_token_loss, reset_position_index_per_cell=FLAGS.reset_position_index_per_cell, span_prediction=tapas_classifier_model.SpanPredictionMode( FLAGS.span_prediction), proj_value_length=_get_projection_length(FLAGS.proj_value_length), attention_bias_disabled=FLAGS.attention_bias_disabled, attention_bias_use_relative_scalar_only=FLAGS .attention_bias_use_relative_scalar_only, ) model_fn = tapas_classifier_model.model_fn_builder(tapas_config) estimator = experiment_utils.build_estimator(model_fn) if tapas_config.classification_label_weight: if any(x < 0 for x in tapas_config.classification_label_weight.values()): raise ValueError("Label weights cannot be negative in input: " f"{tapas_config.classification_label_weight}.") if any(x < 0 or x >= tapas_config.num_classification_labels for x in tapas_config.classification_label_weight.keys()): raise ValueError("Invalid label in label weights for input: " f"{tapas_config.classification_label_weight}.") if FLAGS.do_train: tf.io.gfile.makedirs(FLAGS.model_dir) # Copy the table pruning config if pruning is used. if FLAGS.table_pruning_config_file: table_pruning_path = os.path.join(FLAGS.model_dir, "table_pruning_config.textproto") tf.io.gfile.copy( FLAGS.table_pruning_config_file, table_pruning_path, overwrite=True) bert_config.to_json_file(os.path.join(FLAGS.model_dir, "bert_config.json")) tapas_config.to_json_file( os.path.join(FLAGS.model_dir, "tapas_config.json")) train_input_fn = functools.partial( tapas_classifier_model.input_fn, name="train", file_patterns=FLAGS.input_file_train, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=True, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=False) estimator.train(input_fn=train_input_fn, max_steps=total_steps) eval_input_fn = functools.partial( tapas_classifier_model.input_fn, name="eval", file_patterns=FLAGS.input_file_eval, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=False, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=not FLAGS.use_tpu) if FLAGS.do_eval: eval_name = FLAGS.eval_name if FLAGS.eval_name is not None else "default" for _, checkpoint in experiment_utils.iterate_checkpoints( model_dir=estimator.model_dir, total_steps=total_steps, marker_file_prefix=os.path.join(estimator.model_dir, f"eval_{eval_name}"), minutes_to_sleep=FLAGS.minutes_to_sleep_before_predictions): tf.logging.info("Running eval: %s", eval_name) result = estimator.evaluate( input_fn=eval_input_fn, steps=FLAGS.num_eval_steps, name=eval_name, checkpoint_path=checkpoint) tf.logging.info("Eval result:\n%s", result) if FLAGS.do_predict: predict_input_fn = functools.partial( tapas_classifier_model.input_fn, name="predict", file_patterns=FLAGS.input_file_predict, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=False, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=not FLAGS.use_tpu) if FLAGS.prediction_output_dir: prediction_output_dir = FLAGS.prediction_output_dir tf.io.gfile.makedirs(prediction_output_dir) else: prediction_output_dir = estimator.model_dir marker_file_prefix = os.path.join(prediction_output_dir, "predict") # When two separate jobs are launched we don't want conflicting markers. if FLAGS.input_file_predict is not None: marker_file_prefix += "_test" if FLAGS.input_file_eval is not None: marker_file_prefix += "_dev" for current_step, checkpoint in experiment_utils.iterate_checkpoints( model_dir=estimator.model_dir, total_steps=total_steps, single_step=FLAGS.evaluated_checkpoint_step, marker_file_prefix=marker_file_prefix): try: if FLAGS.input_file_predict is not None: _predict_and_export_metrics( mode="predict", name=FLAGS.eval_name, input_fn=predict_input_fn, output_dir=prediction_output_dir, estimator=estimator, checkpoint=checkpoint, current_step=current_step, do_model_classification=do_model_classification, do_model_aggregation=do_model_aggregation, output_token_answers=not FLAGS.disable_per_token_loss) if FLAGS.input_file_eval is not None: _predict_and_export_metrics( mode="eval", name=FLAGS.eval_name, input_fn=eval_input_fn, output_dir=prediction_output_dir, estimator=estimator, checkpoint=checkpoint, current_step=current_step, do_model_classification=do_model_classification, do_model_aggregation=do_model_aggregation, output_token_answers=not FLAGS.disable_per_token_loss) except ValueError: tf.logging.error("Error getting predictions for checkpoint %s: %s", checkpoint, traceback.format_exc())