示例#1
0
 def run_training(self, distribution, model_dir, steps_per_loop, run_eagerly):
   model_training_utils.run_customized_training_loop(
       strategy=distribution,
       model_fn=self._model_fn,
       loss_fn=tf.keras.losses.categorical_crossentropy,
       model_dir=model_dir,
       steps_per_epoch=20,
       steps_per_loop=steps_per_loop,
       epochs=2,
       train_input_fn=self._input_fn,
       eval_input_fn=self._input_fn,
       eval_steps=10,
       init_checkpoint=None,
       metric_fn=metric_fn,
       custom_callbacks=None,
       run_eagerly=run_eagerly)
示例#2
0
def run_customized_training(strategy,
                            bert_config,
                            max_seq_length,
                            max_predictions_per_seq,
                            model_dir,
                            steps_per_epoch,
                            steps_per_loop,
                            epochs,
                            initial_lr,
                            warmup_steps,
                            input_files,
                            train_batch_size,
                            use_remote_tpu=False):
    """Run BERT pretrain model training using low-level API."""

    train_input_fn = functools.partial(get_pretrain_input_data, input_files,
                                       max_seq_length, max_predictions_per_seq,
                                       train_batch_size, strategy)

    def _get_pretrain_model():
        """Gets a pretraining model."""
        pretrain_model, core_model = bert_models.pretrain_model(
            bert_config, max_seq_length, max_predictions_per_seq)
        pretrain_model.optimizer = optimization.create_optimizer(
            initial_lr, steps_per_epoch * epochs, warmup_steps)
        if FLAGS.fp16_implementation == 'graph_rewrite':
            # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
            # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
            # which will ensure tf.compat.v2.keras.mixed_precision and
            # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
            # up.
            pretrain_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                pretrain_model.optimizer)
        return pretrain_model, core_model

    trained_model = model_training_utils.run_customized_training_loop(
        strategy=strategy,
        model_fn=_get_pretrain_model,
        loss_fn=get_loss_fn(
            loss_factor=1.0 /
            strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0),
        model_dir=model_dir,
        train_input_fn=train_input_fn,
        steps_per_epoch=steps_per_epoch,
        steps_per_loop=steps_per_loop,
        epochs=epochs,
        use_remote_tpu=use_remote_tpu)

    # Creates the BERT core model outside distribution strategy scope.
    _, core_model = bert_models.pretrain_model(bert_config, max_seq_length,
                                               max_predictions_per_seq)

    # Restores the core model from model checkpoints and get a new checkpoint only
    # contains the core model.
    model_saving_utils.export_pretraining_checkpoint(checkpoint_dir=model_dir,
                                                     model=core_model)
    return trained_model
 def run_training(self, strategy, model_dir, steps_per_loop, run_eagerly):
     input_fn = create_fake_data_input_fn(batch_size=8,
                                          features_shape=[128],
                                          num_classes=3)
     model_training_utils.run_customized_training_loop(
         strategy=strategy,
         model_fn=self._model_fn,
         loss_fn=tf.keras.losses.categorical_crossentropy,
         model_dir=model_dir,
         steps_per_epoch=20,
         steps_per_loop=steps_per_loop,
         epochs=2,
         train_input_fn=input_fn,
         eval_input_fn=input_fn,
         eval_steps=10,
         init_checkpoint=None,
         metric_fn=metric_fn,
         custom_callbacks=None,
         run_eagerly=run_eagerly)
示例#4
0
def run_customized_training(strategy,
                            bert_config,
                            max_seq_length,
                            max_predictions_per_seq,
                            model_dir,
                            steps_per_epoch,
                            steps_per_loop,
                            epochs,
                            initial_lr,
                            warmup_steps,
                            input_files,
                            train_batch_size):
  """Run BERT pretrain model training using low-level API."""

  train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length,
                                           max_predictions_per_seq,
                                           train_batch_size)

  def _get_pretrain_model():
    """Gets a pretraining model."""
    pretrain_model, core_model = bert_models.pretrain_model(
        bert_config, max_seq_length, max_predictions_per_seq)
    pretrain_model.optimizer = optimization.create_optimizer(
        initial_lr, steps_per_epoch * epochs, warmup_steps)
    if FLAGS.fp16_implementation == 'graph_rewrite':
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      pretrain_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          pretrain_model.optimizer)
    return pretrain_model, core_model

  trained_model = model_training_utils.run_customized_training_loop(
      strategy=strategy,
      model_fn=_get_pretrain_model,
      loss_fn=get_loss_fn(
          loss_factor=1.0 /
          strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0),
      model_dir=model_dir,
      train_input_fn=train_input_fn,
      steps_per_epoch=steps_per_epoch,
      steps_per_loop=steps_per_loop,
      epochs=epochs,
      sub_model_export_name='pretrained/bert_model')

  return trained_model
