示例#1
0
 def run_training(self, distribution, model_dir, steps_per_loop, run_eagerly):
   model_training_utils.run_customized_training_loop(
       strategy=distribution,
       model_fn=self._model_fn,
       loss_fn=tf.keras.losses.categorical_crossentropy,
       model_dir=model_dir,
       steps_per_epoch=20,
       steps_per_loop=steps_per_loop,
       epochs=2,
       train_input_fn=self._input_fn,
       eval_input_fn=self._input_fn,
       eval_steps=10,
       init_checkpoint=None,
       metric_fn=metric_fn,
       custom_callbacks=None,
       run_eagerly=run_eagerly)
示例#2
0
 def run_training(self, strategy, model_dir, steps_per_loop, run_eagerly):
     input_fn = create_fake_data_input_fn(batch_size=8,
                                          features_shape=[128],
                                          num_classes=3)
     model_training_utils.run_customized_training_loop(
         strategy=strategy,
         model_fn=self._model_fn,
         loss_fn=tf.keras.losses.categorical_crossentropy,
         model_dir=model_dir,
         steps_per_epoch=20,
         steps_per_loop=steps_per_loop,
         epochs=2,
         train_input_fn=input_fn,
         eval_input_fn=input_fn,
         eval_steps=10,
         init_checkpoint=None,
         metric_fn=metric_fn,
         custom_callbacks=None,
         run_eagerly=run_eagerly)
示例#3
0
def run_customized_training(strategy, bert_config, max_seq_length,
                            max_predictions_per_seq, model_dir,
                            steps_per_epoch, steps_per_loop, epochs,
                            initial_lr, warmup_steps, input_files,
                            train_batch_size):
    """Run BERT pretrain model training using low-level API."""

    train_input_fn = functools.partial(get_pretrain_input_data, input_files,
                                       max_seq_length, max_predictions_per_seq,
                                       train_batch_size, strategy)

    def _get_pretrain_model():
        """Gets a pretraining model."""
        pretrain_model, core_model = bert_models.pretrain_model(
            bert_config, max_seq_length, max_predictions_per_seq)
        pretrain_model.optimizer = optimization.create_optimizer(
            initial_lr, steps_per_epoch * epochs, warmup_steps)
        if FLAGS.fp16_implementation == 'graph_rewrite':
            # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
            # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
            # which will ensure tf.compat.v2.keras.mixed_precision and
            # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
            # up.
            pretrain_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                pretrain_model.optimizer)
        return pretrain_model, core_model

    trained_model = model_training_utils.run_customized_training_loop(
        strategy=strategy,
        model_fn=_get_pretrain_model,
        loss_fn=get_loss_fn(
            loss_factor=1.0 /
            strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0),
        model_dir=model_dir,
        train_input_fn=train_input_fn,
        steps_per_epoch=steps_per_epoch,
        steps_per_loop=steps_per_loop,
        epochs=epochs)

    # Creates the BERT core model outside distribution strategy scope.
    _, core_model = bert_models.pretrain_model(bert_config, max_seq_length,
                                               max_predictions_per_seq)

    # Restores the core model from model checkpoints and get a new checkpoint only
    # contains the core model.
    model_saving_utils.export_pretraining_checkpoint(checkpoint_dir=model_dir,
                                                     model=core_model)
    return trained_model
