def _run_callbacks_on_batch_end(batch, logs):
   """Runs custom callbacks at the end of every step."""
   mlp_log.mlperf_print(
       'block_stop', None, metadata={
           'first_epoch_num': int(batch),
       })
   if not custom_callbacks:
     return
   for callback in custom_callbacks:
     callback.on_batch_end(batch, logs)
Example #2
0
  def eval_begin(self):
    """See base class."""
    if self.test_loss:
      self.test_loss.reset_states()
    if self.test_accuracy:
      self.test_accuracy.reset_states()
    # self.test_corrects.reset_states()

    epoch_num = int(self.epoch_helper.current_epoch)
    mlp_log.mlperf_print('eval_start', None,
                         metadata={'epoch_num': epoch_num + 1})
 def _run_callbacks_on_batch_begin(batch):
   """Runs custom callbacks at the start of every step."""
   # While BERT pretraining does not have epochs,
   # to make the logging consistent with other mlperf models,
   # in all the mlp_log, epochs are steps.
   mlp_log.mlperf_print(
       'block_start',
       None,
       metadata={
           'first_epoch_num': int(batch),
           'epoch_count': int(steps_per_loop),
       })
   if not custom_callbacks:
     return
   for callback in custom_callbacks:
     callback.on_batch_begin(batch)
    def _run_evaluation(current_training_step, test_iterator):
      """Runs validation steps and aggregate metrics."""
      mlperf_epoch_num = int(current_training_step / steps_between_eval)
      mlp_log.mlperf_print(
          'eval_start', None, metadata={'epoch_num': mlperf_epoch_num})
      for _ in range(eval_steps):
        test_step(test_iterator)
      mlp_log.mlperf_print(
          'eval_stop', None, metadata={'epoch_num': mlperf_epoch_num})

      with eval_summary_writer.as_default():
        masked_lm_accuracy = (
            _float_metric_value(eval_metric_num) /
            _float_metric_value(eval_metric_denom))
        logging.info('Step: [%d] Validation %s = %f', current_training_step,
                     'masked_lm_accuracy', masked_lm_accuracy)
        tf.summary.scalar(
            'masked_lm_accuracy',
            masked_lm_accuracy,
            step=current_training_step)
        mlp_log.mlperf_print(
            'eval_accuracy',
            masked_lm_accuracy,
            metadata={'epoch_num': mlperf_epoch_num})
        eval_summary_writer.flush()
      return masked_lm_accuracy
Example #5
0
    def eval_end(self):
        """See base class."""
        epoch_num = int(self.epoch_helper.current_epoch)
        mlp_log.mlperf_print('eval_stop',
                             None,
                             metadata={'epoch_num': epoch_num + 1})

        eval_accuracy = float(self.test_accuracy.result())
        # eval_accuracy = float(self.test_corrects.result()
        #                      ) / imagenet_preprocessing.NUM_IMAGES['validation']
        # eval_accuracy = float(self.test_accuracy.result()) * \
        #     self.flags_obj.batch_size * self.num_eval_steps / \
        #     imagenet_preprocessing.NUM_IMAGES['validation']
        mlp_log.mlperf_print('eval_accuracy',
                             eval_accuracy,
                             metadata={'epoch_num': epoch_num + 1})

        first_epoch_num = max(epoch_num - self.epochs_between_evals + 1, 0)
        epoch_count = self.epochs_between_evals
        if first_epoch_num == 0:
            epoch_count = self.flags_obj.eval_offset_epochs
            if epoch_count == 0:
                epoch_count = self.flags_obj.epochs_between_evals
        mlp_log.mlperf_print('block_stop',
                             None,
                             metadata={
                                 'first_epoch_num': first_epoch_num + 1,
                                 'epoch_count': epoch_count
                             })

        continue_training = True
        if (eval_accuracy >= self.flags_obj.target_accuracy
                or eval_accuracy <= 0.002):
            continue_training = False
        else:
            mlp_log.mlperf_print('block_start',
                                 None,
                                 metadata={
                                     'first_epoch_num': epoch_num + 2,
                                     'epoch_count': self.epochs_between_evals
                                 })

        results = {}
        if self.test_loss:
            results['test_loss'] = self.test_loss.result()
        if self.test_accuracy:
            results['test_accuracy'] = self.test_accuracy.result()
        results['continue_training'] = continue_training
        return results