示例#5
0
def run_customized_training(strategy, bert_config, max_seq_length,
                            max_predictions_per_seq, model_dir,
                            steps_per_epoch, steps_per_loop, epochs,
                            initial_lr, warmup_steps, input_files,
                            train_batch_size):
    """Run BERT pretrain model training using low-level API."""

    train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length,
                                             max_predictions_per_seq,
                                             train_batch_size)

    def _get_pretrain_model():
        """Gets a pretraining model."""
        pretrain_model, core_model = bert_models.pretrain_model(
            bert_config,
            max_seq_length,
            max_predictions_per_seq,
            float_type=tf.float16 if FLAGS.use_fp16 else tf.float32)
        pretrain_model.optimizer = optimization.create_optimizer(
            initial_lr, steps_per_epoch * epochs, warmup_steps,
            FLAGS.optimizer_type)
        if FLAGS.use_fp16:
            pretrain_model.optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
                pretrain_model.optimizer, dynamic=True)
        return pretrain_model, core_model

    dllogging = dllogger_class.dllogger_class(FLAGS.dllog_path)
    params = {'dllogging': dllogging, 'FLAGS': FLAGS}
    logging.info("init_lr = %f", initial_lr)
    trained_model = model_training_utils.run_customized_training_loop(
        strategy=strategy,
        model_fn=_get_pretrain_model,
        loss_fn=get_loss_fn(loss_factor=1.0 / strategy.num_replicas_in_sync
                            if FLAGS.scale_loss and strategy else 1.0),
        model_dir=model_dir,
        train_input_fn=train_input_fn,
        steps_per_epoch=steps_per_epoch,
        num_accumulative_step=FLAGS.num_accumulation_steps,
        steps_per_loop=steps_per_loop,
        epochs=epochs,
        sub_model_export_name='pretrained/bert_model',
        init_checkpoint=FLAGS.init_checkpoint,
        hvd=hvd if FLAGS.use_horovod else None,
        params=params)

    return trained_model
def run_customized_training(strategy, bert_config, max_seq_length,
                            max_predictions_per_seq, model_dir,
                            steps_per_epoch, steps_per_loop, epochs,
                            initial_lr, warmup_steps, input_files,
                            train_batch_size):
    """Run BERT pretrain model training using low-level API."""

    train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length,
                                             max_predictions_per_seq,
                                             train_batch_size)

    def _get_pretrain_model():
        """Gets a pretraining model."""
        pretrain_model, core_model = bert_models.pretrain_model(
            bert_config, max_seq_length, max_predictions_per_seq)
        optimizer = optimization.create_optimizer(initial_lr,
                                                  steps_per_epoch * epochs,
                                                  warmup_steps)
        pretrain_model.optimizer = performance.configure_optimizer(
            optimizer,
            use_float16=common_flags.use_float16(),
            use_graph_rewrite=common_flags.use_graph_rewrite())
        return pretrain_model, core_model

    trained_model = model_training_utils.run_customized_training_loop(
        strategy=strategy,
        model_fn=_get_pretrain_model,
        loss_fn=get_loss_fn(
            loss_factor=1.0 /
            strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0),
        model_dir=model_dir,
        train_input_fn=train_input_fn,
        steps_per_epoch=steps_per_epoch,
        steps_per_loop=steps_per_loop,
        epochs=epochs,
        sub_model_export_name='pretrained/bert_model')

    return trained_model
