Beispiel #1
0
 def _get_pretrain_model():
     """Gets a pretraining model."""
     pretrain_model, core_model = bert_models.pretrain_model(
         bert_config, max_seq_length, max_predictions_per_seq)
     optimizer = optimization.create_optimizer(initial_lr,
                                               steps_per_epoch * epochs,
                                               warmup_steps,
                                               FLAGS.optimizer_type)
     pretrain_model.optimizer = performance.configure_optimizer(
         optimizer,
         use_float16=common_flags.use_float16(),
         use_graph_rewrite=common_flags.use_graph_rewrite())
     return pretrain_model, core_model
Beispiel #2
0
 def _get_pretrain_model():
     """Gets a pretraining model."""
     pretrain_model, core_model = bert_models.pretrain_model(
         bert_config,
         max_seq_length,
         max_predictions_per_seq,
         use_next_sentence_label=use_next_sentence_label)
     optimizer = optimization.create_optimizer(initial_lr,
                                               steps_per_epoch * epochs,
                                               warmup_steps, end_lr,
                                               optimizer_type)
     pretrain_model.optimizer = performance.configure_optimizer(
         optimizer, use_float16=common_flags.use_float16())
     return pretrain_model, core_model
Beispiel #3
0
 def _get_classifier_model():
     """Gets a classifier model."""
     classifier_model, core_model = (bert_models.classifier_model(
         bert_config,
         num_classes,
         max_seq_length,
         hub_module_url=FLAGS.hub_module_url,
         hub_module_trainable=FLAGS.hub_module_trainable))
     optimizer = optimization.create_optimizer(initial_lr,
                                               steps_per_epoch * epochs,
                                               warmup_steps, FLAGS.end_lr,
                                               FLAGS.optimizer_type)
     classifier_model.optimizer = performance.configure_optimizer(
         optimizer, use_float16=common_flags.use_float16())
     return classifier_model, core_model
Beispiel #4
0
    def _get_squad_model():
        """Get Squad model and optimizer."""
        squad_model, core_model = bert_models.squad_model(
            bert_config,
            max_seq_length,
            hub_module_url=FLAGS.hub_module_url,
            hub_module_trainable=FLAGS.hub_module_trainable)
        optimizer = optimization.create_optimizer(FLAGS.learning_rate,
                                                  steps_per_epoch * epochs,
                                                  warmup_steps, FLAGS.end_lr,
                                                  FLAGS.optimizer_type)

        squad_model.optimizer = performance.configure_optimizer(
            optimizer, use_float16=common_flags.use_float16())
        return squad_model, core_model
Beispiel #5
0
 def _get_model():
     """Gets a siamese model."""
     if FLAGS.model_type == 'siamese':
         model, core_model = (siamese_bert.siamese_model(
             bert_config, num_classes, siamese_type=FLAGS.siamese_type))
     else:
         model, core_model = (bert_models.classifier_model(
             bert_config, num_classes, max_seq_length))
     optimizer = optimization.create_optimizer(initial_lr,
                                               steps_per_epoch * epochs,
                                               warmup_steps, FLAGS.end_lr,
                                               FLAGS.optimizer_type)
     model.optimizer = performance.configure_optimizer(
         optimizer,
         use_float16=common_flags.use_float16(),
         use_graph_rewrite=common_flags.use_graph_rewrite())
     return model, core_model
Beispiel #6
0
def train_squad(strategy,
                input_meta_data,
                custom_callbacks=None,
                run_eagerly=False):
  """Run bert squad training."""
  if strategy:
    logging.info('Training using customized training loop with distribution'
                 ' strategy.')
  # Enables XLA in Session Config. Should not be set for TPU.
  keras_utils.set_config_v2(FLAGS.enable_xla)

  use_float16 = common_flags.use_float16()
  if use_float16:
    tf.keras.mixed_precision.experimental.set_policy('mixed_float16')

  bert_config = MODEL_CLASSES[FLAGS.model_type][0].from_json_file(
      FLAGS.bert_config_file)
  epochs = FLAGS.num_train_epochs
  num_train_examples = input_meta_data['train_data_size']
  max_seq_length = input_meta_data['max_seq_length']
  steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
  warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size)
  train_input_fn = get_dataset_fn(
      FLAGS.train_data_path,
      max_seq_length,
      FLAGS.train_batch_size,
      is_training=True)

  def _get_squad_model():
    """Get Squad model and optimizer."""
    squad_model, core_model = bert_models.squad_model(
        bert_config,
        max_seq_length,
        float_type=tf.float16 if use_float16 else tf.float32,
        hub_module_url=FLAGS.hub_module_url)
    squad_model.optimizer = optimization.create_optimizer(
        FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
    if use_float16:
      # Wraps optimizer with a LossScaleOptimizer. This is done automatically
      # in compile() with the "mixed_float16" policy, but since we do not call
      # compile(), we must wrap the optimizer manually.
      squad_model.optimizer = (
          tf.keras.mixed_precision.experimental.LossScaleOptimizer(
              squad_model.optimizer, loss_scale=common_flags.get_loss_scale()))
    if FLAGS.fp16_implementation == 'graph_rewrite':
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      squad_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          squad_model.optimizer)
    return squad_model, core_model

  # The original BERT model does not scale the loss by
  # 1/num_replicas_in_sync. It could be an accident. So, in order to use
  # the same hyper parameter, we do the same thing here by keeping each
  # replica loss as it is.
  loss_fn = get_loss_fn(
      loss_factor=1.0 /
      strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)

  model_training_utils.run_customized_training_loop(
      strategy=strategy,
      model_fn=_get_squad_model,
      loss_fn=loss_fn,
      model_dir=FLAGS.model_dir,
      steps_per_epoch=steps_per_epoch,
      steps_per_loop=FLAGS.steps_per_loop,
      epochs=epochs,
      train_input_fn=train_input_fn,
      init_checkpoint=FLAGS.init_checkpoint,
      run_eagerly=run_eagerly,
      custom_callbacks=custom_callbacks)
