def main(_): do_model_aggregation = FLAGS.num_aggregation_labels > 0 do_model_classification = FLAGS.num_classification_labels > 0 bert_config = experiment_utils.bert_config_from_flags() total_steps = experiment_utils.num_train_steps() tapas_config = tapas_classifier_model.TapasClassifierConfig( bert_config=bert_config, init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate, num_train_steps=total_steps, num_warmup_steps=experiment_utils.num_warmup_steps(), use_tpu=FLAGS.use_tpu, positive_weight=FLAGS.positive_weight, num_aggregation_labels=FLAGS.num_aggregation_labels, num_classification_labels=FLAGS.num_classification_labels, aggregation_loss_importance=FLAGS.aggregation_loss_importance, use_answer_as_supervision=FLAGS.use_answer_as_supervision, answer_loss_importance=FLAGS.answer_loss_importance, use_normalized_answer_loss=FLAGS.use_normalized_answer_loss, huber_loss_delta=FLAGS.huber_loss_delta, temperature=FLAGS.temperature, agg_temperature=FLAGS.agg_temperature, use_gumbel_for_cells=FLAGS.use_gumbel_for_cells, use_gumbel_for_agg=FLAGS.use_gumbel_for_agg, average_approximation_function=tapas_classifier_model.\ AverageApproximationFunction(FLAGS.average_approximation_function), cell_select_pref=FLAGS.cell_select_pref, answer_loss_cutoff=FLAGS.answer_loss_cutoff, grad_clipping=FLAGS.grad_clipping, reset_output_cls=FLAGS.reset_output_cls, disabled_features=FLAGS.disabled_features, max_num_rows=FLAGS.max_num_rows, max_num_columns=FLAGS.max_num_columns, average_logits_per_cell=FLAGS.average_logits_per_cell, init_cell_selection_weights_to_zero=FLAGS .init_cell_selection_weights_to_zero, select_one_column=FLAGS.select_one_column, allow_empty_column_selection=FLAGS.allow_empty_column_selection, disable_position_embeddings=FLAGS.disable_position_embeddings, disable_per_token_loss=FLAGS.disable_per_token_loss, reset_position_index_per_cell=FLAGS.reset_position_index_per_cell, span_prediction=tapas_classifier_model.SpanPredictionMode( FLAGS.span_prediction), proj_value_length=FLAGS.proj_value_length if FLAGS.proj_value_length > 0 else None,) model_fn = tapas_classifier_model.model_fn_builder(tapas_config) estimator = experiment_utils.build_estimator(model_fn) if FLAGS.do_train: tf.io.gfile.makedirs(FLAGS.model_dir) bert_config.to_json_file( os.path.join(FLAGS.model_dir, "bert_config.json")) tapas_config.to_json_file( os.path.join(FLAGS.model_dir, "tapas_config.json")) train_input_fn = functools.partial( tapas_classifier_model.input_fn, name="train", file_patterns=FLAGS.input_file_train, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=True, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=False) estimator.train(input_fn=train_input_fn, max_steps=total_steps) eval_input_fn = functools.partial( tapas_classifier_model.input_fn, name="eval", file_patterns=FLAGS.input_file_eval, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=False, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=not FLAGS.use_tpu) if FLAGS.do_eval: eval_name = FLAGS.eval_name if FLAGS.eval_name is not None else "default" for _, checkpoint in experiment_utils.iterate_checkpoints( model_dir=estimator.model_dir, total_steps=total_steps, marker_file_prefix=os.path.join(estimator.model_dir, f"eval_{eval_name}"), minutes_to_sleep=FLAGS.minutes_to_sleep_before_predictions): tf.logging.info("Running eval: %s", eval_name) result = estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.num_eval_steps, name=eval_name, checkpoint_path=checkpoint) tf.logging.info("Eval result:\n%s", result) if FLAGS.do_predict: predict_input_fn = functools.partial( tapas_classifier_model.input_fn, name="predict", file_patterns=FLAGS.input_file_predict, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=False, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=not FLAGS.use_tpu) if FLAGS.prediction_output_dir: prediction_output_dir = FLAGS.prediction_output_dir tf.io.gfile.makedirs(prediction_output_dir) else: prediction_output_dir = estimator.model_dir marker_file_prefix = os.path.join(prediction_output_dir, "predict") # When two separate jobs are launched we don't want conflicting markers. if FLAGS.input_file_predict is not None: marker_file_prefix += "_test" if FLAGS.input_file_eval is not None: marker_file_prefix += "_dev" for current_step, checkpoint in experiment_utils.iterate_checkpoints( model_dir=estimator.model_dir, total_steps=total_steps, single_step=FLAGS.evaluated_checkpoint_step, marker_file_prefix=marker_file_prefix): try: if FLAGS.input_file_predict is not None: _predict_and_export_metrics( mode="predict", name=FLAGS.eval_name, input_fn=predict_input_fn, output_dir=prediction_output_dir, estimator=estimator, checkpoint=checkpoint, current_step=current_step, do_model_classification=do_model_classification, do_model_aggregation=do_model_aggregation, output_token_answers=not FLAGS.disable_per_token_loss) if FLAGS.input_file_eval is not None: _predict_and_export_metrics( mode="eval", name=FLAGS.eval_name, input_fn=eval_input_fn, output_dir=prediction_output_dir, estimator=estimator, checkpoint=checkpoint, current_step=current_step, do_model_classification=do_model_classification, do_model_aggregation=do_model_aggregation, output_token_answers=not FLAGS.disable_per_token_loss) except ValueError: tf.logging.error( "Error getting predictions for checkpoint %s: %s", checkpoint, traceback.format_exc())
def main(_): do_model_aggregation = FLAGS.num_aggregation_labels > 0 do_model_classification = FLAGS.num_classification_labels > 0 bert_config = experiment_utils.bert_config_from_flags() tapas_config = tapas_classifier_model.TapasClassifierConfig( bert_config=bert_config, init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate, num_train_steps=experiment_utils.num_train_steps(), num_warmup_steps=experiment_utils.num_warmup_steps(), use_tpu=FLAGS.use_tpu, positive_weight=FLAGS.positive_weight, num_aggregation_labels=FLAGS.num_aggregation_labels, num_classification_labels=FLAGS.num_classification_labels, aggregation_loss_importance=FLAGS.aggregation_loss_importance, use_answer_as_supervision=FLAGS.use_answer_as_supervision, answer_loss_importance=FLAGS.answer_loss_importance, use_normalized_answer_loss=FLAGS.use_normalized_answer_loss, huber_loss_delta=FLAGS.huber_loss_delta, temperature=FLAGS.temperature, agg_temperature=FLAGS.agg_temperature, use_gumbel_for_cells=FLAGS.use_gumbel_for_cells, use_gumbel_for_agg=FLAGS.use_gumbel_for_agg, average_approximation_function=tapas_classifier_model.\ AverageApproximationFunction(FLAGS.average_approximation_function), cell_select_pref=FLAGS.cell_select_pref, answer_loss_cutoff=FLAGS.answer_loss_cutoff, grad_clipping=FLAGS.grad_clipping, disabled_features=FLAGS.disabled_features, max_num_rows=FLAGS.max_num_rows, max_num_columns=FLAGS.max_num_columns, average_logits_per_cell=FLAGS.average_logits_per_cell, init_cell_selection_weights_to_zero=FLAGS .init_cell_selection_weights_to_zero, select_one_column=FLAGS.select_one_column, allow_empty_column_selection=FLAGS.allow_empty_column_selection, disable_position_embeddings=FLAGS.disable_position_embeddings, disable_per_token_loss=FLAGS.disable_per_token_loss, reset_position_index_per_cell=FLAGS.reset_position_index_per_cell, span_prediction=tapas_classifier_model.SpanPredictionMode( FLAGS.span_prediction),) model_fn = tapas_classifier_model.model_fn_builder(tapas_config) estimator = experiment_utils.build_estimator(model_fn) if FLAGS.do_train: tf.io.gfile.makedirs(FLAGS.model_dir) bert_config.to_json_file( os.path.join(FLAGS.model_dir, "bert_config.json")) tapas_config.to_json_file( os.path.join(FLAGS.model_dir, "tapas_config.json")) train_input_fn = functools.partial( tapas_classifier_model.input_fn, name="train", file_patterns=FLAGS.input_file_train, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=True, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=False) estimator.train(input_fn=train_input_fn, max_steps=experiment_utils.num_train_steps()) eval_input_fn = functools.partial( tapas_classifier_model.input_fn, name="eval", file_patterns=FLAGS.input_file_eval, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=False, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=True) if FLAGS.do_eval: prev_checkpoint = None while True: checkpoint = estimator.latest_checkpoint() if checkpoint == prev_checkpoint: if FLAGS.minutes_to_sleep_before_predictions > 0: tf.logging.info("Sleeping %d mins before evaluation", FLAGS.minutes_to_sleep_before_predictions) time.sleep(FLAGS.minutes_to_sleep_before_predictions * 60) continue tf.logging.info("Running eval: %s", FLAGS.eval_name) result = estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.num_eval_steps, name=FLAGS.eval_name, checkpoint_path=checkpoint) tf.logging.info("Eval result:\n%s", result) current_step = int(os.path.basename(checkpoint).split("-")[1]) if experiment_utils.num_train_steps( ) is None or current_step >= experiment_utils.num_train_steps(): tf.logging.info("Evaluation finished after training step %d", current_step) break prev_checkpoint = checkpoint if FLAGS.do_predict: predict_input_fn = functools.partial( tapas_classifier_model.input_fn, name="predict", file_patterns=FLAGS.input_file_predict, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=False, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=True) prev_checkpoint = None while True: checkpoint = estimator.latest_checkpoint() if checkpoint == prev_checkpoint: tf.logging.info("Sleeping 5 mins before predicting") time.sleep(5 * 60) continue current_step = int(os.path.basename(checkpoint).split("-")[1]) def _predict_and_export_metrics( mode, input_fn, input_file, interactions_file, ): """Exports model predictions and calculates denotation metric.""" # Predict for each new checkpoint. tf.logging.info( "Running predictor for step %d (%s).", current_step, checkpoint, ) result = estimator.predict( input_fn=input_fn, checkpoint_path=checkpoint, ) if FLAGS.prediction_output_dir: output_dir = FLAGS.prediction_output_dir tf.io.gfile.makedirs(output_dir) else: output_dir = FLAGS.model_dir output_predict_file = os.path.join( output_dir, f"{mode}_results_{current_step}.tsv") prediction_utils.write_predictions( result, output_predict_file, do_model_aggregation, do_model_classification, FLAGS.cell_classification_threshold) if FLAGS.do_sequence_prediction: examples_by_position = prediction_utils.read_classifier_dataset( predict_data=input_file, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision) result_sequence = prediction_utils.compute_prediction_sequence( estimator=estimator, examples_by_position=examples_by_position) output_predict_file_sequence = os.path.join( FLAGS.model_dir, mode + "_results_sequence_{}.tsv").format(current_step) prediction_utils.write_predictions( result_sequence, output_predict_file_sequence, do_model_aggregation, do_model_classification, FLAGS.cell_classification_threshold) if FLAGS.input_file_predict is not None: _predict_and_export_metrics( mode="predict", input_fn=predict_input_fn, input_file=FLAGS.input_file_predict, interactions_file=FLAGS.predict_interactions_file) if FLAGS.input_file_eval is not None: _predict_and_export_metrics( mode="eval", input_fn=eval_input_fn, input_file=FLAGS.input_file_eval, interactions_file=FLAGS.eval_interactions_file) num_train_steps = experiment_utils.num_train_steps() if num_train_steps is None or current_step >= num_train_steps: tf.logging.info( "Predictor finished after training step %d", current_step, ) break prev_checkpoint = checkpoint
def main(_): do_model_aggregation = FLAGS.num_aggregation_labels > 0 do_model_classification = FLAGS.num_classification_labels > 0 bert_config = experiment_utils.bert_config_from_flags() total_steps = experiment_utils.num_train_steps() tapas_config = tapas_classifier_model.TapasClassifierConfig( bert_config=bert_config, init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate, num_train_steps=total_steps, num_warmup_steps=experiment_utils.num_warmup_steps(), use_tpu=FLAGS.use_tpu, positive_weight=FLAGS.positive_weight, num_aggregation_labels=FLAGS.num_aggregation_labels, num_classification_labels=FLAGS.num_classification_labels, aggregation_loss_importance=FLAGS.aggregation_loss_importance, use_answer_as_supervision=FLAGS.use_answer_as_supervision, answer_loss_importance=FLAGS.answer_loss_importance, use_normalized_answer_loss=FLAGS.use_normalized_answer_loss, huber_loss_delta=FLAGS.huber_loss_delta, temperature=FLAGS.temperature, agg_temperature=FLAGS.agg_temperature, use_gumbel_for_cells=FLAGS.use_gumbel_for_cells, use_gumbel_for_agg=FLAGS.use_gumbel_for_agg, average_approximation_function=( tapas_classifier_model.AverageApproximationFunction( FLAGS.average_approximation_function)), cell_select_pref=FLAGS.cell_select_pref, answer_loss_cutoff=FLAGS.answer_loss_cutoff, grad_clipping=FLAGS.grad_clipping, classification_label_weight={ int(pair.split(":")[0]): float(pair.split(":")[1]) for pair in FLAGS.classification_label_weight.split(",") if pair }, table_pruning_config_file=FLAGS.table_pruning_config_file, restrict_attention_mode=(attention_utils.RestrictAttentionMode( FLAGS.restrict_attention_mode)), restrict_attention_bucket_size=FLAGS.restrict_attention_bucket_size, restrict_attention_header_size=FLAGS.restrict_attention_header_size, restrict_attention_row_heads_ratio=( FLAGS.restrict_attention_row_heads_ratio), mask_examples_without_labels=FLAGS.mask_examples_without_labels, cell_cross_entropy_hard_em=FLAGS.cell_cross_entropy_hard_em, cell_cross_entropy=FLAGS.cell_cross_entropy, reset_output_cls=FLAGS.reset_output_cls, disabled_features=FLAGS.disabled_features, max_num_rows=FLAGS.max_num_rows, max_num_columns=FLAGS.max_num_columns, average_logits_per_cell=FLAGS.average_logits_per_cell, init_cell_selection_weights_to_zero=FLAGS .init_cell_selection_weights_to_zero, select_one_column=FLAGS.select_one_column, allow_empty_column_selection=FLAGS.allow_empty_column_selection, disable_position_embeddings=FLAGS.disable_position_embeddings, disable_per_token_loss=FLAGS.disable_per_token_loss, reset_position_index_per_cell=FLAGS.reset_position_index_per_cell, span_prediction=tapas_classifier_model.SpanPredictionMode( FLAGS.span_prediction), proj_value_length=_get_projection_length(FLAGS.proj_value_length), attention_bias_disabled=FLAGS.attention_bias_disabled, attention_bias_use_relative_scalar_only=FLAGS .attention_bias_use_relative_scalar_only, ) model_fn = tapas_classifier_model.model_fn_builder(tapas_config) estimator = experiment_utils.build_estimator(model_fn) if tapas_config.classification_label_weight: if any(x < 0 for x in tapas_config.classification_label_weight.values()): raise ValueError("Label weights cannot be negative in input: " f"{tapas_config.classification_label_weight}.") if any(x < 0 or x >= tapas_config.num_classification_labels for x in tapas_config.classification_label_weight.keys()): raise ValueError("Invalid label in label weights for input: " f"{tapas_config.classification_label_weight}.") if FLAGS.do_train: tf.io.gfile.makedirs(FLAGS.model_dir) # Copy the table pruning config if pruning is used. if FLAGS.table_pruning_config_file: table_pruning_path = os.path.join(FLAGS.model_dir, "table_pruning_config.textproto") tf.io.gfile.copy( FLAGS.table_pruning_config_file, table_pruning_path, overwrite=True) bert_config.to_json_file(os.path.join(FLAGS.model_dir, "bert_config.json")) tapas_config.to_json_file( os.path.join(FLAGS.model_dir, "tapas_config.json")) train_input_fn = functools.partial( tapas_classifier_model.input_fn, name="train", file_patterns=FLAGS.input_file_train, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=True, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=False) estimator.train(input_fn=train_input_fn, max_steps=total_steps) eval_input_fn = functools.partial( tapas_classifier_model.input_fn, name="eval", file_patterns=FLAGS.input_file_eval, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=False, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=not FLAGS.use_tpu) if FLAGS.do_eval: eval_name = FLAGS.eval_name if FLAGS.eval_name is not None else "default" for _, checkpoint in experiment_utils.iterate_checkpoints( model_dir=estimator.model_dir, total_steps=total_steps, marker_file_prefix=os.path.join(estimator.model_dir, f"eval_{eval_name}"), minutes_to_sleep=FLAGS.minutes_to_sleep_before_predictions): tf.logging.info("Running eval: %s", eval_name) result = estimator.evaluate( input_fn=eval_input_fn, steps=FLAGS.num_eval_steps, name=eval_name, checkpoint_path=checkpoint) tf.logging.info("Eval result:\n%s", result) if FLAGS.do_predict: predict_input_fn = functools.partial( tapas_classifier_model.input_fn, name="predict", file_patterns=FLAGS.input_file_predict, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=False, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=FLAGS.use_answer_as_supervision, include_id=not FLAGS.use_tpu) if FLAGS.prediction_output_dir: prediction_output_dir = FLAGS.prediction_output_dir tf.io.gfile.makedirs(prediction_output_dir) else: prediction_output_dir = estimator.model_dir marker_file_prefix = os.path.join(prediction_output_dir, "predict") # When two separate jobs are launched we don't want conflicting markers. if FLAGS.input_file_predict is not None: marker_file_prefix += "_test" if FLAGS.input_file_eval is not None: marker_file_prefix += "_dev" for current_step, checkpoint in experiment_utils.iterate_checkpoints( model_dir=estimator.model_dir, total_steps=total_steps, single_step=FLAGS.evaluated_checkpoint_step, marker_file_prefix=marker_file_prefix): try: if FLAGS.input_file_predict is not None: _predict_and_export_metrics( mode="predict", name=FLAGS.eval_name, input_fn=predict_input_fn, output_dir=prediction_output_dir, estimator=estimator, checkpoint=checkpoint, current_step=current_step, do_model_classification=do_model_classification, do_model_aggregation=do_model_aggregation, output_token_answers=not FLAGS.disable_per_token_loss) if FLAGS.input_file_eval is not None: _predict_and_export_metrics( mode="eval", name=FLAGS.eval_name, input_fn=eval_input_fn, output_dir=prediction_output_dir, estimator=estimator, checkpoint=checkpoint, current_step=current_step, do_model_classification=do_model_classification, do_model_aggregation=do_model_aggregation, output_token_answers=not FLAGS.disable_per_token_loss) except ValueError: tf.logging.error("Error getting predictions for checkpoint %s: %s", checkpoint, traceback.format_exc())
def _train_and_predict( task, tpu_options, test_batch_size, train_batch_size, gradient_accumulation_steps, bert_config_file, init_checkpoint, test_mode, mode, output_dir, model_dir, loop_predict, ): """Trains, produces test predictions and eval metric.""" file_utils.make_directories(model_dir) if task in (tasks.Task.SQA, tasks.Task.HYBRIDQA, tasks.Task.HYBRIDQA_RC): num_aggregation_labels = 0 num_classification_labels = 0 use_answer_as_supervision = False elif task in [ tasks.Task.WTQ, tasks.Task.WIKISQL, tasks.Task.WIKISQL_SUPERVISED ]: num_aggregation_labels = 4 num_classification_labels = 0 use_answer_as_supervision = task != tasks.Task.WIKISQL_SUPERVISED elif task == tasks.Task.TABFACT: num_classification_labels = 2 num_aggregation_labels = 0 use_answer_as_supervision = True elif task == tasks.Task.NQ_RETRIEVAL: num_aggregation_labels = 0 num_classification_labels = 2 use_answer_as_supervision = False else: raise ValueError(f'Unknown task: {task.name}') do_model_aggregation = num_aggregation_labels > 0 do_model_classification = num_classification_labels > 0 hparams = hparam_utils.get_hparams(task) if test_mode: if train_batch_size is None: train_batch_size = 1 test_batch_size = 1 num_train_steps = 10 num_warmup_steps = 1 else: if train_batch_size is None: train_batch_size = hparams['train_batch_size'] num_train_examples = hparams['num_train_examples'] num_train_steps = int(num_train_examples / train_batch_size) num_warmup_steps = int(num_train_steps * hparams['warmup_ratio']) bert_config = modeling.BertConfig.from_json_file(bert_config_file) if 'bert_config_attention_probs_dropout_prob' in hparams: bert_config.attention_probs_dropout_prob = hparams.get( 'bert_config_attention_probs_dropout_prob') if 'bert_config_hidden_dropout_prob' in hparams: bert_config.hidden_dropout_prob = hparams.get( 'bert_config_hidden_dropout_prob') tapas_config = tapas_classifier_model.TapasClassifierConfig( bert_config=bert_config, init_checkpoint=init_checkpoint, learning_rate=hparams['learning_rate'], num_train_steps=num_train_steps, num_warmup_steps=num_warmup_steps, use_tpu=tpu_options.use_tpu, positive_weight=10.0, num_aggregation_labels=num_aggregation_labels, num_classification_labels=num_classification_labels, aggregation_loss_importance=1.0, use_answer_as_supervision=use_answer_as_supervision, answer_loss_importance=1.0, use_normalized_answer_loss=False, huber_loss_delta=hparams.get('huber_loss_delta'), temperature=hparams.get('temperature', 1.0), agg_temperature=1.0, use_gumbel_for_cells=False, use_gumbel_for_agg=False, average_approximation_function=( tapas_classifier_model.AverageApproximationFunction.RATIO), cell_select_pref=hparams.get('cell_select_pref'), answer_loss_cutoff=hparams.get('answer_loss_cutoff'), grad_clipping=hparams.get('grad_clipping'), disabled_features=[], max_num_rows=64, max_num_columns=32, average_logits_per_cell=False, disable_per_token_loss=hparams.get('disable_per_token_loss', False), mask_examples_without_labels=hparams.get( 'mask_examples_without_labels', False), init_cell_selection_weights_to_zero=( hparams['init_cell_selection_weights_to_zero']), select_one_column=hparams['select_one_column'], allow_empty_column_selection=hparams['allow_empty_column_selection'], span_prediction=tapas_classifier_model.SpanPredictionMode( hparams.get('span_prediction', tapas_classifier_model.SpanPredictionMode.NONE)), disable_position_embeddings=False, reset_output_cls=FLAGS.reset_output_cls, reset_position_index_per_cell=FLAGS.reset_position_index_per_cell, table_pruning_config_file=FLAGS.table_pruning_config_file) model_fn = tapas_classifier_model.model_fn_builder(tapas_config) is_per_host = tf_estimator.tpu.InputPipelineConfig.PER_HOST_V2 tpu_cluster_resolver = None if tpu_options.use_tpu and tpu_options.tpu_name: tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu=tpu_options.tpu_name, zone=tpu_options.tpu_zone, project=tpu_options.gcp_project, ) run_config = tf_estimator.tpu.RunConfig( cluster=tpu_cluster_resolver, master=tpu_options.master, model_dir=model_dir, tf_random_seed=FLAGS.tf_random_seed, save_checkpoints_steps=1000, keep_checkpoint_max=5, keep_checkpoint_every_n_hours=4.0, tpu_config=tf_estimator.tpu.TPUConfig( iterations_per_loop=tpu_options.iterations_per_loop, num_shards=tpu_options.num_tpu_cores, per_host_input_for_training=is_per_host)) # If TPU is not available, this will fall back to normal Estimator on CPU/GPU. estimator = tf_estimator.tpu.TPUEstimator( params={'gradient_accumulation_steps': gradient_accumulation_steps}, use_tpu=tpu_options.use_tpu, model_fn=model_fn, config=run_config, train_batch_size=train_batch_size // gradient_accumulation_steps, eval_batch_size=None, predict_batch_size=test_batch_size) if mode == Mode.TRAIN: _print('Training') bert_config.to_json_file(os.path.join(model_dir, 'bert_config.json')) tapas_config.to_json_file(os.path.join(model_dir, 'tapas_config.json')) train_input_fn = functools.partial( tapas_classifier_model.input_fn, name='train', file_patterns=_get_train_examples_file(task, output_dir), data_format='tfrecord', compression_type=FLAGS.compression_type, is_training=True, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=_MAX_PREDICTIONS_PER_SEQ, add_aggregation_function_id=do_model_aggregation, add_classification_labels=do_model_classification, add_answer=use_answer_as_supervision, include_id=False, ) estimator.train( input_fn=train_input_fn, max_steps=tapas_config.num_train_steps, ) elif mode == Mode.PREDICT_AND_EVALUATE or mode == Mode.PREDICT: # Starts a continous eval that starts with the latest checkpoint and runs # until a checkpoint with 'num_train_steps' is reached. prev_checkpoint = None while True: checkpoint = estimator.latest_checkpoint() if not loop_predict and not checkpoint: raise ValueError(f'No checkpoint found at {model_dir}.') if loop_predict and checkpoint == prev_checkpoint: _print('Sleeping 5 mins before predicting') time.sleep(5 * 60) continue current_step = int(os.path.basename(checkpoint).split('-')[1]) _predict( estimator, task, output_dir, model_dir, do_model_aggregation, do_model_classification, use_answer_as_supervision, use_tpu=tapas_config.use_tpu, global_step=current_step, ) if mode == Mode.PREDICT_AND_EVALUATE: _eval(task=task, output_dir=output_dir, model_dir=model_dir, global_step=current_step) if not loop_predict or current_step >= tapas_config.num_train_steps: _print( f'Evaluation finished after training step {current_step}.') break prev_checkpoint = checkpoint else: raise ValueError(f'Unexpected mode: {mode}.')