示例#7
0
def train_squad(strategy,
                input_meta_data,
                custom_callbacks=None,
                run_eagerly=False):
  """Run bert squad training."""
  if strategy:
    logging.info('Training using customized training loop with distribution'
                 ' strategy.')
  # Enables XLA in Session Config. Should not be set for TPU.
  keras_utils.set_config_v2(FLAGS.enable_xla)

  use_float16 = common_flags.use_float16()
  if use_float16:
    tf.keras.mixed_precision.experimental.set_policy('mixed_float16')

  bert_config = MODEL_CLASSES[FLAGS.model_type][0].from_json_file(
      FLAGS.bert_config_file)
  epochs = FLAGS.num_train_epochs
  num_train_examples = input_meta_data['train_data_size']
  max_seq_length = input_meta_data['max_seq_length']
  steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
  warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size)
  train_input_fn = get_dataset_fn(
      FLAGS.train_data_path,
      max_seq_length,
      FLAGS.train_batch_size,
      is_training=True)

  def _get_squad_model():
    """Get Squad model and optimizer."""
    squad_model, core_model = bert_models.squad_model(
        bert_config,
        max_seq_length,
        float_type=tf.float16 if use_float16 else tf.float32,
        hub_module_url=FLAGS.hub_module_url)
    squad_model.optimizer = optimization.create_optimizer(
        FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
    if use_float16:
      # Wraps optimizer with a LossScaleOptimizer. This is done automatically
      # in compile() with the "mixed_float16" policy, but since we do not call
      # compile(), we must wrap the optimizer manually.
      squad_model.optimizer = (
          tf.keras.mixed_precision.experimental.LossScaleOptimizer(
              squad_model.optimizer, loss_scale=common_flags.get_loss_scale()))
    if FLAGS.fp16_implementation == 'graph_rewrite':
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      squad_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          squad_model.optimizer)
    return squad_model, core_model

  # The original BERT model does not scale the loss by
  # 1/num_replicas_in_sync. It could be an accident. So, in order to use
  # the same hyper parameter, we do the same thing here by keeping each
  # replica loss as it is.
  loss_fn = get_loss_fn(
      loss_factor=1.0 /
      strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)

  model_training_utils.run_customized_training_loop(
      strategy=strategy,
      model_fn=_get_squad_model,
      loss_fn=loss_fn,
      model_dir=FLAGS.model_dir,
      steps_per_epoch=steps_per_epoch,
      steps_per_loop=FLAGS.steps_per_loop,
      epochs=epochs,
      train_input_fn=train_input_fn,
      init_checkpoint=FLAGS.init_checkpoint,
      run_eagerly=run_eagerly,
      custom_callbacks=custom_callbacks)