def run(flags_obj):
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
    NotImplementedError: If some features are not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  mlp_log.mlperf_print('init_start', None)
  common.print_flags(flags_obj)

  keras_utils.set_session_config(
      enable_eager=flags_obj.enable_eager,
      enable_xla=flags_obj.enable_xla)

  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
    keras_utils.set_gpu_thread_mode_and_count(
        per_gpu_thread_count=flags_obj.per_gpu_thread_count,
        gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
        num_gpus=flags_obj.num_gpus,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads)
  common.set_cudnn_batchnorm_mode()

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == tf.float16:
    loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
    policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
        'mixed_float16', loss_scale=loss_scale)
    tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
    if not keras_utils.is_v2_0():
      raise ValueError('--dtype=fp16 is not supported in TensorFlow 1.')
  elif dtype == tf.bfloat16:
    policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
        'mixed_bfloat16')
    tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  # Configures cluster spec for distribution strategy.
  _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                           flags_obj.task_index)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs,
      tpu_address=flags_obj.tpu,
      tpu_zone=flags_obj.tpu_zone if flags_obj.tpu else None)

  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )

  strategy_scope = distribution_utils.get_strategy_scope(strategy)

  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
    distribution_utils.set_up_synthetic_data()
    input_fn = common.get_synth_input_fn(
        height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_preprocessing.NUM_CHANNELS,
        num_classes=flags_obj.num_classes,
        dtype=dtype,
        drop_remainder=True)
  else:
    distribution_utils.undo_set_up_synthetic_data()
    input_fn = imagenet_preprocessing.input_fn

  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  # drop_remainder = flags_obj.enable_xla

  # Current resnet_model.resnet50 input format is always channel-last.
  # We use keras_application mobilenet model which input format is depends on
  # the keras beckend image data format.
  # This use_keras_image_data_format flags indicates whether image preprocessor
  # output format should be same as the keras backend image data format or just
  # channel-last format.
  use_keras_image_data_format = (flags_obj.model == 'mobilenet')
  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
          use_keras_image_data_format=use_keras_image_data_format),
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
      dtype=dtype,
      drop_remainder=flags_obj.drop_train_remainder,
      tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
      training_dataset_cache=flags_obj.training_dataset_cache,
  )

  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
        dtype=dtype,
        drop_remainder=flags_obj.drop_eval_remainder)

  steps_per_epoch, train_epochs = common.get_num_train_iterations(flags_obj)

  mlp_log.mlperf_print('global_batch_size', flags_obj.batch_size)
  mlp_log.mlperf_print('num_train_examples',
                       imagenet_preprocessing.NUM_IMAGES['train'])
  mlp_log.mlperf_print('num_eval_examples',
                       imagenet_preprocessing.NUM_IMAGES['validation'])

  learning_rate_schedule_fn = None
  with strategy_scope:
    optimizer, learning_rate_schedule_fn = common.get_optimizer(
        flags_obj=flags_obj,
        steps_per_epoch=steps_per_epoch,
        train_steps=train_epochs * steps_per_epoch)
    if flags_obj.fp16_implementation == 'graph_rewrite':
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          optimizer)

    if flags_obj.model == 'resnet50_v1.5':
      resnet_model.change_keras_layer(flags_obj.use_tf_keras_layers)
      model = resnet_model.resnet50(num_classes=flags_obj.num_classes)
    elif flags_obj.model == 'mobilenet':
      # TODO(kimjaehong): Remove layers attribute when minimum TF version
      # support 2.0 layers by default.
      model = tf.keras.applications.mobilenet.MobileNet(
          weights=None, classes=flags_obj.num_classes, layers=tf.keras.layers)
    if flags_obj.pretrained_filepath:
      model.load_weights(flags_obj.pretrained_filepath)

    if flags_obj.pruning_method == 'polynomial_decay':
      if dtype != tf.float32:
        raise NotImplementedError(
            'Pruning is currently only supported on dtype=tf.float32.')
      pruning_params = {
          'pruning_schedule':
              tfmot.sparsity.keras.PolynomialDecay(
                  initial_sparsity=flags_obj.pruning_initial_sparsity,
                  final_sparsity=flags_obj.pruning_final_sparsity,
                  begin_step=flags_obj.pruning_begin_step,
                  end_step=flags_obj.pruning_end_step,
                  frequency=flags_obj.pruning_frequency),
      }
      model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
    elif flags_obj.pruning_method:
      raise NotImplementedError(
          'Only polynomial_decay is currently supported.')
    # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
    # a valid arg for this model. Also remove as a valid flag.
    if flags_obj.force_v2_in_keras_compile is not None:
      model.compile(
          loss='sparse_categorical_crossentropy',
          optimizer=optimizer,
          metrics=(['sparse_categorical_accuracy']
                   if flags_obj.report_accuracy_metrics else None),
          run_eagerly=flags_obj.run_eagerly,
          experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)
    else:
      model.compile(
          loss='sparse_categorical_crossentropy',
          optimizer=optimizer,
          metrics=(['sparse_categorical_accuracy']
                   if flags_obj.report_accuracy_metrics else None),
          run_eagerly=flags_obj.run_eagerly)

  callbacks = common.get_callbacks(
      steps_per_epoch=steps_per_epoch,
      learning_rate_schedule_fn=learning_rate_schedule_fn,
      pruning_method=flags_obj.pruning_method,
      enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
      model_dir=flags_obj.model_dir)

  num_eval_steps = common.get_num_eval_steps(flags_obj)
  if flags_obj.skip_eval:
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None

  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
    # when not using distribition strategy.
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

  mlp_log.mlperf_print('init_stop', None)
  mlp_log.mlperf_print('run_start', None)

  for epoch in range(train_epochs):
    mlp_log.mlperf_print('epoch_start', None,
                         metadata={'first_epoch_num': epoch,
                                   'epoch_count': 1})
    mlp_log.mlperf_print('block_start', None)
    history = model.fit(train_input_dataset,
                        epochs=1,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks,
                        verbose=2)
    mlp_log.mlperf_print('block_stop', None)

    eval_output = None
    if not flags_obj.skip_eval:
      mlp_log.mlperf_print('eval_start', None)
      eval_output = model.evaluate(eval_input_dataset,
                                   steps=num_eval_steps,
                                   verbose=2)
      mlp_log.mlperf_print('eval_stop', None)

      eval_accuracy = eval_output[1]
      mlp_log.mlperf_print(
          'eval_accuracy', eval_accuracy, metadata={'epoch_num': epoch})
      if eval_accuracy >= flags_obj.target_accuracy:
        break

    mlp_log.mlperf_print('epoch_stop', None)

  mlp_log.mlperf_print('run_stop', None)

  if flags_obj.pruning_method:
    model = tfmot.sparsity.keras.strip_pruning(model)
  if flags_obj.enable_checkpoint_and_export:
    if dtype == tf.bfloat16:
      logging.warning('Keras model.save does not support bfloat16 dtype.')
    else:
      # Keras model.save assumes a float32 input designature.
      export_path = os.path.join(flags_obj.model_dir, 'saved_model')
      model.save(export_path, include_optimizer=False)

  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

  stats = common.build_stats(history, eval_output, callbacks)
  return stats
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using custom training loops.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    mlp_log.mlperf_print('cache_clear', True)
    mlp_log.mlperf_print('init_start', None)
    mlp_log.mlperf_print('submission_benchmark', 'resnet')
    mlp_log.mlperf_print('submission_division', 'closed')
    mlp_log.mlperf_print('submission_org', 'google')
    mlp_log.mlperf_print(
        'submission_platform', 'tpu-v3-{}'.format(flags_obj.num_replicas)
        if flags_obj.tpu else 'gpu-v100-{}'.format(flags_obj.num_gpus))
    mlp_log.mlperf_print('submission_status', 'cloud')

    common.print_flags(flags_obj)

    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)
    performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))

    if tf.config.list_physical_devices('GPU'):
        if flags_obj.tf_gpu_thread_mode:
            datasets_num_private_threads = keras_utils.set_gpu_thread_mode_and_count(
                per_gpu_thread_count=flags_obj.per_gpu_thread_count,
                gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
                num_gpus=flags_obj.num_gpus)
            if not flags_obj.datasets_num_private_threads:
                flags_obj.datasets_num_private_threads = datasets_num_private_threads
        common.set_cudnn_batchnorm_mode()

    # TODO(anj-s): Set data_format without using Keras.
    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu,
        tpu_zone=flags_obj.tpu_zone if flags_obj.tpu else None)
    mlp_log.mlperf_print('global_batch_size', flags_obj.batch_size)
    mlp_log.mlperf_print('train_samples',
                         imagenet_preprocessing.NUM_IMAGES['train'])
    mlp_log.mlperf_print('eval_samples',
                         imagenet_preprocessing.NUM_IMAGES['validation'])
    mlp_log.mlperf_print(
        'model_bn_span',
        int(flags_obj.batch_size /
            (flags_obj.num_replicas if flags_obj.tpu else flags_obj.num_gpus)))

    per_epoch_steps, train_epochs = common.get_num_train_iterations(flags_obj)
    eval_steps = common.get_num_eval_steps(flags_obj)
    steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps)

    logging.info(
        'Training %d epochs, each epoch has %d steps, '
        'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
        train_epochs * per_epoch_steps, eval_steps)

    time_callback = keras_utils.TimeHistory(
        flags_obj.batch_size,
        flags_obj.log_steps,
        logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
    with distribution_utils.get_strategy_scope(strategy):
        runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback)

    eval_interval = (flags_obj.epochs_between_evals *
                     per_epoch_steps if not flags_obj.skip_eval else None)
    eval_offset = (flags_obj.eval_offset_epochs *
                   per_epoch_steps if not flags_obj.skip_eval else 0)
    if eval_offset != 0:
        eval_offset -= eval_interval
    checkpoint_interval = (per_epoch_steps
                           if flags_obj.enable_checkpoint_and_export else None)
    summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None

    checkpoint_manager = tf.train.CheckpointManager(
        runnable.checkpoint,
        directory=flags_obj.model_dir,
        max_to_keep=10,
        step_counter=runnable.global_step,
        checkpoint_interval=checkpoint_interval)

    device_warmup_steps = (flags_obj.device_warmup_steps
                           if flags_obj.enable_device_warmup else 0)
    if flags_obj.enable_device_warmup:
        logging.info('Warmup for %d steps.', device_warmup_steps)

    resnet_controller = controller.Controller(
        strategy,
        runnable.train,
        runnable.evaluate,
        runnable.warmup,
        global_step=runnable.global_step,
        steps_per_loop=steps_per_loop,
        train_steps=per_epoch_steps * train_epochs,
        device_warmup_steps=device_warmup_steps,
        checkpoint_manager=checkpoint_manager,
        summary_interval=summary_interval,
        eval_steps=eval_steps,
        eval_interval=eval_interval,
        eval_offset=eval_offset)

    if flags_obj.enable_device_warmup:
        resnet_controller.warmup()

    mlp_log.mlperf_print('init_stop', None)

    profile_steps = flags_obj.profile_steps
    if profile_steps:
        profile_steps = [int(i) for i in profile_steps.split(',')]
        if profile_steps[0] < 0:
            runnable.trace_start(-1)

    time_callback.on_train_begin()
    mlp_log.mlperf_print('run_start', None)
    mlp_log.mlperf_print(
        'block_start',
        None,
        metadata={
            'first_epoch_num':
            1,
            'epoch_count':
            (flags_obj.eval_offset_epochs if flags_obj.eval_offset_epochs != 0
             else flags_obj.epochs_between_evals)
        })
    resnet_controller.train(evaluate=not flags_obj.skip_eval)
    mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'})
    time_callback.on_train_end()
    mlp_log.mlperf_print('run_final', None)

    stats = build_stats(runnable, time_callback)
    return stats
