Example #1
0
def run_customized_training(strategy,
                            bert_config,
                            max_seq_length,
                            max_predictions_per_seq,
                            model_dir,
                            steps_per_epoch,
                            steps_per_loop,
                            epochs,
                            initial_lr,
                            warmup_steps,
                            input_files,
                            train_batch_size,
                            use_remote_tpu=False):
    """Run BERT pretrain model training using low-level API."""

    train_input_fn = functools.partial(get_pretrain_input_data, input_files,
                                       max_seq_length, max_predictions_per_seq,
                                       train_batch_size, strategy)

    def _get_pretrain_model():
        """Gets a pretraining model."""
        pretrain_model, core_model = bert_models.pretrain_model(
            bert_config, max_seq_length, max_predictions_per_seq)
        pretrain_model.optimizer = optimization.create_optimizer(
            initial_lr, steps_per_epoch * epochs, warmup_steps)
        if FLAGS.fp16_implementation == 'graph_rewrite':
            # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
            # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
            # which will ensure tf.compat.v2.keras.mixed_precision and
            # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
            # up.
            pretrain_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                pretrain_model.optimizer)
        return pretrain_model, core_model

    trained_model = model_training_utils.run_customized_training_loop(
        strategy=strategy,
        model_fn=_get_pretrain_model,
        loss_fn=get_loss_fn(
            loss_factor=1.0 /
            strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0),
        model_dir=model_dir,
        train_input_fn=train_input_fn,
        steps_per_epoch=steps_per_epoch,
        steps_per_loop=steps_per_loop,
        epochs=epochs,
        use_remote_tpu=use_remote_tpu)

    # Creates the BERT core model outside distribution strategy scope.
    _, core_model = bert_models.pretrain_model(bert_config, max_seq_length,
                                               max_predictions_per_seq)

    # Restores the core model from model checkpoints and get a new checkpoint only
    # contains the core model.
    model_saving_utils.export_pretraining_checkpoint(checkpoint_dir=model_dir,
                                                     model=core_model)
    return trained_model
Example #2
0
 def _get_pretrain_model():
     """Gets a pretraining model."""
     pretrain_model, core_model = bert_models.pretrain_model(
         bert_config,
         max_seq_length,
         max_predictions_per_seq,
         float_type=tf.float16 if FLAGS.use_fp16 else tf.float32)
     pretrain_model.optimizer = optimization.create_optimizer(
         initial_lr, steps_per_epoch * epochs, warmup_steps,
         FLAGS.optimizer_type)
     if FLAGS.use_fp16:
         pretrain_model.optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
             pretrain_model.optimizer, dynamic=True)
     return pretrain_model, core_model
Example #3
0
 def _get_pretrain_model():
     """Gets a pretraining model."""
     pretrain_model, core_model = bert_models.pretrain_model(
         bert_config, max_seq_length, max_predictions_per_seq)
     pretrain_model.optimizer = optimization.create_optimizer(
         initial_lr, steps_per_epoch * epochs, warmup_steps)
     if FLAGS.fp16_implementation == 'graph_rewrite':
         # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
         # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
         # which will ensure tf.compat.v2.keras.mixed_precision and
         # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
         # up.
         pretrain_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
             pretrain_model.optimizer)
     return pretrain_model, core_model