def test_train_check_callbacks(self, distribution): model_dir = self.get_temp_dir() callback = RecordingCallback() callbacks = [callback] input_fn = create_fake_data_input_fn(batch_size=8, features_shape=[128], num_classes=3) model_training_utils.run_customized_training_loop( strategy=distribution, model_fn=self._model_fn, loss_fn=tf.keras.losses.categorical_crossentropy, model_dir=model_dir, steps_per_epoch=20, steps_per_loop=10, epochs=2, train_input_fn=input_fn, eval_input_fn=input_fn, eval_steps=10, init_checkpoint=None, metric_fn=metric_fn, custom_callbacks=callbacks, run_eagerly=False) self.assertEqual(callback.epoch_begin, [(1, {}), (2, {})]) epoch_ends, epoch_end_infos = zip(*callback.epoch_end) self.assertEqual(list(epoch_ends), [1, 2]) for info in epoch_end_infos: self.assertIn('accuracy', info) self.assertEqual(callback.batch_begin, [(0, {}), (10, {}), (20, {}), (30, {})]) batch_ends, batch_end_infos = zip(*callback.batch_end) self.assertEqual(list(batch_ends), [9, 19, 29, 39]) for info in batch_end_infos: self.assertIn('loss', info)
def run_customized_training(strategy, bert_config, init_checkpoint, max_seq_length, max_predictions_per_seq, model_dir, steps_per_epoch, steps_per_loop, epochs, initial_lr, warmup_steps, end_lr, optimizer_type, input_files, train_batch_size, use_next_sentence_label=True, train_summary_interval=0, custom_callbacks=None): """Run BERT pretrain model training using low-level API.""" train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length, max_predictions_per_seq, train_batch_size, use_next_sentence_label) def _get_pretrain_model(): """Gets a pretraining model.""" pretrain_model, core_model = bert_models.pretrain_model( bert_config, max_seq_length, max_predictions_per_seq, use_next_sentence_label=use_next_sentence_label) optimizer = optimization.create_optimizer(initial_lr, steps_per_epoch * epochs, warmup_steps, end_lr, optimizer_type) pretrain_model.optimizer = performance.configure_optimizer( optimizer, use_float16=common_flags.use_float16(), use_graph_rewrite=common_flags.use_graph_rewrite()) return pretrain_model, core_model trained_model = model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_pretrain_model, loss_fn=get_loss_fn(), scale_loss=FLAGS.scale_loss, model_dir=model_dir, init_checkpoint=init_checkpoint, train_input_fn=train_input_fn, steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop, epochs=epochs, sub_model_export_name='pretrained/bert_model', train_summary_interval=train_summary_interval, custom_callbacks=custom_callbacks) return trained_model
def run_training(self, strategy, model_dir, steps_per_loop, run_eagerly): input_fn = create_fake_data_input_fn(batch_size=8, features_shape=[128], num_classes=3) model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=self._model_fn, loss_fn=tf.keras.losses.categorical_crossentropy, model_dir=model_dir, steps_per_epoch=20, steps_per_loop=steps_per_loop, epochs=2, train_input_fn=input_fn, eval_input_fn=input_fn, eval_steps=10, init_checkpoint=None, metric_fn=metric_fn, custom_callbacks=None, run_eagerly=run_eagerly)
def run_customized_training(strategy, bert_config, max_seq_length, max_predictions_per_seq, model_dir, steps_per_epoch, steps_per_loop, epochs, initial_lr, warmup_steps, input_files, train_batch_size): """Run BERT pretrain model training using low-level API.""" train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length, max_predictions_per_seq, train_batch_size) def _get_pretrain_model(): """Gets a pretraining model.""" pretrain_model, core_model = bert_models.pretrain_model( bert_config, max_seq_length, max_predictions_per_seq) optimizer = optimization.create_optimizer( initial_lr, steps_per_epoch * epochs, warmup_steps) pretrain_model.optimizer = performance.configure_optimizer( optimizer, use_float16=common_flags.use_float16(), use_graph_rewrite=common_flags.use_graph_rewrite()) return pretrain_model, core_model trained_model = model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_pretrain_model, loss_fn=get_loss_fn(), scale_loss=FLAGS.scale_loss, model_dir=model_dir, train_input_fn=train_input_fn, steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop, epochs=epochs, sub_model_export_name='pretrained/bert_model') return trained_model
def train_squad(strategy, input_meta_data, bert_config, custom_callbacks=None, run_eagerly=False, init_checkpoint=None, sub_model_export_name=None): """Run bert squad training.""" if strategy: logging.info( 'Training using customized training loop with distribution' ' strategy.') # Enables XLA in Session Config. Should not be set for TPU. keras_utils.set_session_config(FLAGS.enable_xla) performance.set_mixed_precision_policy(common_flags.dtype()) epochs = FLAGS.num_train_epochs num_train_examples = input_meta_data['train_data_size'] max_seq_length = input_meta_data['max_seq_length'] steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size) warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size) train_input_fn = get_dataset_fn(FLAGS.train_data_path, max_seq_length, FLAGS.train_batch_size, is_training=True) def _get_squad_model(): """Get Squad model and optimizer.""" squad_model, core_model = bert_models.squad_model( bert_config, max_seq_length, hub_module_url=FLAGS.hub_module_url, hub_module_trainable=FLAGS.hub_module_trainable) optimizer = optimization.create_optimizer(FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps, FLAGS.end_lr, FLAGS.optimizer_type) squad_model.optimizer = performance.configure_optimizer( optimizer, use_float16=common_flags.use_float16(), use_graph_rewrite=common_flags.use_graph_rewrite()) return squad_model, core_model # Only when explicit_allreduce = True, post_allreduce_callbacks and # allreduce_bytes_per_pack will take effect. optimizer.apply_gradients() no # longer implicitly allreduce gradients, users manually allreduce gradient and # pass the allreduced grads_and_vars to apply_gradients(). # With explicit_allreduce = True, clip_by_global_norm is moved to after # allreduce. model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_squad_model, loss_fn=get_loss_fn(), model_dir=FLAGS.model_dir, steps_per_epoch=steps_per_epoch, steps_per_loop=FLAGS.steps_per_loop, epochs=epochs, train_input_fn=train_input_fn, init_checkpoint=init_checkpoint or FLAGS.init_checkpoint, sub_model_export_name=sub_model_export_name, run_eagerly=run_eagerly, custom_callbacks=custom_callbacks, explicit_allreduce=FLAGS.explicit_allreduce, pre_allreduce_callbacks=[ model_training_utils.clip_by_global_norm_callback ], allreduce_bytes_per_pack=FLAGS.allreduce_bytes_per_pack)
def run_bert_classifier(strategy, bert_config, input_meta_data, model_dir, epochs, steps_per_epoch, steps_per_loop, eval_steps, warmup_steps, initial_lr, init_checkpoint, train_input_fn, eval_input_fn, custom_callbacks=None, run_eagerly=False, use_keras_compile_fit=False): """Run BERT classifier training using low-level API.""" max_seq_length = input_meta_data['max_seq_length'] num_classes = input_meta_data['num_labels'] def _get_classifier_model(): """Gets a classifier model.""" classifier_model, core_model = (bert_models.classifier_model( bert_config, num_classes, max_seq_length, hub_module_url=FLAGS.hub_module_url, hub_module_trainable=FLAGS.hub_module_trainable)) optimizer = optimization.create_optimizer(initial_lr, steps_per_epoch * epochs, warmup_steps) classifier_model.optimizer = performance.configure_optimizer( optimizer, use_float16=common_flags.use_float16(), use_graph_rewrite=common_flags.use_graph_rewrite()) return classifier_model, core_model loss_fn = get_loss_fn(num_classes) # Defines evaluation metrics function, which will create metrics in the # correct device and strategy scope. def metric_fn(): return tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy', dtype=tf.float32) if use_keras_compile_fit: # Start training using Keras compile/fit API. logging.info('Training using TF 2.0 Keras compile/fit API with ' 'distribution strategy.') return run_keras_compile_fit(model_dir, strategy, _get_classifier_model, train_input_fn, eval_input_fn, loss_fn, metric_fn, init_checkpoint, epochs, steps_per_epoch, steps_per_loop, eval_steps, custom_callbacks=custom_callbacks) # Use user-defined loop to start training. logging.info('Training using customized training loop TF 2.0 with ' 'distribution strategy.') return model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_classifier_model, loss_fn=loss_fn, model_dir=model_dir, steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop, epochs=epochs, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, eval_steps=eval_steps, init_checkpoint=init_checkpoint, metric_fn=metric_fn, custom_callbacks=custom_callbacks, run_eagerly=run_eagerly)
def run_bert_classifier(strategy, bert_config, input_meta_data, model_dir, epochs, steps_per_epoch, steps_per_loop, eval_steps, warmup_steps, initial_lr, init_checkpoint, train_input_fn, eval_input_fn, custom_callbacks=None, run_eagerly=False, use_keras_compile_fit=False): """Run BERT classifier training using low-level API.""" max_seq_length = input_meta_data['max_seq_length'] num_classes = input_meta_data['num_labels'] def _get_classifier_model(): """Gets a classifier model.""" classifier_model, core_model = (bert_models.classifier_model( bert_config, num_classes, max_seq_length, hub_module_url=FLAGS.hub_module_url, hub_module_trainable=FLAGS.hub_module_trainable)) classifier_model.optimizer = optimization.create_optimizer( initial_lr, steps_per_epoch * epochs, warmup_steps) if FLAGS.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. classifier_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( classifier_model.optimizer) return classifier_model, core_model # During distributed training, loss used for gradient computation is # summed over from all replicas. When Keras compile/fit() API is used, # the fit() API internally normalizes the loss by dividing the loss by # the number of replicas used for computation. However, when custom # training loop is used this is not done automatically and should be # done manually by the end user. loss_multiplier = 1.0 if FLAGS.scale_loss and not use_keras_compile_fit: loss_multiplier = 1.0 / strategy.num_replicas_in_sync loss_fn = get_loss_fn(num_classes, loss_factor=loss_multiplier) # Defines evaluation metrics function, which will create metrics in the # correct device and strategy scope. def metric_fn(): return tf.keras.metrics.SparseCategoricalAccuracy('accuracy', dtype=tf.float32) if use_keras_compile_fit: # Start training using Keras compile/fit API. logging.info('Training using TF 2.0 Keras compile/fit API with ' 'distribution strategy.') return run_keras_compile_fit(model_dir, strategy, _get_classifier_model, train_input_fn, eval_input_fn, loss_fn, metric_fn, init_checkpoint, epochs, steps_per_epoch, eval_steps, input_meta_data['labels_list'], custom_callbacks=None) # Use user-defined loop to start training. logging.info('Training using customized training loop TF 2.0 with ' 'distribution strategy.') return model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_classifier_model, loss_fn=loss_fn, model_dir=model_dir, steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop, epochs=epochs, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, eval_steps=eval_steps, init_checkpoint=init_checkpoint, metric_fn=metric_fn, custom_callbacks=custom_callbacks, run_eagerly=run_eagerly)
def train_squad(strategy, input_meta_data, bert_config, custom_callbacks=None, run_eagerly=False): """Run bert squad training.""" if strategy: logging.info( 'Training using customized training loop with distribution' ' strategy.') # Enables XLA in Session Config. Should not be set for TPU. keras_utils.set_config_v2(FLAGS.enable_xla) performance.set_mixed_precision_policy(common_flags.dtype()) epochs = FLAGS.num_train_epochs num_train_examples = input_meta_data['train_data_size'] max_seq_length = input_meta_data['max_seq_length'] steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size) warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size) train_input_fn = get_dataset_fn(FLAGS.train_data_path, max_seq_length, FLAGS.train_batch_size, is_training=True) def _get_squad_model(): """Get Squad model and optimizer.""" squad_model, core_model = bert_models.squad_model( bert_config, max_seq_length, hub_module_url=FLAGS.hub_module_url, hub_module_trainable=FLAGS.hub_module_trainable) optimizer = optimization.create_optimizer(FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps, FLAGS.optimizer_type) squad_model.optimizer = performance.configure_optimizer( optimizer, use_float16=common_flags.use_float16(), use_graph_rewrite=common_flags.use_graph_rewrite()) return squad_model, core_model # If explicit_allreduce = True, apply_gradients() no longer implicitly # allreduce gradients, users manually allreduce gradient and pass the # allreduced grads_and_vars to apply_gradients(). clip_by_global_norm will be # applied to allreduced gradients. def clip_by_global_norm_callback(grads_and_vars): grads, variables = zip(*grads_and_vars) (clipped_grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) return zip(clipped_grads, variables) model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_squad_model, loss_fn=get_loss_fn(), model_dir=FLAGS.model_dir, steps_per_epoch=steps_per_epoch, steps_per_loop=FLAGS.steps_per_loop, epochs=epochs, train_input_fn=train_input_fn, init_checkpoint=FLAGS.init_checkpoint, run_eagerly=run_eagerly, custom_callbacks=custom_callbacks, explicit_allreduce=False, post_allreduce_callbacks=[clip_by_global_norm_callback])
def run_classifier(self, train_input_fn, validation_input_fn, epochs, steps_per_epoch, validation_steps, num_classes): """Creates classifier and runs the classifier training.""" if epochs is None: epochs = self.default_training_epochs bert_config = bert_configs.BertConfig( 0, initializer_range=self.initializer_range, hidden_dropout_prob=self.dropout_rate) warmup_steps = int(epochs * steps_per_epoch * 0.1) initial_lr = self.learning_rate def _get_classifier_model(): """Gets a classifier model.""" classifier_model, core_model = (bert_models.classifier_model( bert_config, num_classes, self.seq_len, hub_module_url=self.uri)) classifier_model.optimizer = optimization.create_optimizer( initial_lr, steps_per_epoch * epochs, warmup_steps) return classifier_model, core_model # During distributed training, loss used for gradient computation is # summed over from all replicas. When Keras compile/fit() API is used, # the fit() API internally normalizes the loss by dividing the loss by # the number of replicas used for computation. However, when custom # training loop is used this is not done automatically and should be # done manually by the end user. loss_multiplier = 1.0 if self.scale_loss: loss_multiplier = 1.0 / self.strategy.num_replicas_in_sync loss_fn = self.get_classification_loss_fn(num_classes, loss_factor=loss_multiplier) # Defines evaluation metrics function, which will create metrics in the # correct device and strategy scope. def metric_fn(): return tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy', dtype=tf.float32) # Use user-defined loop to start training. tf.compat.v1.logging.info( 'Training using customized training loop TF 2.0 ' 'with distribution strategy.') bert_model = model_training_utils.run_customized_training_loop( strategy=self.strategy, model_fn=_get_classifier_model, loss_fn=loss_fn, model_dir=self.model_dir, steps_per_epoch=steps_per_epoch, steps_per_loop=self.steps_per_loop, epochs=epochs, train_input_fn=train_input_fn, eval_input_fn=validation_input_fn, eval_steps=validation_steps, init_checkpoint=None, metric_fn=metric_fn, custom_callbacks=None, run_eagerly=False) # Used in evaluation. with self.strategy.scope(): bert_model, _ = _get_classifier_model() checkpoint_path = tf.train.latest_checkpoint(self.model_dir) checkpoint = tf.train.Checkpoint(model=bert_model) checkpoint.restore(checkpoint_path).expect_partial() bert_model.compile(loss=loss_fn, metrics=[metric_fn()]) return bert_model