示例#8
0
def train_squad(strategy,
                input_meta_data,
                bert_config,
                custom_callbacks=None,
                run_eagerly=False):
    """Run bert squad training."""
    if strategy:
        logging.info(
            'Training using customized training loop with distribution'
            ' strategy.')
    # Enables XLA in Session Config. Should not be set for TPU.
    keras_utils.set_config_v2(FLAGS.enable_xla)

    use_float16 = common_flags.use_float16()
    if use_float16:
        tf.keras.mixed_precision.experimental.set_policy('mixed_float16')

    epochs = FLAGS.num_train_epochs
    num_train_examples = input_meta_data['train_data_size']
    max_seq_length = input_meta_data['max_seq_length']
    steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
    warmup_steps = int(epochs * num_train_examples * 0.1 /
                       FLAGS.train_batch_size)
    train_input_fn = get_dataset_fn(FLAGS.train_data_path,
                                    max_seq_length,
                                    FLAGS.train_batch_size,
                                    is_training=True)

    def _get_squad_model():
        """Get Squad model and optimizer."""
        squad_model, core_model = bert_models.squad_model(
            bert_config,
            max_seq_length,
            hub_module_url=FLAGS.hub_module_url,
            hub_module_trainable=FLAGS.hub_module_trainable)
        squad_model.optimizer = optimization.create_optimizer(
            FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
        if use_float16:
            # Wraps optimizer with a LossScaleOptimizer. This is done automatically
            # in compile() with the "mixed_float16" policy, but since we do not call
            # compile(), we must wrap the optimizer manually.
            squad_model.optimizer = (
                tf.keras.mixed_precision.experimental.LossScaleOptimizer(
                    squad_model.optimizer,
                    loss_scale=common_flags.get_loss_scale()))
        if FLAGS.fp16_implementation == 'graph_rewrite':
            # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
            # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
            # which will ensure tf.compat.v2.keras.mixed_precision and
            # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
            # up.
            squad_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                squad_model.optimizer)
        return squad_model, core_model

    # The original BERT model does not scale the loss by
    # 1/num_replicas_in_sync. It could be an accident. So, in order to use
    # the same hyper parameter, we do the same thing here by keeping each
    # replica loss as it is.
    loss_fn = get_loss_fn(
        loss_factor=1.0 /
        strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)

    # when all_reduce_sum_gradients = False, apply_gradients() no longer
    # implicitly allreduce gradients, users manually allreduce gradient and
    # passed the allreduced grads_and_vars. For now, the clip_by_global_norm
    # will be moved to before users' manual allreduce to keep the math
    # unchanged.
    def clip_by_global_norm_callback(grads_and_vars):
        grads, variables = zip(*grads_and_vars)
        (clipped_grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
        return zip(clipped_grads, variables)

    model_training_utils.run_customized_training_loop(
        strategy=strategy,
        model_fn=_get_squad_model,
        loss_fn=loss_fn,
        model_dir=FLAGS.model_dir,
        steps_per_epoch=steps_per_epoch,
        steps_per_loop=FLAGS.steps_per_loop,
        epochs=epochs,
        train_input_fn=train_input_fn,
        init_checkpoint=FLAGS.init_checkpoint,
        run_eagerly=run_eagerly,
        custom_callbacks=custom_callbacks,
        explicit_allreduce=True,
        pre_allreduce_callbacks=[clip_by_global_norm_callback])
示例#9
0
def run_bert_classifier(strategy,
                        bert_config,
                        input_meta_data,
                        model_dir,
                        epochs,
                        steps_per_epoch,
                        steps_per_loop,
                        eval_steps,
                        warmup_steps,
                        initial_lr,
                        init_checkpoint,
                        custom_callbacks=None,
                        run_eagerly=False):
    """Run BERT classifier training using low-level API."""
    max_seq_length = input_meta_data['max_seq_length']
    num_classes = input_meta_data['num_labels']

    train_input_fn = functools.partial(
        input_pipeline.create_classifier_dataset,
        FLAGS.train_data_path,
        seq_length=max_seq_length,
        batch_size=FLAGS.train_batch_size)
    eval_input_fn = functools.partial(input_pipeline.create_classifier_dataset,
                                      FLAGS.eval_data_path,
                                      seq_length=max_seq_length,
                                      batch_size=FLAGS.eval_batch_size,
                                      is_training=False,
                                      drop_remainder=False)

    def _get_classifier_model():
        """Gets a classifier model."""
        classifier_model, core_model = (bert_models.classifier_model(
            bert_config,
            tf.float32,
            num_classes,
            max_seq_length,
            hub_module_url=FLAGS.hub_module_url))
        classifier_model.optimizer = optimization.create_optimizer(
            initial_lr, steps_per_epoch * epochs, warmup_steps)
        if FLAGS.fp16_implementation == 'graph_rewrite':
            # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
            # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
            # which will ensure tf.compat.v2.keras.mixed_precision and
            # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
            # up.
            classifier_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                classifier_model.optimizer)
        return classifier_model, core_model

    loss_fn = get_loss_fn(
        num_classes,
        loss_factor=1.0 /
        strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)

    # Defines evaluation metrics function, which will create metrics in the
    # correct device and strategy scope.
    def metric_fn():
        return tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy',
                                                          dtype=tf.float32)

    if FLAGS.use_keras_compile_fit:
        # Start training using Keras compile/fit API.
        logging.info('Training using TF 2.0 Keras compile/fit API with '
                     'distrubuted strategy.')
        return run_keras_compile_fit(model_dir,
                                     strategy,
                                     _get_classifier_model,
                                     train_input_fn,
                                     eval_input_fn,
                                     loss_fn,
                                     metric_fn,
                                     init_checkpoint,
                                     epochs,
                                     steps_per_epoch,
                                     eval_steps,
                                     custom_callbacks=None)

    # Use user-defined loop to start training.
    logging.info('Training using customized training loop TF 2.0 with '
                 'distrubuted strategy.')
    return model_training_utils.run_customized_training_loop(
        strategy=strategy,
        model_fn=_get_classifier_model,
        loss_fn=loss_fn,
        model_dir=model_dir,
        steps_per_epoch=steps_per_epoch,
        steps_per_loop=steps_per_loop,
        epochs=epochs,
        train_input_fn=train_input_fn,
        eval_input_fn=eval_input_fn,
        eval_steps=eval_steps,
        init_checkpoint=init_checkpoint,
        metric_fn=metric_fn,
        custom_callbacks=custom_callbacks,
        run_eagerly=run_eagerly)