def run_customized_training_loop(
    # pylint: disable=invalid-name
    _sentinel=None,
    # pylint: enable=invalid-name
    strategy=None,
    model_fn=None,
    loss_fn=None,
    model_dir=None,
    train_input_fn=None,
    steps_per_epoch=None,
    steps_per_loop=1,
    epochs=1,
    eval_input_fn=None,
    eval_steps=None,
    steps_between_eval=None,
    steps_before_eval_start=None,
    stop_threshold=None,
    metric_fn=None,
    init_checkpoint=None,
    custom_callbacks=None,
    run_eagerly=False,
    sub_model_export_name=None,
    explicit_allreduce=False,
    device_warmup=False,
    synthetic_train_input_fn=None,
    pre_allreduce_callbacks=None,
    post_allreduce_callbacks=None,
    allreduce_bytes_per_pack=0,
    enable_checkpoint_and_summary=False,
    num_accumulation_steps=1,
    stop_steps=None):
  """Run BERT pretrain model training using low-level API.

  Arguments:
      _sentinel: Used to prevent positional parameters. Internal, do not use.
      strategy: Distribution strategy on which to run low level training loop.
      model_fn: Function that returns a tuple (model, sub_model). Caller of this
        function should add optimizer to the `model` via calling
        `model.compile()` API or manually setting `model.optimizer` attribute.
        Second element of the returned tuple(sub_model) is an optional sub model
        to be used for initial checkpoint -- if provided.
      loss_fn: Function with signature func(labels, logits) and returns a loss
        tensor.
      model_dir: Model directory used during training for restoring/saving model
        weights.
      train_input_fn: Function that returns a tf.data.Dataset used for training.
      steps_per_epoch: Number of steps to run per epoch. At the end of each
        epoch, model checkpoint will be saved and evaluation will be conducted
        if evaluation dataset is provided.
      steps_per_loop: Number of steps per graph-mode loop. In order to reduce
        communication in eager context, training logs are printed every
        steps_per_loop.
      epochs: Number of epochs to train.
      eval_input_fn: Function that returns evaluation dataset. If none,
        evaluation is skipped.
      eval_steps: Number of steps to run evaluation. Required if `eval_input_fn`
        is not none.
      steps_between_eval: Number of steps between evals
      steps_before_eval_start: Number of steps to skip before starting eval
      stop_threshold: Stop threshold for MLPerf once accuracy achieved
      metric_fn: A metrics function that returns a Keras Metric object to record
        evaluation result using evaluation dataset or with training dataset
        after every epoch.
      init_checkpoint: Optional checkpoint to load to `sub_model` returned by
        `model_fn`.
      custom_callbacks: A list of Keras Callbacks objects to run during
        training. More specifically, `on_batch_begin()`, `on_batch_end()`,
        methods are invoked during training.
      run_eagerly: Whether to run model training in pure eager execution. This
        should be disable for TPUStrategy.
      sub_model_export_name: If not None, will export `sub_model` returned by
        `model_fn` into checkpoint files. The name of intermediate checkpoint
        file is {sub_model_export_name}_step_{step}.ckpt and the last
        checkpint's name is {sub_model_export_name}.ckpt;
        if None, `sub_model` will not be exported as checkpoint.
      explicit_allreduce: Whether to explicitly perform gradient allreduce,
        instead of relying on implicit allreduce in optimizer.apply_gradients().
        default is False. For now, if training using FP16 mixed precision,
        explicit allreduce will aggregate gradients in FP16 format. For TPU and
        GPU training using FP32, explicit allreduce will aggregate gradients in
        FP32 format.
      device_warmup: Whether or not to enable device warmup. This
        runs the training and eval loop on synthetic data to pre-compile XLA
        and TF tracing before accessing data.
      synthetic_train_input_fn: Function that returns synthetic training
        dataset. This is used in device warmup.
      pre_allreduce_callbacks: A list of callback functions that takes gradients
        and model variables pairs as input, manipulate them, and returns a new
        gradients and model variables paris. The callback functions will be
        invoked in the list order and before gradients are allreduced.
        Default is no callbacks. Only used when explicit_allreduce=True.
      post_allreduce_callbacks: A list of callback functions that takes
        gradients and model variables pairs as input, manipulate them, and
        returns a new gradients and model variables paris. The callback
        functions will be invoked in the list order and right before gradients
        are applied to variables for updates. Default is no callbacks. Only used
        when explicit_allreduce=True.
      allreduce_bytes_per_pack: A non-negative integer. Breaks collective
        operations into packs of certain size. If it's zero, all gradients are
        in one pack.
      enable_checkpoint_and_summary: Whether to save checkpoint and summary.
      stop_steps: The number of steps to run before stopping the training loop.

  Returns:
      Trained model.

  Raises:
      ValueError: (1) When model returned by `model_fn` does not have optimizer
        attribute or when required parameters are set to none. (2) eval args are
        not specified correctly. (3) metric_fn must be a callable if specified.
        (4) sub_model_checkpoint_name is specified, but `sub_model` returned
        by `model_fn` is None.
  """

  if _sentinel is not None:
    raise ValueError('only call `run_customized_training_loop()` '
                     'with named arguments.')

  required_arguments = [
      strategy, model_fn, loss_fn, model_dir, steps_per_epoch, train_input_fn
  ]
  if [arg for arg in required_arguments if arg is None]:
    raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, '
                     '`steps_per_loop` and `steps_per_epoch` are required '
                     'parameters.')

  if steps_between_eval % steps_per_loop != 0:
    raise ValueError('steps_between_eval should be multiple of steps_per_loop.')

  if steps_per_loop > steps_per_epoch:
    logging.error(
        'steps_per_loop: %d is specified to be greater than '
        ' steps_per_epoch: %d, we will use steps_per_epoch as'
        ' steps_per_loop.', steps_per_loop, steps_per_epoch)
    steps_per_loop = steps_per_epoch
  assert tf.executing_eagerly()

  if run_eagerly:
    if steps_per_loop > 1:
      raise ValueError(
          'steps_per_loop is used for performance optimization. When you want '
          'to run eagerly, you cannot leverage graph mode loop.')
    if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
      raise ValueError(
          'TPUStrategy should not run eagerly as it heavily replies on graph'
          ' optimization for the distributed system.')

  if eval_input_fn and (eval_steps is None):
    raise ValueError(
        '`eval_step` and `metric_fn` are required when `eval_input_fn ` '
        'is not none.')
  if device_warmup and (synthetic_train_input_fn is None):
    raise ValueError('`synthetic_train_input_fn` is required when '
                     'device_warmup is enabled.')

  if metric_fn and not callable(metric_fn):
    raise ValueError(
        'if `metric_fn` is specified, metric_fn must be a callable.')

  if stop_steps:
    total_training_steps = stop_steps
  else:
    total_training_steps = steps_per_epoch * epochs

  if stop_steps and stop_steps > steps_per_epoch * epochs:
    raise ValueError('`stop_steps` should not be greater than '
                     '`num_train_steps_per_epoch` * `num_epochs`.')

  # To reduce unnecessary send/receive input pipeline operation, we place input
  # pipeline ops in worker task.
  train_iterator = _get_input_iterator(train_input_fn, strategy)

  with distribution_utils.get_strategy_scope(strategy):
    # To correctly place the model weights on accelerators,
    # model and optimizer should be created in scope.
    model, sub_model, sub_pretrain_model = model_fn()
    if not hasattr(model, 'optimizer'):
      raise ValueError('User should set optimizer attribute to model '
                       'inside `model_fn`.')
    if sub_model_export_name and sub_model is None:
      raise ValueError('sub_model_export_name is specified as %s, but '
                       'sub_model is None.' % sub_model_export_name)

    optimizer = model.optimizer

    train_loss_metric = tf.keras.metrics.Mean(
        'training_loss', dtype=tf.float32)
    if eval_input_fn:
      eval_metric_num = tf.keras.metrics.Sum('masked_lm_num', dtype=tf.float32)
      eval_metric_denom = tf.keras.metrics.Sum(
          'masked_lm_denom', dtype=tf.float32)

    # If evaluation is required, make a copy of metric as it will be used by
    # both train and evaluation.
    train_metrics = [
        tf.keras.metrics.Mean('masked_lm_accuracy', dtype=tf.float32)
    ]

    # Create summary writers
    summary_dir = os.path.join(model_dir, 'summaries')
    if enable_checkpoint_and_summary:
      eval_summary_writer = tf.summary.create_file_writer(
          os.path.join(summary_dir, 'eval'))
    else:
      eval_summary_writer = tf.summary.create_noop_writer()
    if steps_per_loop >= _MIN_SUMMARY_STEPS and enable_checkpoint_and_summary:
      # Only writes summary when the stats are collected sufficiently over
      # enough steps.
      train_summary_writer = tf.summary.create_file_writer(
          os.path.join(summary_dir, 'train'))
    else:
      train_summary_writer = tf.summary.create_noop_writer()

    # Collects training variables.
    training_vars = model.trainable_variables

    @tf.function(experimental_compile=True)
    def _compiled_local_step(inputs, labels, training_vars, accum_vars):
      """Replicated training step."""
      with tf.GradientTape() as tape:
        model_outputs, metric_outputs = model(inputs, training=True)
        loss = loss_fn(labels, model_outputs)
      if isinstance(optimizer,
                    tf.keras.mixed_precision.experimental.LossScaleOptimizer):
        with tape:
          scaled_loss = optimizer.get_scaled_loss(loss)
        scaled_grads = tape.gradient(scaled_loss, training_vars)
        grads = optimizer.get_unscaled_gradients(scaled_grads)
      else:
        grads = tape.gradient(loss, training_vars)
      (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)

      if accum_vars is None:
        return grads, loss, model_outputs, metric_outputs
      else:
        new_accum_vars = []
        for i, grad in enumerate(grads):
          new_accum_vars.append(
              accum_vars[i] +
              tf.math.scalar_mul(1.0 / num_accumulation_steps, grad))
        return new_accum_vars, loss, model_outputs, metric_outputs

    def get_input_slice(input_dict, idx):
      split_input = {}
      for key in input_dict:
        split_input[key] = input_dict[key][idx]
      return split_input

    def _replicated_step(inputs):
      """Replicated training step."""
      inputs, labels = inputs
      if explicit_allreduce:
        # TODO(b/155523821): Fix OOM issue so we use experimental_compile with
        # multi-worker mirrored strategy.
        with tf.GradientTape() as tape:
          model_outputs, metric_outputs = model(inputs, training=True)
          loss = loss_fn(labels, model_outputs)

        grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss,
                                                     training_vars,
                                                     pre_allreduce_callbacks,
                                                     post_allreduce_callbacks,
                                                     allreduce_bytes_per_pack)
      else:
        if num_accumulation_steps > 1:
          accum_vars = [
              tf.zeros_like(tvar, dtype=tf.float32) for tvar in training_vars
          ]
          for key in inputs:
            inputs[key] = tf.split(inputs[key], num_accumulation_steps)

          split_labels = tf.split(labels, num_accumulation_steps)
          for local_step in range(num_accumulation_steps):
            accum_vars, loss, model_outputs, metric_outputs = _compiled_local_step(
                get_input_slice(inputs, local_step), split_labels[local_step],
                training_vars, accum_vars)

          optimizer.apply_gradients(zip(accum_vars, training_vars))
        else:
          grads, loss, model_outputs, metric_outputs = _compiled_local_step(
              inputs, labels, training_vars, None)
          optimizer.apply_gradients(zip(grads, training_vars))
      # For reporting, the metric takes the mean of losses.
      train_loss_metric.update_state(loss)
      for metric in train_metrics:
        metric.update_state(metric_outputs['masked_lm_accuracy'])

    @tf.function
    def train_steps(iterator, steps):
      """Performs distributed training steps in a loop.

      Args:
        iterator: the distributed iterator of training datasets.
        steps: an tf.int32 integer tensor to specify number of steps to run
          inside host training loop.

      Raises:
        ValueError: Any of the arguments or tensor shapes are invalid.
      """
      if not isinstance(steps, tf.Tensor):
        raise ValueError('steps should be an Tensor. Python object may cause '
                         'retracing.')

      for _ in tf.range(steps):
        strategy.run(_replicated_step, args=(next(iterator),))

    def train_single_step(iterator):
      """Performs a distributed training step.

      Args:
        iterator: the distributed iterator of training datasets.

      Raises:
        ValueError: Any of the arguments or tensor shapes are invalid.
      """
      strategy.run(_replicated_step, args=(next(iterator),))

    def test_step(iterator):
      """Calculates evaluation metrics on distributed devices."""

      def _test_step_fn(inputs):
        """Replicated accuracy calculation."""

        inputs, labels = inputs
        model_outputs, metric_outputs = model(inputs, training=False)
        eval_metric_num.update_state(metric_outputs['masked_lm_num'])
        eval_metric_denom.update_state(metric_outputs['masked_lm_denom'])
      strategy.run(_test_step_fn, args=(next(iterator),))

    if not run_eagerly:
      train_single_step = tf.function(train_single_step)
      test_step = tf.function(test_step)

    def _run_evaluation(current_training_step, test_iterator):
      """Runs validation steps and aggregate metrics."""
      mlperf_epoch_num = int(current_training_step / steps_between_eval)
      mlp_log.mlperf_print(
          'eval_start', None, metadata={'epoch_num': mlperf_epoch_num})
      for _ in range(eval_steps):
        test_step(test_iterator)
      mlp_log.mlperf_print(
          'eval_stop', None, metadata={'epoch_num': mlperf_epoch_num})

      with eval_summary_writer.as_default():
        masked_lm_accuracy = (
            _float_metric_value(eval_metric_num) /
            _float_metric_value(eval_metric_denom))
        logging.info('Step: [%d] Validation %s = %f', current_training_step,
                     'masked_lm_accuracy', masked_lm_accuracy)
        tf.summary.scalar(
            'masked_lm_accuracy',
            masked_lm_accuracy,
            step=current_training_step)
        mlp_log.mlperf_print(
            'eval_accuracy',
            masked_lm_accuracy,
            metadata={'epoch_num': mlperf_epoch_num})
        eval_summary_writer.flush()
      return masked_lm_accuracy

    def _run_callbacks_on_batch_begin(batch):
      """Runs custom callbacks at the start of every step."""
      # While BERT pretraining does not have epochs,
      # to make the logging consistent with other mlperf models,
      # in all the mlp_log, epochs are steps.
      mlp_log.mlperf_print(
          'block_start',
          None,
          metadata={
              'first_epoch_num': int(batch),
              'epoch_count': int(steps_per_loop),
          })
      if not custom_callbacks:
        return
      for callback in custom_callbacks:
        callback.on_batch_begin(batch)

    def _run_callbacks_on_batch_end(batch, logs):
      """Runs custom callbacks at the end of every step."""
      mlp_log.mlperf_print(
          'block_stop', None, metadata={
              'first_epoch_num': int(batch),
          })
      if not custom_callbacks:
        return
      for callback in custom_callbacks:
        callback.on_batch_end(batch, logs)

    # Training loop starts here.
    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
    sub_model_checkpoint = tf.train.Checkpoint(
        model=sub_model) if sub_model_export_name else None

    # TODO: commenting this out, as we always load from a initial checkpoint
    # latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
    # if latest_checkpoint_file:
    #   logging.info(
    #       'Checkpoint file %s found and restoring from '
    #       'checkpoint', latest_checkpoint_file)
    #   checkpoint.restore(latest_checkpoint_file)
    #   logging.info('Loading from checkpoint file completed')

    current_step = optimizer.iterations.numpy()
    checkpoint_name = 'ctl_step_{step}.ckpt'
    checkpoint_save_dir = model_dir if enable_checkpoint_and_summary else None

    if init_checkpoint:
      logging.info(
          'Checkpoint file %s found and restoring from '
          'initial checkpoint for core model.', init_checkpoint)
      checkpoint = tf.train.Checkpoint(model=sub_pretrain_model)
      checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
      logging.info('Loading from checkpoint file completed')

    if device_warmup:
      synthetic_train_iterator = _get_input_iterator(synthetic_train_input_fn,
                                                     strategy)
      logging.info('Running device warmup for 1 step.')
      train_steps(synthetic_train_iterator, tf.constant(1, dtype=tf.int32))
      # Reset the global step.
      tf.keras.backend.set_value(optimizer.iterations, 0)
      current_step = optimizer.iterations.numpy()

    masked_lm_accuracy = 0
    mlp_log.mlperf_print('init_stop', None)
    mlp_log.mlperf_print('run_start', None)

    while current_step < total_training_steps:
      # Training loss/metric are taking average over steps inside micro
      # training loop. We reset the their values before each round.
      train_loss_metric.reset_states()
      for metric in train_metrics + model.metrics:
        metric.reset_states()

      _run_callbacks_on_batch_begin(current_step)
      # Runs several steps in the host while loop.
      steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop)

      train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32))
      train_loss = _float_metric_value(train_loss_metric)
      _run_callbacks_on_batch_end(current_step, {'loss': train_loss})
      current_step += steps

      # Updates training logging.
      training_status = 'Train Step: %d/%d  / loss = %s' % (
          current_step, total_training_steps, train_loss)

      with train_summary_writer.as_default():
        tf.summary.scalar(
            train_loss_metric.name, train_loss, step=current_step)
        for metric in train_metrics + model.metrics:
          metric_value = _float_metric_value(metric)
          training_status += '  %s = %f' % (metric.name, metric_value)
          tf.summary.scalar(metric.name, metric_value, step=current_step)
        train_summary_writer.flush()
      logging.info(training_status)

      # Saves model checkpoints and run validation steps at every epoch end.
      if current_step % steps_per_epoch == 0:
        # To avoid repeated model saving, we do not save after the last
        # step of training.
        if current_step < total_training_steps:
          _save_checkpoint(checkpoint, checkpoint_save_dir,
                           checkpoint_name.format(step=current_step))
          if sub_model_export_name:
            _save_checkpoint(
                sub_model_checkpoint, checkpoint_save_dir,
                '%s_step_%d.ckpt' % (sub_model_export_name, current_step))
      if eval_input_fn and (current_step % (steps_between_eval) == 0) and (
          current_step >= steps_before_eval_start):
        logging.info('Running evaluation after step: %s.', current_step)
        masked_lm_accuracy = _run_evaluation(
            current_step, _get_input_iterator(eval_input_fn, strategy))
        if masked_lm_accuracy >= stop_threshold:
          mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'})
          break

        # Re-initialize evaluation metric.
        eval_metric_num.reset_states()
        eval_metric_denom.reset_states()

    if masked_lm_accuracy < stop_threshold:
      mlp_log.mlperf_print('run_stop', None, metadata={'status': 'aborted'})

    _save_checkpoint(checkpoint, checkpoint_save_dir,
                     checkpoint_name.format(step=current_step))
    if sub_model_export_name:
      _save_checkpoint(sub_model_checkpoint, checkpoint_save_dir,
                       '%s.ckpt' % sub_model_export_name)

    if enable_checkpoint_and_summary:
      training_summary = {
          'total_training_steps': total_training_steps,
          'train_loss': _float_metric_value(train_loss_metric),
      }
      if train_metrics:
        # TODO(hongkuny): Cleans up summary reporting in text.
        training_summary['last_train_metrics'] = _float_metric_value(
            train_metrics[0])
        #training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0])

      write_txt_summary(training_summary, summary_dir)

    return model, masked_lm_accuracy, current_step
