def run_customized_training(strategy, bert_config, max_seq_length, max_predictions_per_seq, model_dir, steps_per_epoch, epochs, initial_lr, warmup_steps, input_files, train_batch_size, use_remote_tpu=False): """Run BERT pretrain model training using low-level API.""" train_input_fn = functools.partial(get_pretrain_input_data, input_files, max_seq_length, max_predictions_per_seq, train_batch_size) def _get_pretrain_model(): pretrain_model, core_model = bert_models.pretrain_model( bert_config, max_seq_length, max_predictions_per_seq) pretrain_model.optimizer = optimization.create_optimizer( initial_lr, steps_per_epoch * epochs, warmup_steps) return pretrain_model, core_model model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_pretrain_model, loss_fn=get_loss_fn(), model_dir=model_dir, train_input_fn=train_input_fn, steps_per_epoch=steps_per_epoch, epochs=epochs, use_remote_tpu=use_remote_tpu)
def train_squad(strategy, input_meta_data, custom_callbacks=None, run_eagerly=False): """Run bert squad training.""" if strategy: logging.info( 'Training using customized training loop with distribution' ' strategy.') bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) epochs = FLAGS.num_train_epochs num_train_examples = input_meta_data['train_data_size'] max_seq_length = input_meta_data['max_seq_length'] steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size) warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size) train_input_fn = functools.partial(input_pipeline.create_squad_dataset, FLAGS.train_data_path, max_seq_length, FLAGS.train_batch_size, is_training=True) def _get_squad_model(): squad_model, core_model = bert_models.squad_model( bert_config, max_seq_length, float_type=tf.float32) squad_model.optimizer = optimization.create_optimizer( FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps) return squad_model, core_model # The original BERT model does not scale the loss by # 1/num_replicas_in_sync. It could be an accident. So, in order to use # the same hyper parameter, we do the same thing here by keeping each # replica loss as it is. loss_fn = get_loss_fn(loss_scale=1.0) use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu) model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_squad_model, loss_fn=loss_fn, model_dir=FLAGS.model_dir, steps_per_epoch=steps_per_epoch, steps_per_loop=FLAGS.steps_per_loop, epochs=epochs, train_input_fn=train_input_fn, init_checkpoint=FLAGS.init_checkpoint, use_remote_tpu=use_remote_tpu, run_eagerly=run_eagerly, custom_callbacks=custom_callbacks)
def run_customized_training(strategy, bert_config, input_meta_data, model_dir, epochs, steps_per_epoch, steps_per_loop, eval_steps, warmup_steps, initial_lr, init_checkpoint, use_remote_tpu=False, custom_callbacks=None): """Run BERT classifier training using low-level API.""" max_seq_length = input_meta_data['max_seq_length'] num_classes = input_meta_data['num_labels'] train_input_fn = functools.partial( input_pipeline.create_classifier_dataset, FLAGS.train_data_path, seq_length=max_seq_length, batch_size=FLAGS.train_batch_size) eval_input_fn = functools.partial(input_pipeline.create_classifier_dataset, FLAGS.eval_data_path, seq_length=max_seq_length, batch_size=FLAGS.eval_batch_size, is_training=False, drop_remainder=False) def _get_classifier_model(): classifier_model, core_model = (bert_models.classifier_model( bert_config, tf.float32, num_classes, max_seq_length)) classifier_model.optimizer = optimization.create_optimizer( initial_lr, steps_per_epoch * epochs, warmup_steps) return classifier_model, core_model loss_fn = get_loss_fn(num_classes, loss_scale=1.0) # Defines evaluation metrics function, which will create metrics in the # correct device and strategy scope. def metric_fn(): return tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy', dtype=tf.float32) return model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_classifier_model, loss_fn=loss_fn, model_dir=model_dir, steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop, epochs=epochs, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, eval_steps=eval_steps, init_checkpoint=init_checkpoint, metric_fn=metric_fn, use_remote_tpu=use_remote_tpu, custom_callbacks=custom_callbacks)
def run_customized_training(strategy, bert_config, max_seq_length, max_predictions_per_seq, model_dir, steps_per_epoch, steps_per_loop, epochs, initial_lr, warmup_steps, input_files, train_batch_size, use_remote_tpu=False): """Run BERT pretrain model training using low-level API.""" train_input_fn = functools.partial(get_pretrain_input_data, input_files, max_seq_length, max_predictions_per_seq, train_batch_size, strategy) def _get_pretrain_model(): """Gets a pretraining model.""" pretrain_model, core_model = bert_models.pretrain_model( bert_config, max_seq_length, max_predictions_per_seq) pretrain_model.optimizer = optimization.create_optimizer( initial_lr, steps_per_epoch * epochs, warmup_steps) if FLAGS.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. pretrain_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( pretrain_model.optimizer) return pretrain_model, core_model trained_model = model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_pretrain_model, loss_fn=get_loss_fn( loss_factor=1.0 / strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0), model_dir=model_dir, train_input_fn=train_input_fn, steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop, epochs=epochs, use_remote_tpu=use_remote_tpu) # Creates the BERT core model outside distribution strategy scope. _, core_model = bert_models.pretrain_model(bert_config, max_seq_length, max_predictions_per_seq) # Restores the core model from model checkpoints and get a new checkpoint only # contains the core model. model_saving_utils.export_pretraining_checkpoint( checkpoint_dir=model_dir, model=core_model) return trained_model
def run_customized_training(strategy, bert_config, max_seq_length, max_predictions_per_seq, model_dir, steps_per_epoch, steps_per_loop, epochs, initial_lr, warmup_steps, input_files, train_batch_size, use_remote_tpu=False): """Run BERT pretrain model training using low-level API.""" train_input_fn = functools.partial(get_pretrain_input_data, input_files, max_seq_length, max_predictions_per_seq, train_batch_size, strategy) def _get_pretrain_model(): pretrain_model, core_model = bert_models.pretrain_model( bert_config, max_seq_length, max_predictions_per_seq) pretrain_model.optimizer = optimization.create_optimizer( initial_lr, steps_per_epoch * epochs, warmup_steps) return pretrain_model, core_model trained_model = model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_pretrain_model, loss_fn=get_loss_fn( loss_factor=1.0 / strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0), model_dir=model_dir, train_input_fn=train_input_fn, steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop, epochs=epochs, use_remote_tpu=use_remote_tpu) # Creates the BERT core model outside distribution strategy scope. _, core_model = bert_models.pretrain_model(bert_config, max_seq_length, max_predictions_per_seq) # Restores the core model from model checkpoints and get a new checkpoint only # contains the core model. model_saving_utils.export_pretraining_checkpoint(checkpoint_dir=model_dir, model=core_model) return trained_model
def train_squad(strategy, input_meta_data, custom_callbacks=None, run_eagerly=False): """Run bert squad training.""" if strategy: logging.info('Training using customized training loop with distribution' ' strategy.') # Enables XLA in Session Config. Should not be set for TPU. keras_utils.set_config_v2(FLAGS.enable_xla) use_float16 = common_flags.use_float16() if use_float16: policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16') tf.keras.mixed_precision.experimental.set_policy(policy) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) epochs = FLAGS.num_train_epochs num_train_examples = input_meta_data['train_data_size'] max_seq_length = input_meta_data['max_seq_length'] steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size) warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size) train_input_fn = functools.partial( input_pipeline.create_squad_dataset, FLAGS.train_data_path, max_seq_length, FLAGS.train_batch_size, is_training=True) def _get_squad_model(): """Get Squad model and optimizer.""" squad_model, core_model = bert_models.squad_model( bert_config, max_seq_length, float_type=tf.float16 if use_float16 else tf.float32) squad_model.optimizer = optimization.create_optimizer( FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps) if use_float16: # Wraps optimizer with a LossScaleOptimizer. This is done automatically # in compile() with the "mixed_float16" policy, but since we do not call # compile(), we must wrap the optimizer manually. squad_model.optimizer = ( tf.keras.mixed_precision.experimental.LossScaleOptimizer( squad_model.optimizer, loss_scale=common_flags.get_loss_scale())) return squad_model, core_model # The original BERT model does not scale the loss by # 1/num_replicas_in_sync. It could be an accident. So, in order to use # the same hyper parameter, we do the same thing here by keeping each # replica loss as it is. loss_fn = get_loss_fn(loss_factor=1.0) use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu) model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_squad_model, loss_fn=loss_fn, model_dir=FLAGS.model_dir, steps_per_epoch=steps_per_epoch, steps_per_loop=FLAGS.steps_per_loop, epochs=epochs, train_input_fn=train_input_fn, init_checkpoint=FLAGS.init_checkpoint, use_remote_tpu=use_remote_tpu, run_eagerly=run_eagerly, custom_callbacks=custom_callbacks)
def run_customized_training(strategy, bert_config, input_meta_data, model_dir, epochs, steps_per_epoch, steps_per_loop, eval_steps, warmup_steps, initial_lr, init_checkpoint, use_remote_tpu=False, custom_callbacks=None, run_eagerly=False): """Run BERT classifier training using low-level API.""" max_seq_length = input_meta_data['max_seq_length'] num_classes = input_meta_data['num_labels'] train_input_fn = functools.partial( input_pipeline.create_classifier_dataset, FLAGS.train_data_path, seq_length=max_seq_length, batch_size=FLAGS.train_batch_size) eval_input_fn = functools.partial( input_pipeline.create_classifier_dataset, FLAGS.eval_data_path, seq_length=max_seq_length, batch_size=FLAGS.eval_batch_size, is_training=False, drop_remainder=False) def _get_classifier_model(): """Gets a classifier model.""" classifier_model, core_model = ( bert_models.classifier_model(bert_config, tf.float32, num_classes, max_seq_length)) classifier_model.optimizer = optimization.create_optimizer( initial_lr, steps_per_epoch * epochs, warmup_steps) if FLAGS.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. classifier_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( classifier_model.optimizer) return classifier_model, core_model loss_fn = get_loss_fn( num_classes, loss_factor=1.0 / strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0) # Defines evaluation metrics function, which will create metrics in the # correct device and strategy scope. def metric_fn(): return tf.keras.metrics.SparseCategoricalAccuracy( 'test_accuracy', dtype=tf.float32) return model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_classifier_model, loss_fn=loss_fn, model_dir=model_dir, steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop, epochs=epochs, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, eval_steps=eval_steps, init_checkpoint=init_checkpoint, metric_fn=metric_fn, use_remote_tpu=use_remote_tpu, custom_callbacks=custom_callbacks, run_eagerly=run_eagerly)