def run_customized_training(strategy, bert_config, max_seq_length, max_predictions_per_seq, model_dir, steps_per_epoch, steps_per_loop, epochs, initial_lr, warmup_steps, input_files, train_batch_size, use_remote_tpu=False): """Run BERT pretrain model training using low-level API.""" train_input_fn = functools.partial(get_pretrain_input_data, input_files, max_seq_length, max_predictions_per_seq, train_batch_size, strategy) def _get_pretrain_model(): """Gets a pretraining model.""" pretrain_model, core_model = bert_models.pretrain_model( bert_config, max_seq_length, max_predictions_per_seq) pretrain_model.optimizer = optimization.create_optimizer( initial_lr, steps_per_epoch * epochs, warmup_steps) if FLAGS.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. pretrain_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( pretrain_model.optimizer) return pretrain_model, core_model trained_model = model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_pretrain_model, loss_fn=get_loss_fn( loss_factor=1.0 / strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0), model_dir=model_dir, train_input_fn=train_input_fn, steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop, epochs=epochs, use_remote_tpu=use_remote_tpu) # Creates the BERT core model outside distribution strategy scope. _, core_model = bert_models.pretrain_model(bert_config, max_seq_length, max_predictions_per_seq) # Restores the core model from model checkpoints and get a new checkpoint only # contains the core model. model_saving_utils.export_pretraining_checkpoint(checkpoint_dir=model_dir, model=core_model) return trained_model
def _get_pretrain_model(): """Gets a pretraining model.""" pretrain_model, core_model = bert_models.pretrain_model( bert_config, max_seq_length, max_predictions_per_seq, float_type=tf.float16 if FLAGS.use_fp16 else tf.float32) pretrain_model.optimizer = optimization.create_optimizer( initial_lr, steps_per_epoch * epochs, warmup_steps, FLAGS.optimizer_type) if FLAGS.use_fp16: pretrain_model.optimizer = tf.keras.mixed_precision.LossScaleOptimizer( pretrain_model.optimizer, dynamic=True) return pretrain_model, core_model
def _get_pretrain_model(): """Gets a pretraining model.""" pretrain_model, core_model = bert_models.pretrain_model( bert_config, max_seq_length, max_predictions_per_seq) pretrain_model.optimizer = optimization.create_optimizer( initial_lr, steps_per_epoch * epochs, warmup_steps) if FLAGS.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. pretrain_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( pretrain_model.optimizer) return pretrain_model, core_model