示例#4
0
def train_squad(strategy,
                input_meta_data,
                custom_callbacks=None,
                run_eagerly=False):
  """Run bert squad training."""
  if strategy:
    logging.info('Training using customized training loop with distribution'
                 ' strategy.')
  # Enables XLA in Session Config. Should not be set for TPU.
  keras_utils.set_config_v2(FLAGS.enable_xla)

  use_float16 = common_flags.use_float16()
  if use_float16:
    policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
    tf.keras.mixed_precision.experimental.set_policy(policy)

  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
  epochs = FLAGS.num_train_epochs
  num_train_examples = input_meta_data['train_data_size']
  max_seq_length = input_meta_data['max_seq_length']
  steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
  warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size)
  train_input_fn = functools.partial(
      input_pipeline.create_squad_dataset,
      FLAGS.train_data_path,
      max_seq_length,
      FLAGS.train_batch_size,
      is_training=True)

  def _get_squad_model():
    """Get Squad model and optimizer."""
    squad_model, core_model = bert_models.squad_model(
        bert_config,
        max_seq_length,
        float_type=tf.float16 if use_float16 else tf.float32)
    squad_model.optimizer = optimization.create_optimizer(
        FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
    if use_float16:
      # Wraps optimizer with a LossScaleOptimizer. This is done automatically
      # in compile() with the "mixed_float16" policy, but since we do not call
      # compile(), we must wrap the optimizer manually.
      squad_model.optimizer = (
          tf.keras.mixed_precision.experimental.LossScaleOptimizer(
              squad_model.optimizer, loss_scale=common_flags.get_loss_scale()))
    if FLAGS.fp16_implementation == 'graph_rewrite':
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      squad_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          squad_model.optimizer)
    return squad_model, core_model

  # The original BERT model does not scale the loss by
  # 1/num_replicas_in_sync. It could be an accident. So, in order to use
  # the same hyper parameter, we do the same thing here by keeping each
  # replica loss as it is.
  loss_fn = get_loss_fn(
      loss_factor=1.0 /
      strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)

  model_training_utils.run_customized_training_loop(
      strategy=strategy,
      model_fn=_get_squad_model,
      loss_fn=loss_fn,
      model_dir=FLAGS.model_dir,
      steps_per_epoch=steps_per_epoch,
      steps_per_loop=FLAGS.steps_per_loop,
      epochs=epochs,
      train_input_fn=train_input_fn,
      init_checkpoint=FLAGS.init_checkpoint,
      run_eagerly=run_eagerly,
      custom_callbacks=custom_callbacks)
示例#5
0
def run_customized_training(strategy,
                            bert_config,
                            input_meta_data,
                            model_dir,
                            epochs,
                            steps_per_epoch,
                            steps_per_loop,
                            eval_steps,
                            warmup_steps,
                            initial_lr,
                            init_checkpoint,
                            custom_callbacks=None,
                            run_eagerly=False):
  """Run BERT classifier training using low-level API."""
  max_seq_length = input_meta_data['max_seq_length']
  num_classes = input_meta_data['num_labels']

  train_input_fn = functools.partial(
      input_pipeline.create_classifier_dataset,
      FLAGS.train_data_path,
      seq_length=max_seq_length,
      batch_size=FLAGS.train_batch_size)
  eval_input_fn = functools.partial(
      input_pipeline.create_classifier_dataset,
      FLAGS.eval_data_path,
      seq_length=max_seq_length,
      batch_size=FLAGS.eval_batch_size,
      is_training=False,
      drop_remainder=False)

  def _get_classifier_model():
    """Gets a classifier model."""
    classifier_model, core_model = (
        bert_models.classifier_model(
            bert_config,
            tf.float32,
            num_classes,
            max_seq_length,
            hub_module_url=FLAGS.hub_module_url))
    classifier_model.optimizer = optimization.create_optimizer(
        initial_lr, steps_per_epoch * epochs, warmup_steps)
    if FLAGS.fp16_implementation == 'graph_rewrite':
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      classifier_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          classifier_model.optimizer)
    return classifier_model, core_model

  loss_fn = get_loss_fn(
      num_classes,
      loss_factor=1.0 /
      strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)

  # Defines evaluation metrics function, which will create metrics in the
  # correct device and strategy scope.
  def metric_fn():
    return tf.keras.metrics.SparseCategoricalAccuracy(
        'test_accuracy', dtype=tf.float32)

  return model_training_utils.run_customized_training_loop(
      strategy=strategy,
      model_fn=_get_classifier_model,
      loss_fn=loss_fn,
      model_dir=model_dir,
      steps_per_epoch=steps_per_epoch,
      steps_per_loop=steps_per_loop,
      epochs=epochs,
      train_input_fn=train_input_fn,
      eval_input_fn=eval_input_fn,
      eval_steps=eval_steps,
      init_checkpoint=init_checkpoint,
      metric_fn=metric_fn,
      custom_callbacks=custom_callbacks,
      run_eagerly=run_eagerly)