Beispiel #7
0
def train_squad(strategy,
                input_meta_data,
                bert_config,
                custom_callbacks=None,
                run_eagerly=False):
    """Run bert squad training."""
    if strategy:
        logging.info(
            'Training using customized training loop with distribution'
            ' strategy.')
    # Enables XLA in Session Config. Should not be set for TPU.
    keras_utils.set_config_v2(FLAGS.enable_xla)

    use_float16 = common_flags.use_float16()
    if use_float16:
        tf.keras.mixed_precision.experimental.set_policy('mixed_float16')

    epochs = FLAGS.num_train_epochs
    num_train_examples = input_meta_data['train_data_size']
    max_seq_length = input_meta_data['max_seq_length']
    steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
    warmup_steps = int(epochs * num_train_examples * 0.1 /
                       FLAGS.train_batch_size)
    train_input_fn = get_dataset_fn(FLAGS.train_data_path,
                                    max_seq_length,
                                    FLAGS.train_batch_size,
                                    is_training=True)

    def _get_squad_model():
        """Get Squad model and optimizer."""
        squad_model, core_model = bert_models.squad_model(
            bert_config,
            max_seq_length,
            hub_module_url=FLAGS.hub_module_url,
            hub_module_trainable=FLAGS.hub_module_trainable)
        squad_model.optimizer = optimization.create_optimizer(
            FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
        if use_float16:
            # Wraps optimizer with a LossScaleOptimizer. This is done automatically
            # in compile() with the "mixed_float16" policy, but since we do not call
            # compile(), we must wrap the optimizer manually.
            squad_model.optimizer = (
                tf.keras.mixed_precision.experimental.LossScaleOptimizer(
                    squad_model.optimizer,
                    loss_scale=common_flags.get_loss_scale()))
        if FLAGS.fp16_implementation == 'graph_rewrite':
            # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
            # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
            # which will ensure tf.compat.v2.keras.mixed_precision and
            # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
            # up.
            squad_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                squad_model.optimizer)
        return squad_model, core_model

    # The original BERT model does not scale the loss by
    # 1/num_replicas_in_sync. It could be an accident. So, in order to use
    # the same hyper parameter, we do the same thing here by keeping each
    # replica loss as it is.
    loss_fn = get_loss_fn(
        loss_factor=1.0 /
        strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)

    # when all_reduce_sum_gradients = False, apply_gradients() no longer
    # implicitly allreduce gradients, users manually allreduce gradient and
    # passed the allreduced grads_and_vars. For now, the clip_by_global_norm
    # will be moved to before users' manual allreduce to keep the math
    # unchanged.
    def clip_by_global_norm_callback(grads_and_vars):
        grads, variables = zip(*grads_and_vars)
        (clipped_grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
        return zip(clipped_grads, variables)

    model_training_utils.run_customized_training_loop(
        strategy=strategy,
        model_fn=_get_squad_model,
        loss_fn=loss_fn,
        model_dir=FLAGS.model_dir,
        steps_per_epoch=steps_per_epoch,
        steps_per_loop=FLAGS.steps_per_loop,
        epochs=epochs,
        train_input_fn=train_input_fn,
        init_checkpoint=FLAGS.init_checkpoint,
        run_eagerly=run_eagerly,
        custom_callbacks=custom_callbacks,
        explicit_allreduce=True,
        pre_allreduce_callbacks=[clip_by_global_norm_callback])