示例#10
0
def train_squad(strategy,
                input_meta_data,
                bert_config,
                custom_callbacks=None,
                run_eagerly=False):
    """Run bert squad training."""
    if strategy:
        logging.info(
            'Training using customized training loop with distribution'
            ' strategy.')
    # Enables XLA in Session Config. Should not be set for TPU.
    keras_utils.set_config_v2(FLAGS.enable_xla)
    performance.set_mixed_precision_policy(common_flags.dtype())

    epochs = FLAGS.num_train_epochs
    num_train_examples = input_meta_data['train_data_size']
    max_seq_length = input_meta_data['max_seq_length']
    steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
    warmup_steps = int(epochs * num_train_examples * 0.1 /
                       FLAGS.train_batch_size)
    train_input_fn = get_dataset_fn(FLAGS.train_data_path,
                                    max_seq_length,
                                    FLAGS.train_batch_size,
                                    is_training=True)

    def _get_squad_model():
        """Get Squad model and optimizer."""
        squad_model, core_model = bert_models.squad_model(
            bert_config,
            max_seq_length,
            hub_module_url=FLAGS.hub_module_url,
            hub_module_trainable=FLAGS.hub_module_trainable)
        optimizer = optimization.create_optimizer(FLAGS.learning_rate,
                                                  steps_per_epoch * epochs,
                                                  warmup_steps)

        squad_model.optimizer = performance.configure_optimizer(
            optimizer,
            use_float16=common_flags.use_float16(),
            use_graph_rewrite=common_flags.use_graph_rewrite())
        return squad_model, core_model

    # If explicit_allreduce = True, apply_gradients() no longer implicitly
    # allreduce gradients, users manually allreduce gradient and pass the
    # allreduced grads_and_vars to apply_gradients(). clip_by_global_norm will be
    # applied to allreduced gradients.
    def clip_by_global_norm_callback(grads_and_vars):
        grads, variables = zip(*grads_and_vars)
        (clipped_grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
        return zip(clipped_grads, variables)

    model_training_utils.run_customized_training_loop(
        strategy=strategy,
        model_fn=_get_squad_model,
        loss_fn=get_loss_fn(),
        model_dir=FLAGS.model_dir,
        steps_per_epoch=steps_per_epoch,
        steps_per_loop=FLAGS.steps_per_loop,
        epochs=epochs,
        train_input_fn=train_input_fn,
        init_checkpoint=FLAGS.init_checkpoint,
        run_eagerly=run_eagerly,
        custom_callbacks=custom_callbacks,
        explicit_allreduce=False,
        post_allreduce_callbacks=[clip_by_global_norm_callback])
示例#11
0
def run_bert_classifier(strategy,
                        bert_config,
                        input_meta_data,
                        model_dir,
                        epochs,
                        steps_per_epoch,
                        steps_per_loop,
                        eval_steps,
                        warmup_steps,
                        initial_lr,
                        init_checkpoint,
                        train_input_fn,
                        eval_input_fn,
                        custom_callbacks=None,
                        run_eagerly=False,
                        use_keras_compile_fit=False):
    """Run BERT classifier training using low-level API."""
    max_seq_length = input_meta_data['max_seq_length']
    num_classes = input_meta_data['num_labels']

    def _get_classifier_model():
        """Gets a classifier model."""
        classifier_model, core_model = (bert_models.classifier_model(
            bert_config,
            tf.float32,
            num_classes,
            max_seq_length,
            hub_module_url=FLAGS.hub_module_url))
        classifier_model.optimizer = optimization.create_optimizer(
            initial_lr, steps_per_epoch * epochs, warmup_steps)
        if FLAGS.fp16_implementation == 'graph_rewrite':
            # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
            # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
            # which will ensure tf.compat.v2.keras.mixed_precision and
            # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
            # up.
            classifier_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                classifier_model.optimizer)
        return classifier_model, core_model

    # During distributed training, loss used for gradient computation is
    # summed over from all replicas. When Keras compile/fit() API is used,
    # the fit() API internally normalizes the loss by dividing the loss by
    # the number of replicas used for computation. However, when custom
    # training loop is used this is not done automatically and should be
    # done manually by the end user.
    loss_multiplier = 1.0
    if FLAGS.scale_loss and not use_keras_compile_fit:
        loss_multiplier = 1.0 / strategy.num_replicas_in_sync

    loss_fn = get_loss_fn(num_classes, loss_factor=loss_multiplier)

    # Defines evaluation metrics function, which will create metrics in the
    # correct device and strategy scope.
    def metric_fn():
        return tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy',
                                                          dtype=tf.float32)

    if use_keras_compile_fit:
        # Start training using Keras compile/fit API.
        logging.info('Training using TF 2.0 Keras compile/fit API with '
                     'distribution strategy.')
        return run_keras_compile_fit(model_dir,
                                     strategy,
                                     _get_classifier_model,
                                     train_input_fn,
                                     eval_input_fn,
                                     loss_fn,
                                     metric_fn,
                                     init_checkpoint,
                                     epochs,
                                     steps_per_epoch,
                                     eval_steps,
                                     custom_callbacks=None)

    # Use user-defined loop to start training.
    logging.info('Training using customized training loop TF 2.0 with '
                 'distribution strategy.')
    return model_training_utils.run_customized_training_loop(
        strategy=strategy,
        model_fn=_get_classifier_model,
        loss_fn=loss_fn,
        model_dir=model_dir,
        steps_per_epoch=steps_per_epoch,
        steps_per_loop=steps_per_loop,
        epochs=epochs,
        train_input_fn=train_input_fn,
        eval_input_fn=eval_input_fn,
        eval_steps=eval_steps,
        init_checkpoint=init_checkpoint,
        metric_fn=metric_fn,
        custom_callbacks=custom_callbacks,
        run_eagerly=run_eagerly)
    def run_classifier(self, train_input_fn, validation_input_fn, epochs,
                       steps_per_epoch, validation_steps, num_classes):
        """Creates classifier and runs the classifier training."""
        if epochs is None:
            epochs = self.default_training_epochs

        bert_config = bert_configs.BertConfig(
            0,
            initializer_range=self.initializer_range,
            hidden_dropout_prob=self.dropout_rate)
        warmup_steps = int(epochs * steps_per_epoch * 0.1)
        initial_lr = self.learning_rate

        def _get_classifier_model():
            """Gets a classifier model."""
            classifier_model, core_model = (bert_models.classifier_model(
                bert_config,
                num_classes,
                self.seq_len,
                hub_module_url=self.uri))
            classifier_model.optimizer = optimization.create_optimizer(
                initial_lr, steps_per_epoch * epochs, warmup_steps)
            return classifier_model, core_model

        # During distributed training, loss used for gradient computation is
        # summed over from all replicas. When Keras compile/fit() API is used,
        # the fit() API internally normalizes the loss by dividing the loss by
        # the number of replicas used for computation. However, when custom
        # training loop is used this is not done automatically and should be
        # done manually by the end user.
        loss_multiplier = 1.0
        if self.scale_loss:
            loss_multiplier = 1.0 / self.strategy.num_replicas_in_sync

        loss_fn = self.get_classification_loss_fn(num_classes,
                                                  loss_factor=loss_multiplier)

        # Defines evaluation metrics function, which will create metrics in the
        # correct device and strategy scope.
        def metric_fn():
            return tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy',
                                                              dtype=tf.float32)

        # Use user-defined loop to start training.
        tf.compat.v1.logging.info(
            'Training using customized training loop TF 2.0 '
            'with distribution strategy.')
        bert_model = model_training_utils.run_customized_training_loop(
            strategy=self.strategy,
            model_fn=_get_classifier_model,
            loss_fn=loss_fn,
            model_dir=self.model_dir,
            steps_per_epoch=steps_per_epoch,
            steps_per_loop=self.steps_per_loop,
            epochs=epochs,
            train_input_fn=train_input_fn,
            eval_input_fn=validation_input_fn,
            eval_steps=validation_steps,
            init_checkpoint=None,
            metric_fn=metric_fn,
            custom_callbacks=None,
            run_eagerly=False)

        # Used in evaluation.
        with self.strategy.scope():
            bert_model, _ = _get_classifier_model()
            checkpoint_path = tf.train.latest_checkpoint(self.model_dir)
            checkpoint = tf.train.Checkpoint(model=bert_model)
            checkpoint.restore(checkpoint_path).expect_partial()
            bert_model.compile(loss=loss_fn, metrics=[metric_fn()])
        return bert_model