Example #9
0
    def __init__(self,
                 batch_size,
                 steps_per_epoch,
                 train_steps,
                 initial_learning_rate=None,
                 end_learning_rate=None,
                 warmup_epochs=None,
                 compute_lr_on_cpu=False,
                 name=None):
        """Applies a polynomial decay to the learning rate with warmup."""
        super(PolynomialDecayWithWarmup, self).__init__()

        self.batch_size = batch_size
        self.steps_per_epoch = steps_per_epoch
        self.train_steps = train_steps
        self.name = name
        self.learning_rate_ops_cache = {}
        self.compute_lr_on_cpu = compute_lr_on_cpu

        if batch_size < 16384:
            self.initial_learning_rate = 10.0
            warmup_epochs_ = 5
        elif batch_size < 32768:
            self.initial_learning_rate = 25.0
            warmup_epochs_ = 5
        else:
            self.initial_learning_rate = 31.2
            warmup_epochs_ = 25

        # Override default poly learning rate and warmup epochs
        if initial_learning_rate:
            self.initial_learning_rate = initial_learning_rate

        if end_learning_rate:
            self.end_learning_rate = end_learning_rate
        else:
            self.end_learning_rate = 0.0001

        if warmup_epochs is not None:
            warmup_epochs_ = warmup_epochs
        self.warmup_epochs = warmup_epochs_

        opt_name = FLAGS.optimizer.lower()
        mlp_log.mlperf_print('opt_name', opt_name)
        if opt_name == 'lars':
            mlp_log.mlperf_print('{}_epsilon'.format(opt_name),
                                 FLAGS.lars_epsilon)
        mlp_log.mlperf_print('{}_opt_weight_decay'.format(opt_name),
                             FLAGS.weight_decay)
        mlp_log.mlperf_print('{}_opt_base_learning_rate'.format(opt_name),
                             self.initial_learning_rate)
        mlp_log.mlperf_print(
            '{}_opt_learning_rate_warmup_epochs'.format(opt_name),
            warmup_epochs_)
        mlp_log.mlperf_print('{}_opt_end_learning_rate'.format(opt_name),
                             self.end_learning_rate)
        warmup_steps = warmup_epochs_ * steps_per_epoch
        self.warmup_steps = tf.cast(warmup_steps, tf.float32)
        self.decay_steps = train_steps - warmup_steps + 1
        mlp_log.mlperf_print(
            '{}_opt_learning_rate_decay_steps'.format(opt_name),
            int(self.decay_steps))
        mlp_log.mlperf_print(
            '{}_opt_learning_rate_decay_poly_power'.format(opt_name), 2.0)
        mlp_log.mlperf_print('{}_opt_momentum'.format(opt_name),
                             FLAGS.momentum)

        self.poly_rate_scheduler = tf.keras.optimizers.schedules.PolynomialDecay(
            initial_learning_rate=self.initial_learning_rate,
            decay_steps=self.decay_steps,
            end_learning_rate=self.end_learning_rate,
            power=2.0)