def run_training(self, distribution, model_dir, steps_per_loop, run_eagerly): model_training_utils.run_customized_training_loop( strategy=distribution, model_fn=self._model_fn, loss_fn=tf.keras.losses.categorical_crossentropy, model_dir=model_dir, steps_per_epoch=20, steps_per_loop=steps_per_loop, epochs=2, train_input_fn=self._input_fn, eval_input_fn=self._input_fn, eval_steps=10, init_checkpoint=None, metric_fn=metric_fn, custom_callbacks=None, run_eagerly=run_eagerly)
def run_customized_training(strategy, bert_config, max_seq_length, max_predictions_per_seq, model_dir, steps_per_epoch, steps_per_loop, epochs, initial_lr, warmup_steps, input_files, train_batch_size, use_remote_tpu=False): """Run BERT pretrain model training using low-level API.""" train_input_fn = functools.partial(get_pretrain_input_data, input_files, max_seq_length, max_predictions_per_seq, train_batch_size, strategy) def _get_pretrain_model(): """Gets a pretraining model.""" pretrain_model, core_model = bert_models.pretrain_model( bert_config, max_seq_length, max_predictions_per_seq) pretrain_model.optimizer = optimization.create_optimizer( initial_lr, steps_per_epoch * epochs, warmup_steps) if FLAGS.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. pretrain_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( pretrain_model.optimizer) return pretrain_model, core_model trained_model = model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_pretrain_model, loss_fn=get_loss_fn( loss_factor=1.0 / strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0), model_dir=model_dir, train_input_fn=train_input_fn, steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop, epochs=epochs, use_remote_tpu=use_remote_tpu) # Creates the BERT core model outside distribution strategy scope. _, core_model = bert_models.pretrain_model(bert_config, max_seq_length, max_predictions_per_seq) # Restores the core model from model checkpoints and get a new checkpoint only # contains the core model. model_saving_utils.export_pretraining_checkpoint(checkpoint_dir=model_dir, model=core_model) return trained_model
def run_training(self, strategy, model_dir, steps_per_loop, run_eagerly): input_fn = create_fake_data_input_fn(batch_size=8, features_shape=[128], num_classes=3) model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=self._model_fn, loss_fn=tf.keras.losses.categorical_crossentropy, model_dir=model_dir, steps_per_epoch=20, steps_per_loop=steps_per_loop, epochs=2, train_input_fn=input_fn, eval_input_fn=input_fn, eval_steps=10, init_checkpoint=None, metric_fn=metric_fn, custom_callbacks=None, run_eagerly=run_eagerly)
def run_customized_training(strategy, bert_config, max_seq_length, max_predictions_per_seq, model_dir, steps_per_epoch, steps_per_loop, epochs, initial_lr, warmup_steps, input_files, train_batch_size): """Run BERT pretrain model training using low-level API.""" train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length, max_predictions_per_seq, train_batch_size) def _get_pretrain_model(): """Gets a pretraining model.""" pretrain_model, core_model = bert_models.pretrain_model( bert_config, max_seq_length, max_predictions_per_seq) pretrain_model.optimizer = optimization.create_optimizer( initial_lr, steps_per_epoch * epochs, warmup_steps) if FLAGS.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. pretrain_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( pretrain_model.optimizer) return pretrain_model, core_model trained_model = model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_pretrain_model, loss_fn=get_loss_fn( loss_factor=1.0 / strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0), model_dir=model_dir, train_input_fn=train_input_fn, steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop, epochs=epochs, sub_model_export_name='pretrained/bert_model') return trained_model
def run_customized_training(strategy, bert_config, max_seq_length, max_predictions_per_seq, model_dir, steps_per_epoch, steps_per_loop, epochs, initial_lr, warmup_steps, input_files, train_batch_size): """Run BERT pretrain model training using low-level API.""" train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length, max_predictions_per_seq, train_batch_size) def _get_pretrain_model(): """Gets a pretraining model.""" pretrain_model, core_model = bert_models.pretrain_model( bert_config, max_seq_length, max_predictions_per_seq, float_type=tf.float16 if FLAGS.use_fp16 else tf.float32) pretrain_model.optimizer = optimization.create_optimizer( initial_lr, steps_per_epoch * epochs, warmup_steps, FLAGS.optimizer_type) if FLAGS.use_fp16: pretrain_model.optimizer = tf.keras.mixed_precision.LossScaleOptimizer( pretrain_model.optimizer, dynamic=True) return pretrain_model, core_model dllogging = dllogger_class.dllogger_class(FLAGS.dllog_path) params = {'dllogging': dllogging, 'FLAGS': FLAGS} logging.info("init_lr = %f", initial_lr) trained_model = model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_pretrain_model, loss_fn=get_loss_fn(loss_factor=1.0 / strategy.num_replicas_in_sync if FLAGS.scale_loss and strategy else 1.0), model_dir=model_dir, train_input_fn=train_input_fn, steps_per_epoch=steps_per_epoch, num_accumulative_step=FLAGS.num_accumulation_steps, steps_per_loop=steps_per_loop, epochs=epochs, sub_model_export_name='pretrained/bert_model', init_checkpoint=FLAGS.init_checkpoint, hvd=hvd if FLAGS.use_horovod else None, params=params) return trained_model
def run_customized_training(strategy, bert_config, max_seq_length, max_predictions_per_seq, model_dir, steps_per_epoch, steps_per_loop, epochs, initial_lr, warmup_steps, input_files, train_batch_size): """Run BERT pretrain model training using low-level API.""" train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length, max_predictions_per_seq, train_batch_size) def _get_pretrain_model(): """Gets a pretraining model.""" pretrain_model, core_model = bert_models.pretrain_model( bert_config, max_seq_length, max_predictions_per_seq) optimizer = optimization.create_optimizer(initial_lr, steps_per_epoch * epochs, warmup_steps) pretrain_model.optimizer = performance.configure_optimizer( optimizer, use_float16=common_flags.use_float16(), use_graph_rewrite=common_flags.use_graph_rewrite()) return pretrain_model, core_model trained_model = model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_pretrain_model, loss_fn=get_loss_fn( loss_factor=1.0 / strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0), model_dir=model_dir, train_input_fn=train_input_fn, steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop, epochs=epochs, sub_model_export_name='pretrained/bert_model') return trained_model
def train_squad(strategy, input_meta_data, custom_callbacks=None, run_eagerly=False): """Run bert squad training.""" if strategy: logging.info('Training using customized training loop with distribution' ' strategy.') # Enables XLA in Session Config. Should not be set for TPU. keras_utils.set_config_v2(FLAGS.enable_xla) use_float16 = common_flags.use_float16() if use_float16: tf.keras.mixed_precision.experimental.set_policy('mixed_float16') bert_config = MODEL_CLASSES[FLAGS.model_type][0].from_json_file( FLAGS.bert_config_file) epochs = FLAGS.num_train_epochs num_train_examples = input_meta_data['train_data_size'] max_seq_length = input_meta_data['max_seq_length'] steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size) warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size) train_input_fn = get_dataset_fn( FLAGS.train_data_path, max_seq_length, FLAGS.train_batch_size, is_training=True) def _get_squad_model(): """Get Squad model and optimizer.""" squad_model, core_model = bert_models.squad_model( bert_config, max_seq_length, float_type=tf.float16 if use_float16 else tf.float32, hub_module_url=FLAGS.hub_module_url) squad_model.optimizer = optimization.create_optimizer( FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps) if use_float16: # Wraps optimizer with a LossScaleOptimizer. This is done automatically # in compile() with the "mixed_float16" policy, but since we do not call # compile(), we must wrap the optimizer manually. squad_model.optimizer = ( tf.keras.mixed_precision.experimental.LossScaleOptimizer( squad_model.optimizer, loss_scale=common_flags.get_loss_scale())) if FLAGS.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. squad_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( squad_model.optimizer) return squad_model, core_model # The original BERT model does not scale the loss by # 1/num_replicas_in_sync. It could be an accident. So, in order to use # the same hyper parameter, we do the same thing here by keeping each # replica loss as it is. loss_fn = get_loss_fn( loss_factor=1.0 / strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0) model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_squad_model, loss_fn=loss_fn, model_dir=FLAGS.model_dir, steps_per_epoch=steps_per_epoch, steps_per_loop=FLAGS.steps_per_loop, epochs=epochs, train_input_fn=train_input_fn, init_checkpoint=FLAGS.init_checkpoint, run_eagerly=run_eagerly, custom_callbacks=custom_callbacks)
def train_squad(strategy, input_meta_data, bert_config, custom_callbacks=None, run_eagerly=False): """Run bert squad training.""" if strategy: logging.info( 'Training using customized training loop with distribution' ' strategy.') # Enables XLA in Session Config. Should not be set for TPU. keras_utils.set_config_v2(FLAGS.enable_xla) use_float16 = common_flags.use_float16() if use_float16: tf.keras.mixed_precision.experimental.set_policy('mixed_float16') epochs = FLAGS.num_train_epochs num_train_examples = input_meta_data['train_data_size'] max_seq_length = input_meta_data['max_seq_length'] steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size) warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size) train_input_fn = get_dataset_fn(FLAGS.train_data_path, max_seq_length, FLAGS.train_batch_size, is_training=True) def _get_squad_model(): """Get Squad model and optimizer.""" squad_model, core_model = bert_models.squad_model( bert_config, max_seq_length, hub_module_url=FLAGS.hub_module_url, hub_module_trainable=FLAGS.hub_module_trainable) squad_model.optimizer = optimization.create_optimizer( FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps) if use_float16: # Wraps optimizer with a LossScaleOptimizer. This is done automatically # in compile() with the "mixed_float16" policy, but since we do not call # compile(), we must wrap the optimizer manually. squad_model.optimizer = ( tf.keras.mixed_precision.experimental.LossScaleOptimizer( squad_model.optimizer, loss_scale=common_flags.get_loss_scale())) if FLAGS.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. squad_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( squad_model.optimizer) return squad_model, core_model # The original BERT model does not scale the loss by # 1/num_replicas_in_sync. It could be an accident. So, in order to use # the same hyper parameter, we do the same thing here by keeping each # replica loss as it is. loss_fn = get_loss_fn( loss_factor=1.0 / strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0) # when all_reduce_sum_gradients = False, apply_gradients() no longer # implicitly allreduce gradients, users manually allreduce gradient and # passed the allreduced grads_and_vars. For now, the clip_by_global_norm # will be moved to before users' manual allreduce to keep the math # unchanged. def clip_by_global_norm_callback(grads_and_vars): grads, variables = zip(*grads_and_vars) (clipped_grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) return zip(clipped_grads, variables) model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_squad_model, loss_fn=loss_fn, model_dir=FLAGS.model_dir, steps_per_epoch=steps_per_epoch, steps_per_loop=FLAGS.steps_per_loop, epochs=epochs, train_input_fn=train_input_fn, init_checkpoint=FLAGS.init_checkpoint, run_eagerly=run_eagerly, custom_callbacks=custom_callbacks, explicit_allreduce=True, pre_allreduce_callbacks=[clip_by_global_norm_callback])
def run_bert_classifier(strategy, bert_config, input_meta_data, model_dir, epochs, steps_per_epoch, steps_per_loop, eval_steps, warmup_steps, initial_lr, init_checkpoint, custom_callbacks=None, run_eagerly=False): """Run BERT classifier training using low-level API.""" max_seq_length = input_meta_data['max_seq_length'] num_classes = input_meta_data['num_labels'] train_input_fn = functools.partial( input_pipeline.create_classifier_dataset, FLAGS.train_data_path, seq_length=max_seq_length, batch_size=FLAGS.train_batch_size) eval_input_fn = functools.partial(input_pipeline.create_classifier_dataset, FLAGS.eval_data_path, seq_length=max_seq_length, batch_size=FLAGS.eval_batch_size, is_training=False, drop_remainder=False) def _get_classifier_model(): """Gets a classifier model.""" classifier_model, core_model = (bert_models.classifier_model( bert_config, tf.float32, num_classes, max_seq_length, hub_module_url=FLAGS.hub_module_url)) classifier_model.optimizer = optimization.create_optimizer( initial_lr, steps_per_epoch * epochs, warmup_steps) if FLAGS.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. classifier_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( classifier_model.optimizer) return classifier_model, core_model loss_fn = get_loss_fn( num_classes, loss_factor=1.0 / strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0) # Defines evaluation metrics function, which will create metrics in the # correct device and strategy scope. def metric_fn(): return tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy', dtype=tf.float32) if FLAGS.use_keras_compile_fit: # Start training using Keras compile/fit API. logging.info('Training using TF 2.0 Keras compile/fit API with ' 'distrubuted strategy.') return run_keras_compile_fit(model_dir, strategy, _get_classifier_model, train_input_fn, eval_input_fn, loss_fn, metric_fn, init_checkpoint, epochs, steps_per_epoch, eval_steps, custom_callbacks=None) # Use user-defined loop to start training. logging.info('Training using customized training loop TF 2.0 with ' 'distrubuted strategy.') return model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_classifier_model, loss_fn=loss_fn, model_dir=model_dir, steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop, epochs=epochs, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, eval_steps=eval_steps, init_checkpoint=init_checkpoint, metric_fn=metric_fn, custom_callbacks=custom_callbacks, run_eagerly=run_eagerly)
def train_squad(strategy, input_meta_data, bert_config, custom_callbacks=None, run_eagerly=False): """Run bert squad training.""" if strategy: logging.info( 'Training using customized training loop with distribution' ' strategy.') # Enables XLA in Session Config. Should not be set for TPU. keras_utils.set_config_v2(FLAGS.enable_xla) performance.set_mixed_precision_policy(common_flags.dtype()) epochs = FLAGS.num_train_epochs num_train_examples = input_meta_data['train_data_size'] max_seq_length = input_meta_data['max_seq_length'] steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size) warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size) train_input_fn = get_dataset_fn(FLAGS.train_data_path, max_seq_length, FLAGS.train_batch_size, is_training=True) def _get_squad_model(): """Get Squad model and optimizer.""" squad_model, core_model = bert_models.squad_model( bert_config, max_seq_length, hub_module_url=FLAGS.hub_module_url, hub_module_trainable=FLAGS.hub_module_trainable) optimizer = optimization.create_optimizer(FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps) squad_model.optimizer = performance.configure_optimizer( optimizer, use_float16=common_flags.use_float16(), use_graph_rewrite=common_flags.use_graph_rewrite()) return squad_model, core_model # If explicit_allreduce = True, apply_gradients() no longer implicitly # allreduce gradients, users manually allreduce gradient and pass the # allreduced grads_and_vars to apply_gradients(). clip_by_global_norm will be # applied to allreduced gradients. def clip_by_global_norm_callback(grads_and_vars): grads, variables = zip(*grads_and_vars) (clipped_grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) return zip(clipped_grads, variables) model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_squad_model, loss_fn=get_loss_fn(), model_dir=FLAGS.model_dir, steps_per_epoch=steps_per_epoch, steps_per_loop=FLAGS.steps_per_loop, epochs=epochs, train_input_fn=train_input_fn, init_checkpoint=FLAGS.init_checkpoint, run_eagerly=run_eagerly, custom_callbacks=custom_callbacks, explicit_allreduce=False, post_allreduce_callbacks=[clip_by_global_norm_callback])
def run_bert_classifier(strategy, bert_config, input_meta_data, model_dir, epochs, steps_per_epoch, steps_per_loop, eval_steps, warmup_steps, initial_lr, init_checkpoint, train_input_fn, eval_input_fn, custom_callbacks=None, run_eagerly=False, use_keras_compile_fit=False): """Run BERT classifier training using low-level API.""" max_seq_length = input_meta_data['max_seq_length'] num_classes = input_meta_data['num_labels'] def _get_classifier_model(): """Gets a classifier model.""" classifier_model, core_model = (bert_models.classifier_model( bert_config, tf.float32, num_classes, max_seq_length, hub_module_url=FLAGS.hub_module_url)) classifier_model.optimizer = optimization.create_optimizer( initial_lr, steps_per_epoch * epochs, warmup_steps) if FLAGS.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. classifier_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( classifier_model.optimizer) return classifier_model, core_model # During distributed training, loss used for gradient computation is # summed over from all replicas. When Keras compile/fit() API is used, # the fit() API internally normalizes the loss by dividing the loss by # the number of replicas used for computation. However, when custom # training loop is used this is not done automatically and should be # done manually by the end user. loss_multiplier = 1.0 if FLAGS.scale_loss and not use_keras_compile_fit: loss_multiplier = 1.0 / strategy.num_replicas_in_sync loss_fn = get_loss_fn(num_classes, loss_factor=loss_multiplier) # Defines evaluation metrics function, which will create metrics in the # correct device and strategy scope. def metric_fn(): return tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy', dtype=tf.float32) if use_keras_compile_fit: # Start training using Keras compile/fit API. logging.info('Training using TF 2.0 Keras compile/fit API with ' 'distribution strategy.') return run_keras_compile_fit(model_dir, strategy, _get_classifier_model, train_input_fn, eval_input_fn, loss_fn, metric_fn, init_checkpoint, epochs, steps_per_epoch, eval_steps, custom_callbacks=None) # Use user-defined loop to start training. logging.info('Training using customized training loop TF 2.0 with ' 'distribution strategy.') return model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_classifier_model, loss_fn=loss_fn, model_dir=model_dir, steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop, epochs=epochs, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, eval_steps=eval_steps, init_checkpoint=init_checkpoint, metric_fn=metric_fn, custom_callbacks=custom_callbacks, run_eagerly=run_eagerly)
def run_classifier(self, train_input_fn, validation_input_fn, epochs, steps_per_epoch, validation_steps, num_classes): """Creates classifier and runs the classifier training.""" if epochs is None: epochs = self.default_training_epochs bert_config = bert_configs.BertConfig( 0, initializer_range=self.initializer_range, hidden_dropout_prob=self.dropout_rate) warmup_steps = int(epochs * steps_per_epoch * 0.1) initial_lr = self.learning_rate def _get_classifier_model(): """Gets a classifier model.""" classifier_model, core_model = (bert_models.classifier_model( bert_config, num_classes, self.seq_len, hub_module_url=self.uri)) classifier_model.optimizer = optimization.create_optimizer( initial_lr, steps_per_epoch * epochs, warmup_steps) return classifier_model, core_model # During distributed training, loss used for gradient computation is # summed over from all replicas. When Keras compile/fit() API is used, # the fit() API internally normalizes the loss by dividing the loss by # the number of replicas used for computation. However, when custom # training loop is used this is not done automatically and should be # done manually by the end user. loss_multiplier = 1.0 if self.scale_loss: loss_multiplier = 1.0 / self.strategy.num_replicas_in_sync loss_fn = self.get_classification_loss_fn(num_classes, loss_factor=loss_multiplier) # Defines evaluation metrics function, which will create metrics in the # correct device and strategy scope. def metric_fn(): return tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy', dtype=tf.float32) # Use user-defined loop to start training. tf.compat.v1.logging.info( 'Training using customized training loop TF 2.0 ' 'with distribution strategy.') bert_model = model_training_utils.run_customized_training_loop( strategy=self.strategy, model_fn=_get_classifier_model, loss_fn=loss_fn, model_dir=self.model_dir, steps_per_epoch=steps_per_epoch, steps_per_loop=self.steps_per_loop, epochs=epochs, train_input_fn=train_input_fn, eval_input_fn=validation_input_fn, eval_steps=validation_steps, init_checkpoint=None, metric_fn=metric_fn, custom_callbacks=None, run_eagerly=False) # Used in evaluation. with self.strategy.scope(): bert_model, _ = _get_classifier_model() checkpoint_path = tf.train.latest_checkpoint(self.model_dir) checkpoint = tf.train.Checkpoint(model=bert_model) checkpoint.restore(checkpoint_path).expect_partial() bert_model.compile(loss=loss_fn, metrics=[metric_fn()]) return bert_model
def train_squad(strategy, input_meta_data, custom_callbacks=None, run_eagerly=False): """Run bert squad training.""" if strategy: logging.info( 'Training using customized training loop with distribution' ' strategy.') # Enables XLA in Session Config. Should not be set for TPU. keras_utils.set_config_v2(FLAGS.enable_xla) use_float16 = common_flags.use_float16() if use_float16: tf.keras.mixed_precision.experimental.set_policy('mixed_float16') bert_config = MODEL_CLASSES[FLAGS.model_type][0].from_json_file( FLAGS.bert_config_file) epochs = FLAGS.num_train_epochs num_train_examples = input_meta_data['train_data_size'] max_seq_length = input_meta_data['max_seq_length'] global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps if FLAGS.use_horovod: global_batch_size *= hvd.size() steps_per_epoch = int(num_train_examples / global_batch_size) warmup_steps = int(epochs * num_train_examples * 0.1 / global_batch_size) train_input_fn = get_dataset_fn(FLAGS.train_data_path, max_seq_length, FLAGS.train_batch_size, is_training=True, use_horovod=FLAGS.use_horovod) if FLAGS.benchmark: steps_per_epoch = 800 epochs = 1 def _get_squad_model(): """Get Squad model and optimizer.""" squad_model, core_model = bert_models.squad_model( bert_config, max_seq_length, float_type=tf.float16 if FLAGS.use_fp16 else tf.float32, hub_module_url=FLAGS.hub_module_url) learning_rate = FLAGS.learning_rate * hvd.size( ) if FLAGS.use_horovod else FLAGS.learning_rate squad_model.optimizer = optimization.create_optimizer( learning_rate, steps_per_epoch * epochs, warmup_steps, FLAGS.optimizer_type) if FLAGS.use_fp16: squad_model.optimizer = tf.keras.mixed_precision.LossScaleOptimizer( squad_model.optimizer, dynamic=True) return squad_model, core_model # The original BERT model does not scale the loss by # 1/num_replicas_in_sync. It could be an accident. So, in order to use # the same hyper parameter, we do the same thing here by keeping each # replica loss as it is. loss_fn = get_loss_fn(loss_factor=1.0 / strategy.num_replicas_in_sync if FLAGS.scale_loss and strategy else 1.0) params = {'dllogging': input_meta_data['dllogging'], 'FLAGS': FLAGS} model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_squad_model, loss_fn=loss_fn, model_dir=FLAGS.model_dir, steps_per_epoch=steps_per_epoch, num_accumulative_step=FLAGS.num_accumulation_steps, steps_per_loop=FLAGS.steps_per_loop, epochs=epochs, train_input_fn=train_input_fn, init_checkpoint=FLAGS.init_checkpoint, hvd=hvd if FLAGS.use_horovod else None, run_eagerly=run_eagerly, custom_callbacks=custom_callbacks, params=params)
def run(args, strategy): """Pretrains model using TF2. Adapted from the tensorflow/models Github""" # CONFIG # Use timestamp to generate a unique run name run_name = get_run_name(args) logger.info(f'*** Starting run {run_name} ***') output_dir = f'gs://{args.bucket_name}/{args.project_name}/pretrain/runs/{run_name}' # pretrained model path try: pretrained_model_path = PRETRAINED_MODELS[ args.model_class]['bucket_location'] except KeyError: raise ValueError( f'Could not find a pretrained model matching the model class {args.model_class}' ) pretrained_model_config_path = f'gs://{args.bucket_name}/{pretrained_model_path}/bert_config.json' pretrained_model_checkpoint_path = f'gs://{args.bucket_name}/{pretrained_model_path}/bert_model.ckpt' # some logging logger.info( f'Running pretraining of model {args.model_class} on pretrain data {args.pretrain_data}' ) logger.info( f'Initializing model from checkpoint {pretrained_model_checkpoint_path}' ) # load model config based on model_class model_config = get_model_config(pretrained_model_config_path) # input data function train_input_fn = get_dataset_fn(args, _type='train') eval_input_fn = None eval_metric_fn = None if args.do_eval: logger.info(f'Setting up evaluation dataset') eval_metric_fn = get_eval_metric_fn eval_input_fn = get_dataset_fn(args, _type='dev') # model_fn def _get_pretrained_model(end_lr=0.0): """Gets a pretraining model.""" pretrain_model, core_model = bert_models.pretrain_model( model_config, args.max_seq_length, args.max_predictions_per_seq) optimizer = utils.optimizer.create_optimizer( args.learning_rate, steps_per_epoch * args.num_epochs, warmup_steps, args.end_lr, args.optimizer_type) pretrain_model.optimizer = configure_optimizer( optimizer, use_float16=args.dtype == 'fp16', use_graph_rewrite=False) return pretrain_model, core_model # custom callbacks summary_dir = os.path.join(output_dir, 'summaries') time_history_callback = keras_utils.TimeHistory( batch_size=args.train_batch_size, log_steps=args.time_history_log_steps, logdir=summary_dir) custom_callbacks = [time_history_callback] # Save an initial version of the log file data = { 'created_at': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'run_name': run_name, 'num_train_steps': args.num_steps_per_epoch * args.num_epochs, 'eval_steps': args.eval_steps, 'model_dir': output_dir, 'output_dir': output_dir, **vars(args), } f_path_training_log = os.path.join(output_dir, 'run_logs.json') logger.info( f'Writing training preliminary log to {f_path_training_log}...') save_to_json(data, f_path_training_log) # run training loop logger.info( f'Run training for {args.num_epochs:,} epochs, {args.num_steps_per_epoch:,} steps each, processing {args.num_epochs*args.num_steps_per_epoch*args.train_batch_size:,} training examples in total...' ) time_start = time.time() model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_pretrained_model, loss_fn=get_loss_fn(), scale_loss=True, model_dir=output_dir, train_input_fn=train_input_fn, steps_per_epoch=args.num_steps_per_epoch, steps_per_loop=args.steps_per_loop, epochs=args.num_epochs, eval_input_fn=eval_input_fn, eval_steps=args.eval_steps, metric_fn=eval_metric_fn, init_checkpoint=pretrained_model_checkpoint_path, custom_callbacks=custom_callbacks, run_eagerly=False, sub_model_export_name='pretrained/bert_model', explicit_allreduce=False, pre_allreduce_callbacks=None, post_allreduce_callbacks=None) time_end = time.time() training_time_min = (time_end - time_start) / 60 logger.info(f'Finished training after {training_time_min:.1f} min') # Write to run directory data['training_time_min'] = training_time_min logger.info(f'Writing final log to {f_path_training_log}...') save_to_json(data, f_path_training_log)
def run_bert_classifier(strategy, bert_config, input_meta_data, model_dir, epochs, steps_per_epoch, steps_per_loop, eval_steps, warmup_steps, initial_lr, init_checkpoint, train_input_fn, eval_input_fn, custom_callbacks=None, run_eagerly=False, use_keras_compile_fit=False): """Run BERT classifier training using low-level API.""" max_seq_length = input_meta_data['max_seq_length'] num_classes = input_meta_data['num_labels'] def _get_classifier_model(): """Gets a classifier model.""" classifier_model, core_model = (bert_models.classifier_model( bert_config, num_classes, max_seq_length, hub_module_url=FLAGS.hub_module_url, hub_module_trainable=FLAGS.hub_module_trainable)) optimizer = optimization.create_optimizer(initial_lr, steps_per_epoch * epochs, warmup_steps) classifier_model.optimizer = performance.configure_optimizer( optimizer, use_float16=common_flags.use_float16(), use_graph_rewrite=common_flags.use_graph_rewrite()) return classifier_model, core_model loss_fn = get_loss_fn(num_classes) # Defines evaluation metrics function, which will create metrics in the # correct device and strategy scope. def metric_fn(): return tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy', dtype=tf.float32) if use_keras_compile_fit: # Start training using Keras compile/fit API. logging.info('Training using TF 2.0 Keras compile/fit API with ' 'distribution strategy.') return run_keras_compile_fit(model_dir, strategy, _get_classifier_model, train_input_fn, eval_input_fn, loss_fn, metric_fn, init_checkpoint, epochs, steps_per_epoch, steps_per_loop, eval_steps, custom_callbacks=custom_callbacks) # Use user-defined loop to start training. logging.info('Training using customized training loop TF 2.0 with ' 'distribution strategy.') return model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_classifier_model, loss_fn=loss_fn, model_dir=model_dir, steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop, epochs=epochs, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, eval_steps=eval_steps, init_checkpoint=init_checkpoint, metric_fn=metric_fn, custom_callbacks=custom_callbacks, run_eagerly=run_eagerly)