示例#13
0
def train_squad(strategy,
                input_meta_data,
                custom_callbacks=None,
                run_eagerly=False):
    """Run bert squad training."""
    if strategy:
        logging.info(
            'Training using customized training loop with distribution'
            ' strategy.')
    # Enables XLA in Session Config. Should not be set for TPU.
    keras_utils.set_config_v2(FLAGS.enable_xla)

    use_float16 = common_flags.use_float16()
    if use_float16:
        tf.keras.mixed_precision.experimental.set_policy('mixed_float16')

    bert_config = MODEL_CLASSES[FLAGS.model_type][0].from_json_file(
        FLAGS.bert_config_file)
    epochs = FLAGS.num_train_epochs
    num_train_examples = input_meta_data['train_data_size']
    max_seq_length = input_meta_data['max_seq_length']
    global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps
    if FLAGS.use_horovod:
        global_batch_size *= hvd.size()
    steps_per_epoch = int(num_train_examples / global_batch_size)
    warmup_steps = int(epochs * num_train_examples * 0.1 / global_batch_size)
    train_input_fn = get_dataset_fn(FLAGS.train_data_path,
                                    max_seq_length,
                                    FLAGS.train_batch_size,
                                    is_training=True,
                                    use_horovod=FLAGS.use_horovod)

    if FLAGS.benchmark:
        steps_per_epoch = 800
        epochs = 1

    def _get_squad_model():
        """Get Squad model and optimizer."""
        squad_model, core_model = bert_models.squad_model(
            bert_config,
            max_seq_length,
            float_type=tf.float16 if FLAGS.use_fp16 else tf.float32,
            hub_module_url=FLAGS.hub_module_url)
        learning_rate = FLAGS.learning_rate * hvd.size(
        ) if FLAGS.use_horovod else FLAGS.learning_rate
        squad_model.optimizer = optimization.create_optimizer(
            learning_rate, steps_per_epoch * epochs, warmup_steps,
            FLAGS.optimizer_type)
        if FLAGS.use_fp16:
            squad_model.optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
                squad_model.optimizer, dynamic=True)
        return squad_model, core_model

    # The original BERT model does not scale the loss by
    # 1/num_replicas_in_sync. It could be an accident. So, in order to use
    # the same hyper parameter, we do the same thing here by keeping each
    # replica loss as it is.
    loss_fn = get_loss_fn(loss_factor=1.0 / strategy.num_replicas_in_sync
                          if FLAGS.scale_loss and strategy else 1.0)

    params = {'dllogging': input_meta_data['dllogging'], 'FLAGS': FLAGS}

    model_training_utils.run_customized_training_loop(
        strategy=strategy,
        model_fn=_get_squad_model,
        loss_fn=loss_fn,
        model_dir=FLAGS.model_dir,
        steps_per_epoch=steps_per_epoch,
        num_accumulative_step=FLAGS.num_accumulation_steps,
        steps_per_loop=FLAGS.steps_per_loop,
        epochs=epochs,
        train_input_fn=train_input_fn,
        init_checkpoint=FLAGS.init_checkpoint,
        hvd=hvd if FLAGS.use_horovod else None,
        run_eagerly=run_eagerly,
        custom_callbacks=custom_callbacks,
        params=params)
def run(args, strategy):
    """Pretrains model using TF2. Adapted from the tensorflow/models Github"""
    # CONFIG
    # Use timestamp to generate a unique run name
    run_name = get_run_name(args)
    logger.info(f'*** Starting run {run_name} ***')
    output_dir = f'gs://{args.bucket_name}/{args.project_name}/pretrain/runs/{run_name}'

    # pretrained model path
    try:
        pretrained_model_path = PRETRAINED_MODELS[
            args.model_class]['bucket_location']
    except KeyError:
        raise ValueError(
            f'Could not find a pretrained model matching the model class {args.model_class}'
        )
    pretrained_model_config_path = f'gs://{args.bucket_name}/{pretrained_model_path}/bert_config.json'
    pretrained_model_checkpoint_path = f'gs://{args.bucket_name}/{pretrained_model_path}/bert_model.ckpt'

    # some logging
    logger.info(
        f'Running pretraining of model {args.model_class} on pretrain data {args.pretrain_data}'
    )
    logger.info(
        f'Initializing model from checkpoint {pretrained_model_checkpoint_path}'
    )

    # load model config based on model_class
    model_config = get_model_config(pretrained_model_config_path)

    # input data function
    train_input_fn = get_dataset_fn(args, _type='train')
    eval_input_fn = None
    eval_metric_fn = None
    if args.do_eval:
        logger.info(f'Setting up evaluation dataset')
        eval_metric_fn = get_eval_metric_fn
        eval_input_fn = get_dataset_fn(args, _type='dev')

    # model_fn
    def _get_pretrained_model(end_lr=0.0):
        """Gets a pretraining model."""
        pretrain_model, core_model = bert_models.pretrain_model(
            model_config, args.max_seq_length, args.max_predictions_per_seq)
        optimizer = utils.optimizer.create_optimizer(
            args.learning_rate, steps_per_epoch * args.num_epochs,
            warmup_steps, args.end_lr, args.optimizer_type)
        pretrain_model.optimizer = configure_optimizer(
            optimizer,
            use_float16=args.dtype == 'fp16',
            use_graph_rewrite=False)
        return pretrain_model, core_model

    # custom callbacks
    summary_dir = os.path.join(output_dir, 'summaries')
    time_history_callback = keras_utils.TimeHistory(
        batch_size=args.train_batch_size,
        log_steps=args.time_history_log_steps,
        logdir=summary_dir)
    custom_callbacks = [time_history_callback]

    # Save an initial version of the log file
    data = {
        'created_at': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
        'run_name': run_name,
        'num_train_steps': args.num_steps_per_epoch * args.num_epochs,
        'eval_steps': args.eval_steps,
        'model_dir': output_dir,
        'output_dir': output_dir,
        **vars(args),
    }
    f_path_training_log = os.path.join(output_dir, 'run_logs.json')
    logger.info(
        f'Writing training preliminary log to {f_path_training_log}...')
    save_to_json(data, f_path_training_log)

    # run training loop
    logger.info(
        f'Run training for {args.num_epochs:,} epochs, {args.num_steps_per_epoch:,} steps each, processing {args.num_epochs*args.num_steps_per_epoch*args.train_batch_size:,} training examples in total...'
    )
    time_start = time.time()
    model_training_utils.run_customized_training_loop(
        strategy=strategy,
        model_fn=_get_pretrained_model,
        loss_fn=get_loss_fn(),
        scale_loss=True,
        model_dir=output_dir,
        train_input_fn=train_input_fn,
        steps_per_epoch=args.num_steps_per_epoch,
        steps_per_loop=args.steps_per_loop,
        epochs=args.num_epochs,
        eval_input_fn=eval_input_fn,
        eval_steps=args.eval_steps,
        metric_fn=eval_metric_fn,
        init_checkpoint=pretrained_model_checkpoint_path,
        custom_callbacks=custom_callbacks,
        run_eagerly=False,
        sub_model_export_name='pretrained/bert_model',
        explicit_allreduce=False,
        pre_allreduce_callbacks=None,
        post_allreduce_callbacks=None)
    time_end = time.time()
    training_time_min = (time_end - time_start) / 60
    logger.info(f'Finished training after {training_time_min:.1f} min')
    # Write to run directory
    data['training_time_min'] = training_time_min
    logger.info(f'Writing final log to {f_path_training_log}...')
    save_to_json(data, f_path_training_log)
示例#15
0
def run_bert_classifier(strategy,
                        bert_config,
                        input_meta_data,
                        model_dir,
                        epochs,
                        steps_per_epoch,
                        steps_per_loop,
                        eval_steps,
                        warmup_steps,
                        initial_lr,
                        init_checkpoint,
                        train_input_fn,
                        eval_input_fn,
                        custom_callbacks=None,
                        run_eagerly=False,
                        use_keras_compile_fit=False):
    """Run BERT classifier training using low-level API."""
    max_seq_length = input_meta_data['max_seq_length']
    num_classes = input_meta_data['num_labels']

    def _get_classifier_model():
        """Gets a classifier model."""
        classifier_model, core_model = (bert_models.classifier_model(
            bert_config,
            num_classes,
            max_seq_length,
            hub_module_url=FLAGS.hub_module_url,
            hub_module_trainable=FLAGS.hub_module_trainable))
        optimizer = optimization.create_optimizer(initial_lr,
                                                  steps_per_epoch * epochs,
                                                  warmup_steps)
        classifier_model.optimizer = performance.configure_optimizer(
            optimizer,
            use_float16=common_flags.use_float16(),
            use_graph_rewrite=common_flags.use_graph_rewrite())
        return classifier_model, core_model

    loss_fn = get_loss_fn(num_classes)

    # Defines evaluation metrics function, which will create metrics in the
    # correct device and strategy scope.
    def metric_fn():
        return tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy',
                                                          dtype=tf.float32)

    if use_keras_compile_fit:
        # Start training using Keras compile/fit API.
        logging.info('Training using TF 2.0 Keras compile/fit API with '
                     'distribution strategy.')
        return run_keras_compile_fit(model_dir,
                                     strategy,
                                     _get_classifier_model,
                                     train_input_fn,
                                     eval_input_fn,
                                     loss_fn,
                                     metric_fn,
                                     init_checkpoint,
                                     epochs,
                                     steps_per_epoch,
                                     steps_per_loop,
                                     eval_steps,
                                     custom_callbacks=custom_callbacks)

    # Use user-defined loop to start training.
    logging.info('Training using customized training loop TF 2.0 with '
                 'distribution strategy.')
    return model_training_utils.run_customized_training_loop(
        strategy=strategy,
        model_fn=_get_classifier_model,
        loss_fn=loss_fn,
        model_dir=model_dir,
        steps_per_epoch=steps_per_epoch,
        steps_per_loop=steps_per_loop,
        epochs=epochs,
        train_input_fn=train_input_fn,
        eval_input_fn=eval_input_fn,
        eval_steps=eval_steps,
        init_checkpoint=init_checkpoint,
        metric_fn=metric_fn,
        custom_callbacks=custom_callbacks,
        run_eagerly=run_eagerly)