Exemplo n.º 1
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
    NotImplementedError: If some features are not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_utils.set_gpu_thread_mode_and_count(
            per_gpu_thread_count=flags_obj.per_gpu_thread_count,
            gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
            num_gpus=flags_obj.num_gpus,
            datasets_num_private_threads=flags_obj.datasets_num_private_threads
        )
    common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    performance.set_mixed_precision_policy(
        flags_core.get_tf_dtype(flags_obj),
        flags_core.get_loss_scale(flags_obj, default_for_fp16=128))

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    # Configures cluster spec for distribution strategy.
    _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                             flags_obj.task_index)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    # pylint: disable=protected-access
    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = common.get_synth_input_fn(
            height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            num_channels=imagenet_preprocessing.NUM_CHANNELS,
            num_classes=imagenet_preprocessing.NUM_CLASSES,
            dtype=dtype,
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = imagenet_preprocessing.input_fn

    # When `enable_xla` is True, we always drop the remainder of the batches
    # in the dataset, as XLA-GPU doesn't support dynamic shapes.
    drop_remainder = flags_obj.enable_xla

    # Current resnet_model.resnet50 input format is always channel-last.
    # We use keras_application mobilenet model which input format is depends on
    # the keras beckend image data format.
    # This use_keras_image_data_format flags indicates whether image preprocessor
    # output format should be same as the keras backend image data format or just
    # channel-last format.
    use_keras_image_data_format = (flags_obj.model == 'mobilenet')
    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        drop_remainder=drop_remainder,
        tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
        training_dataset_cache=flags_obj.training_dataset_cache,
    )

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=flags_obj.batch_size,
            parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
                use_keras_image_data_format=use_keras_image_data_format),
            dtype=dtype,
            drop_remainder=drop_remainder)

    lr_schedule = common.PiecewiseConstantDecayWithWarmup(
        batch_size=flags_obj.batch_size,
        epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
        warmup_epochs=common.LR_SCHEDULE[0][1],
        boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
        multipliers=list(p[0] for p in common.LR_SCHEDULE),
        compute_lr_on_cpu=True)
    steps_per_epoch = (imagenet_preprocessing.NUM_IMAGES['train'] //
                       flags_obj.batch_size)

    with strategy_scope:
        if flags_obj.optimizer == 'resnet50_default':
            optimizer = common.get_optimizer(lr_schedule)
        elif flags_obj.optimizer == 'mobilenet_default':
            initial_learning_rate = \
                flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
            optimizer = tf.keras.optimizers.SGD(
                learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
                    initial_learning_rate,
                    decay_steps=steps_per_epoch *
                    flags_obj.num_epochs_per_decay,
                    decay_rate=flags_obj.lr_decay_factor,
                    staircase=True),
                momentum=0.9)
        if flags_obj.fp16_implementation == 'graph_rewrite':
            # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
            # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
            # which will ensure tf.compat.v2.keras.mixed_precision and
            # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
            # up.
            optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                optimizer)

        # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
        if flags_obj.use_trivial_model:
            model = test_utils.trivial_model(
                imagenet_preprocessing.NUM_CLASSES)
        elif flags_obj.model == 'resnet50_v1.5':
            model = resnet_model.resnet50(
                num_classes=imagenet_preprocessing.NUM_CLASSES)
        elif flags_obj.model == 'mobilenet':
            # TODO(kimjaehong): Remove layers attribute when minimum TF version
            # support 2.0 layers by default.
            model = tf.keras.applications.mobilenet.MobileNet(
                weights=None,
                classes=imagenet_preprocessing.NUM_CLASSES,
                layers=tf.keras.layers)
        if flags_obj.pretrained_filepath:
            model.load_weights(flags_obj.pretrained_filepath)

        if flags_obj.pruning_method == 'polynomial_decay':
            if dtype != tf.float32:
                raise NotImplementedError(
                    'Pruning is currently only supported on dtype=tf.float32.')
            pruning_params = {
                'pruning_schedule':
                tfmot.sparsity.keras.PolynomialDecay(
                    initial_sparsity=flags_obj.pruning_initial_sparsity,
                    final_sparsity=flags_obj.pruning_final_sparsity,
                    begin_step=flags_obj.pruning_begin_step,
                    end_step=flags_obj.pruning_end_step,
                    frequency=flags_obj.pruning_frequency),
            }
            model = tfmot.sparsity.keras.prune_low_magnitude(
                model, **pruning_params)
        elif flags_obj.pruning_method:
            raise NotImplementedError(
                'Only polynomial_decay is currently supported.')

        model.compile(loss='sparse_categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=(['sparse_categorical_accuracy']
                               if flags_obj.report_accuracy_metrics else None),
                      run_eagerly=flags_obj.run_eagerly)

    train_epochs = flags_obj.train_epochs

    callbacks = common.get_callbacks(
        steps_per_epoch=steps_per_epoch,
        pruning_method=flags_obj.pruning_method,
        enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
        model_dir=flags_obj.model_dir)

    # if mutliple epochs, ignore the train_steps flag.
    if train_epochs <= 1 and flags_obj.train_steps:
        steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
        train_epochs = 1

    num_eval_steps = (imagenet_preprocessing.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        # Only build the training graph. This reduces memory usage introduced by
        # control flow ops in layers that have different implementations for
        # training and inference (e.g., batch norm).
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    if not strategy and flags_obj.explicit_gpu_placement:
        # TODO(b/135607227): Add device scope automatically in Keras training loop
        # when not using distribition strategy.
        no_dist_strat_device = tf.device('/device:GPU:0')
        no_dist_strat_device.__enter__()

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)

    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

    if flags_obj.pruning_method:
        model = tfmot.sparsity.keras.strip_pruning(model)
    if flags_obj.enable_checkpoint_and_export:
        if dtype == tf.bfloat16:
            logging.warning(
                'Keras model.save does not support bfloat16 dtype.')
        else:
            # Keras model.save assumes a float32 input designature.
            export_path = os.path.join(flags_obj.model_dir, 'saved_model')
            model.save(export_path, include_optimizer=False)

    if not strategy and flags_obj.explicit_gpu_placement:
        no_dist_strat_device.__exit__()

    stats = common.build_stats(history, eval_output, callbacks)
    return stats
Exemplo n.º 2
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using custom training loops.

    Args:
      flags_obj: An object containing parsed flag values.

    Raises:
      ValueError: If fp16 is passed as it is not currently supported.

    Returns:
      Dictionary of training and eval stats.
    """
    tf.get_logger().propagate = False
    output_dir = None
    if "LOG_DIR" in os.environ:
        output_dir = os.environ["LOG_DIR"]
    mlperf_mlloger, mlperf_mllog = get_mllog_mlloger(output_dir)
    mlperf_mlloger.event(key=mlperf_mllog.constants.CACHE_CLEAR, value=True)
    mlperf_mlloger.start(key=mlperf_mllog.constants.INIT_START, value=None)
    mlperf_mlloger.event(key=mlperf_mllog.constants.SUBMISSION_BENCHMARK,
                         value=mlperf_mllog.constants.RESNET)
    mlperf_mlloger.event(key=mlperf_mllog.constants.SUBMISSION_ORG,
                         value='Habana')
    mlperf_mlloger.event(key=mlperf_mllog.constants.SUBMISSION_DIVISION,
                         value='closed')
    mlperf_mlloger.event(key=mlperf_mllog.constants.SUBMISSION_PLATFORM,
                         value='gaudi-{}'.format(flags_obj.num_gpus))
    mlperf_mlloger.event(key=mlperf_mllog.constants.SUBMISSION_STATUS,
                         value='onprem')

    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)
    performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))

    # This only affects GPU.
    common.set_cudnn_batchnorm_mode()

    # TODO(anj-s): Set data_format without using Keras.
    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    if horovod_enabled():
        model_dir = os.path.join(flags_obj.model_dir,
                                 "worker_" + str(hvd.rank()))
    else:
        model_dir = flags_obj.model_dir

    global_batch_size = get_global_batch_size(flags_obj.batch_size)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    mlperf_mlloger.event(key=mlperf_mllog.constants.GLOBAL_BATCH_SIZE,
                         value=global_batch_size)
    mlperf_mlloger.event(key=mlperf_mllog.constants.TRAIN_SAMPLES,
                         value=imagenet_preprocessing.NUM_IMAGES['train'])
    mlperf_mlloger.event(key=mlperf_mllog.constants.EVAL_SAMPLES,
                         value=imagenet_preprocessing.NUM_IMAGES['validation'])
    group_batch_norm = 1
    mlperf_mlloger.event(key=mlperf_mllog.constants.MODEL_BN_SPAN,
                         value=flags_obj.batch_size * group_batch_norm)

    train_writer, eval_writer = None, None
    if flags_obj.enable_tensorboard:
        train_writer = tf.summary.create_file_writer(model_dir)
        eval_writer = tf.summary.create_file_writer(
            os.path.join(model_dir, 'eval'))
        hparams = flags_obj.flag_values_dict()
        write_hparams_v2(train_writer, hparams)

    per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
        flags_obj)
    steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps)
    train_steps = train_epochs * per_epoch_steps

    logging.info(
        'Training %d epochs, each epoch has %d steps, '
        'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
        train_steps, eval_steps)

    time_callback = keras_utils.TimeHistory(
        global_batch_size,
        flags_obj.log_steps,
        summary_writer=train_writer,
        batch_size_per_node=flags_obj.batch_size)
    profiler_callback = None
    if flags_obj.profile_steps is not None:
        profiler_callback = keras_utils.get_profiler_callback(
            model_dir, flags_obj.profile_steps, flags_obj.enable_tensorboard,
            per_epoch_steps)
    with distribution_utils.get_strategy_scope(strategy):
        runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
                                                  train_steps, per_epoch_steps,
                                                  profiler_callback,
                                                  mlperf_mlloger, mlperf_mllog)

    eval_interval = flags_obj.epochs_between_evals * per_epoch_steps
    eval_offset = flags_obj.eval_offset_epochs * per_epoch_steps
    if eval_offset != 0:
        eval_offset -= eval_interval
    checkpoint_interval = (per_epoch_steps
                           if flags_obj.enable_checkpoint_and_export else None)
    summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None

    checkpoint_manager = tf.train.CheckpointManager(
        runnable.checkpoint,
        directory=model_dir,
        max_to_keep=10,
        step_counter=runnable.global_step,
        checkpoint_interval=checkpoint_interval)

    device_warmup_steps = (flags_obj.device_warmup_steps
                           if flags_obj.enable_device_warmup else 0)
    if flags_obj.enable_device_warmup:
        logging.info('Warmup for %d steps.', device_warmup_steps)

    train_steps = per_epoch_steps * train_epochs

    resnet_controller = controller.Controller(
        strategy,
        runnable.train,
        runnable.evaluate,
        runnable.warmup,
        global_step=runnable.global_step,
        steps_per_loop=steps_per_loop,
        train_steps=train_steps,
        checkpoint_manager=checkpoint_manager,
        summary_interval=summary_interval,
        eval_steps=eval_steps,
        eval_interval=eval_interval,
        eval_offset=eval_offset,
        device_warmup_steps=device_warmup_steps,
        train_summary_writer=train_writer,
        eval_summary_writer=eval_writer)

    if flags_obj.enable_device_warmup:
        resnet_controller.warmup()

    mlperf_mlloger.end(key=mlperf_mllog.constants.INIT_STOP)

    hvd.broadcast(0, 0)
    time_callback.on_train_begin()
    mlperf_mlloger.start(key=mlperf_mllog.constants.RUN_START)
    mlperf_mlloger.start(
        key=mlperf_mllog.constants.BLOCK_START,
        value=None,
        metadata={
            'first_epoch_num':
            1,
            'epoch_count':
            (flags_obj.eval_offset_epochs if flags_obj.eval_offset_epochs > 0
             else flags_obj.epochs_between_evals)
        })
    resnet_controller.train(evaluate=not flags_obj.skip_eval)
    if not flags_obj.skip_eval:
        eval_accuracy = resnet_controller.last_eval_output['test_accuracy']
        if eval_accuracy >= flags_obj.target_accuracy:
            mlperf_mlloger.end(key=mlperf_mllog.constants.RUN_STOP,
                               value=None,
                               metadata={'status': 'success'})
        else:
            mlperf_mlloger.end(key=mlperf_mllog.constants.RUN_STOP,
                               value=None,
                               metadata={'status': 'fail'})
    time_callback.on_train_end()

    stats = build_stats(runnable, time_callback)
    return stats
def train_and_eval(
        params: base_configs.ExperimentConfig,
        strategy_override: tf.distribute.Strategy) -> Mapping[str, Any]:
    """Runs the train and eval path using compile/fit."""
    logging.info('Running train and eval.')

    # Note: for TPUs, strategy and scope should be created before the dataset
    strategy = strategy_override or distribution_utils.get_distribution_strategy(
        distribution_strategy=params.runtime.distribution_strategy,
        all_reduce_alg=params.runtime.all_reduce_alg,
        num_gpus=params.runtime.num_gpus,
        tpu_address=params.runtime.tpu)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    logging.info('Detected %d devices.',
                 strategy.num_replicas_in_sync if strategy else 1)

    label_smoothing = params.model.loss.label_smoothing
    one_hot = label_smoothing and label_smoothing > 0

    builders = _get_dataset_builders(params, strategy, one_hot)
    datasets = [builder.build() if builder else None for builder in builders]

    # Unpack datasets and builders based on train/val/test splits
    train_builder, validation_builder = builders  # pylint: disable=unbalanced-tuple-unpacking
    train_dataset, validation_dataset = datasets

    train_epochs = params.train.epochs
    train_steps = params.train.steps or train_builder.num_steps
    validation_steps = params.evaluation.steps or validation_builder.num_steps

    initialize(params, train_builder)

    logging.info('Global batch size: %d', train_builder.global_batch_size)

    with strategy_scope:
        model_params = params.model.model_params.as_dict()
        model = get_models()[params.model.name](**model_params)
        learning_rate = optimizer_factory.build_learning_rate(
            params=params.model.learning_rate,
            batch_size=train_builder.global_batch_size,
            train_steps=train_steps)
        optimizer = optimizer_factory.build_optimizer(
            optimizer_name=params.model.optimizer.name,
            base_learning_rate=learning_rate,
            params=params.model.optimizer.as_dict())

        metrics_map = _get_metrics(one_hot)
        metrics = [metrics_map[metric] for metric in params.train.metrics]

        if one_hot:
            loss_obj = tf.keras.losses.CategoricalCrossentropy(
                label_smoothing=params.model.loss.label_smoothing)
        else:
            loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
        model.compile(optimizer=optimizer, loss=loss_obj, metrics=metrics)

        initial_epoch = 0
        if params.train.resume_checkpoint:
            initial_epoch = resume_from_checkpoint(model=model,
                                                   model_dir=params.model_dir,
                                                   train_steps=train_steps)

    serialize_config(params=params, model_dir=params.model_dir)
    # TODO(dankondratyuk): callbacks significantly slow down training
    callbacks = custom_callbacks.get_callbacks(
        model_checkpoint=params.train.callbacks.enable_checkpoint_and_export,
        include_tensorboard=params.train.callbacks.enable_tensorboard,
        time_history=params.train.callbacks.enable_time_history,
        track_lr=params.train.tensorboard.track_lr,
        write_model_weights=params.train.tensorboard.write_model_weights,
        initial_step=initial_epoch * train_steps,
        batch_size=train_builder.global_batch_size,
        log_steps=params.train.time_history.log_steps,
        model_dir=params.model_dir)

    if params.evaluation.skip_eval:
        validation_kwargs = {}
    else:
        validation_kwargs = {
            'validation_data': validation_dataset,
            'validation_steps': validation_steps,
            'validation_freq': params.evaluation.epochs_between_evals,
        }

    history = model.fit(train_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        initial_epoch=initial_epoch,
                        callbacks=callbacks,
                        **validation_kwargs)

    validation_output = None
    if not params.evaluation.skip_eval:
        validation_output = model.evaluate(validation_dataset,
                                           steps=validation_steps,
                                           verbose=2)

    # TODO(dankondratyuk): eval and save final test accuracy
    stats = common.build_stats(history, validation_output, callbacks)
    return stats
def run(flags_obj):
  """Run ResNet ImageNet training and eval loop using custom training loops.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  keras_utils.set_session_config(
      enable_eager=flags_obj.enable_eager,
      enable_xla=flags_obj.enable_xla)
  performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))

  # This only affects GPU.
  common.set_cudnn_batchnorm_mode()

  # TODO(anj-s): Set data_format without using Keras.
  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs,
      tpu_address=flags_obj.tpu)

  per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
      flags_obj)
  steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps)

  logging.info(
      'Training %d epochs, each epoch has %d steps, '
      'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
      train_epochs * per_epoch_steps, eval_steps)

  time_callback = keras_utils.TimeHistory(
      flags_obj.batch_size,
      flags_obj.log_steps,
      logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
  profiler_callback = None
  if flags_obj.profile_steps is not None:
    profiler_callback = keras_utils.get_profiler_callback(
        flags_obj.model_dir,
        flags_obj.profile_steps,
        flags_obj.enable_tensorboard,
        per_epoch_steps)
  with distribution_utils.get_strategy_scope(strategy):
    runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
                                              per_epoch_steps,
                                              profiler_callback)

  eval_interval = flags_obj.epochs_between_evals * per_epoch_steps
  checkpoint_interval = (
      per_epoch_steps if flags_obj.enable_checkpoint_and_export else None)
  summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None

  checkpoint_manager = tf.train.CheckpointManager(
      runnable.checkpoint,
      directory=flags_obj.model_dir,
      max_to_keep=10,
      step_counter=runnable.global_step,
      checkpoint_interval=checkpoint_interval)

  resnet_controller = controller.Controller(
      strategy,
      runnable.train,
      runnable.evaluate,
      global_step=runnable.global_step,
      steps_per_loop=steps_per_loop,
      train_steps=per_epoch_steps * train_epochs,
      checkpoint_manager=checkpoint_manager,
      summary_interval=summary_interval,
      eval_steps=eval_steps,
      eval_interval=eval_interval)

  time_callback.on_train_begin()
  resnet_controller.train(evaluate=not flags_obj.skip_eval)
  time_callback.on_train_end()

  stats = build_stats(runnable, time_callback)
  return stats
def run_customized_training_loop(
    # pylint: disable=invalid-name
    _sentinel=None,
    # pylint: enable=invalid-name
    strategy=None,
    model_fn=None,
    loss_fn=None,
    scale_loss=True,
    model_dir=None,
    train_input_fn=None,
    steps_per_epoch=None,
    steps_per_loop=1,
    epochs=1,
    eval_input_fn=None,
    eval_steps=None,
    metric_fn=None,
    init_checkpoint=None,
    custom_callbacks=None,
    run_eagerly=False,
    sub_model_export_name=None,
    explicit_allreduce=False,
    pre_allreduce_callbacks=None,
    post_allreduce_callbacks=None):
  """Run BERT pretrain model training using low-level API.

  Arguments:
      _sentinel: Used to prevent positional parameters. Internal, do not use.
      strategy: Distribution strategy on which to run low level training loop.
      model_fn: Function that returns a tuple (model, sub_model). Caller of this
        function should add optimizer to the `model` via calling
        `model.compile()` API or manually setting `model.optimizer` attribute.
        Second element of the returned tuple(sub_model) is an optional sub model
        to be used for initial checkpoint -- if provided.
      loss_fn: Function with signature func(labels, logits) and returns a loss
        tensor.
      scale_loss: Whether to divide the raw loss by number of replicas before
        gradients calculation.
      model_dir: Model directory used during training for restoring/saving model
        weights.
      train_input_fn: Function that returns a tf.data.Dataset used for training.
      steps_per_epoch: Number of steps to run per epoch. At the end of each
        epoch, model checkpoint will be saved and evaluation will be conducted
        if evaluation dataset is provided.
      steps_per_loop: Number of steps per graph-mode loop. In order to reduce
        communication in eager context, training logs are printed every
        steps_per_loop.
      epochs: Number of epochs to train.
      eval_input_fn: Function that returns evaluation dataset. If none,
        evaluation is skipped.
      eval_steps: Number of steps to run evaluation. Required if `eval_input_fn`
        is not none.
      metric_fn: A metrics function that returns a Keras Metric object to record
        evaluation result using evaluation dataset or with training dataset
        after every epoch.
      init_checkpoint: Optional checkpoint to load to `sub_model` returned by
        `model_fn`.
      custom_callbacks: A list of Keras Callbacks objects to run during
        training. More specifically, `on_batch_begin()`, `on_batch_end()`,
        methods are invoked during training.
      run_eagerly: Whether to run model training in pure eager execution. This
        should be disable for TPUStrategy.
      sub_model_export_name: If not None, will export `sub_model` returned by
        `model_fn` into checkpoint files. The name of intermediate checkpoint
        file is {sub_model_export_name}_step_{step}.ckpt and the last
        checkpint's name is {sub_model_export_name}.ckpt;
        if None, `sub_model` will not be exported as checkpoint.
      explicit_allreduce: Whether to explicitly perform gradient allreduce,
        instead of relying on implicit allreduce in optimizer.apply_gradients().
        default is False. For now, if training using FP16 mixed precision,
        explicit allreduce will aggregate gradients in FP16 format. For TPU and
        GPU training using FP32, explicit allreduce will aggregate gradients in
        FP32 format.
      pre_allreduce_callbacks: A list of callback functions that takes gradients
        and model variables pairs as input, manipulate them, and returns a new
        gradients and model variables paris. The callback functions will be
        invoked in the list order and before gradients are allreduced.
        With mixed precision training, the pre_allreduce_allbacks will be
        applied on scaled_gradients. Default is no callbacks.
        Only used when explicit_allreduce=True.
      post_allreduce_callbacks: A list of callback functions that takes
        gradients and model variables pairs as input, manipulate them, and
        returns a new gradients and model variables paris. The callback
        functions will be invoked in the list order and right before gradients
        are applied to variables for updates. Default is no callbacks. Only used
        when explicit_allreduce=True.

  Returns:
      Trained model.

  Raises:
      ValueError: (1) When model returned by `model_fn` does not have optimizer
        attribute or when required parameters are set to none. (2) eval args are
        not specified correctly. (3) metric_fn must be a callable if specified.
        (4) sub_model_checkpoint_name is specified, but `sub_model` returned
        by `model_fn` is None.
  """

  if _sentinel is not None:
    raise ValueError('only call `run_customized_training_loop()` '
                     'with named arguments.')

  required_arguments = [
      strategy, model_fn, loss_fn, model_dir, steps_per_epoch, train_input_fn
  ]
  if [arg for arg in required_arguments if arg is None]:
    raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, '
                     '`steps_per_loop` and `steps_per_epoch` are required '
                     'parameters.')
  if steps_per_loop > steps_per_epoch:
    logging.error(
        'steps_per_loop: %d is specified to be greater than '
        ' steps_per_epoch: %d, we will use steps_per_epoch as'
        ' steps_per_loop.', steps_per_loop, steps_per_epoch)
    steps_per_loop = steps_per_epoch
  assert tf.executing_eagerly()

  if run_eagerly:
    if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
      raise ValueError(
          'TPUStrategy should not run eagerly as it heavily relies on graph'
          ' optimization for the distributed system.')

  if eval_input_fn and (eval_steps is None or metric_fn is None):
    raise ValueError(
        '`eval_step` and `metric_fn` are required when `eval_input_fn ` '
        'is not none.')
  if metric_fn and not callable(metric_fn):
    raise ValueError(
        'if `metric_fn` is specified, metric_fn must be a callable.')

  total_training_steps = steps_per_epoch * epochs
  train_iterator = _get_input_iterator(train_input_fn, strategy)

  with distribution_utils.get_strategy_scope(strategy):
    # To correctly place the model weights on accelerators,
    # model and optimizer should be created in scope.
    model, sub_model = model_fn()
    if not hasattr(model, 'optimizer'):
      raise ValueError('User should set optimizer attribute to model '
                       'inside `model_fn`.')
    if sub_model_export_name and sub_model is None:
      raise ValueError('sub_model_export_name is specified as %s, but '
                       'sub_model is None.' % sub_model_export_name)

    optimizer = model.optimizer

    if init_checkpoint:
      logging.info(
          'Checkpoint file %s found and restoring from '
          'initial checkpoint for core model.', init_checkpoint)
      checkpoint = tf.train.Checkpoint(model=sub_model)
      checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
      logging.info('Loading from checkpoint file completed')

    train_loss_metric = tf.keras.metrics.Mean(
        'training_loss', dtype=tf.float32)
    eval_metrics = [metric_fn()] if metric_fn else []
    # If evaluation is required, make a copy of metric as it will be used by
    # both train and evaluation.
    train_metrics = [
        metric.__class__.from_config(metric.get_config())
        for metric in eval_metrics
    ]

    # Create summary writers
    if _should_export_summary(strategy):
      summary_dir = os.path.join(model_dir, 'summaries')
    else:
      # In multi worker training we need every worker to write summary, because
      # variables can trigger synchronization on read and synchronization needs
      # all workers to participate.
      summary_dir = tempfile.mkdtemp()
    eval_summary_writer = tf.summary.create_file_writer(
        os.path.join(summary_dir, 'eval'))
    if steps_per_loop >= _MIN_SUMMARY_STEPS:
      # Only writes summary when the stats are collected sufficiently over
      # enough steps.
      train_summary_writer = tf.summary.create_file_writer(
          os.path.join(summary_dir, 'train'))
    else:
      train_summary_writer = None

    # Collects training variables.
    training_vars = model.trainable_variables

    def _replicated_step(inputs):
      """Replicated training step."""

      inputs, labels = inputs
      with tf.GradientTape() as tape:
        model_outputs = model(inputs, training=True)
        loss = loss_fn(labels, model_outputs)
        # Raw loss is used for reporting in metrics/logs.
        raw_loss = loss
        if scale_loss:
          # Scales down the loss for gradients to be invariant from replicas.
          loss = loss / strategy.num_replicas_in_sync

      if explicit_allreduce:
        grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss,
                                                     training_vars,
                                                     pre_allreduce_callbacks,
                                                     post_allreduce_callbacks)
      else:
        if isinstance(optimizer,
                      tf.keras.mixed_precision.experimental.LossScaleOptimizer):
          with tape:
            scaled_loss = optimizer.get_scaled_loss(loss)
          scaled_grads = tape.gradient(scaled_loss, training_vars)
          grads = optimizer.get_unscaled_gradients(scaled_grads)
        else:
          grads = tape.gradient(loss, training_vars)
        optimizer.apply_gradients(zip(grads, training_vars))
      # For reporting, the metric takes the mean of losses.
      train_loss_metric.update_state(raw_loss)
      for metric in train_metrics:
        metric.update_state(labels, model_outputs)

    @tf.function
    def train_steps(iterator, steps):
      """Performs distributed training steps in a loop.

      Args:
        iterator: the distributed iterator of training datasets.
        steps: an tf.int32 integer tensor to specify number of steps to run
          inside host training loop.

      Raises:
        ValueError: Any of the arguments or tensor shapes are invalid.
      """
      if not isinstance(steps, tf.Tensor):
        raise ValueError('steps should be an Tensor. Python object may cause '
                         'retracing.')

      for _ in tf.range(steps):
        strategy.run(_replicated_step, args=(next(iterator),))

    def train_single_step(iterator):
      """Performs a distributed training step.

      Args:
        iterator: the distributed iterator of training datasets.

      Raises:
        ValueError: Any of the arguments or tensor shapes are invalid.
      """
      strategy.run(_replicated_step, args=(next(iterator),))

    def test_step(iterator):
      """Calculates evaluation metrics on distributed devices."""

      def _test_step_fn(inputs):
        """Replicated accuracy calculation."""

        inputs, labels = inputs
        model_outputs = model(inputs, training=False)
        for metric in eval_metrics:
          metric.update_state(labels, model_outputs)

      strategy.run(_test_step_fn, args=(next(iterator),))

    if not run_eagerly:
      train_single_step = tf.function(train_single_step)
      test_step = tf.function(test_step)

    def _run_evaluation(current_training_step, test_iterator):
      """Runs validation steps and aggregate metrics."""
      for _ in range(eval_steps):
        test_step(test_iterator)

      with eval_summary_writer.as_default():
        for metric in eval_metrics + model.metrics:
          metric_value = _float_metric_value(metric)
          logging.info('Step: [%d] Validation %s = %f', current_training_step,
                       metric.name, metric_value)
          tf.summary.scalar(
              metric.name, metric_value, step=current_training_step)
        eval_summary_writer.flush()

    def _run_callbacks_on_batch_begin(batch):
      """Runs custom callbacks at the start of every step."""
      if not custom_callbacks:
        return
      for callback in custom_callbacks:
        callback.on_batch_begin(batch)

    def _run_callbacks_on_batch_end(batch, logs):
      """Runs custom callbacks at the end of every step."""
      if not custom_callbacks:
        return
      for callback in custom_callbacks:
        callback.on_batch_end(batch, logs)

    # Training loop starts here.
    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
    sub_model_checkpoint = tf.train.Checkpoint(
        model=sub_model) if sub_model_export_name else None

    latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
    if latest_checkpoint_file:
      logging.info(
          'Checkpoint file %s found and restoring from '
          'checkpoint', latest_checkpoint_file)
      checkpoint.restore(latest_checkpoint_file)
      logging.info('Loading from checkpoint file completed')

    current_step = optimizer.iterations.numpy()
    checkpoint_name = 'ctl_step_{step}.ckpt'

    while current_step < total_training_steps:
      # Training loss/metric are taking average over steps inside micro
      # training loop. We reset the their values before each round.
      train_loss_metric.reset_states()
      for metric in train_metrics + model.metrics:
        metric.reset_states()

      _run_callbacks_on_batch_begin(current_step)
      # Runs several steps in the host while loop.
      steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop)

      if tf.test.is_built_with_cuda():
        # TODO(zongweiz): merge with train_steps once tf.while_loop
        # GPU performance bugs are fixed.
        for _ in range(steps):
          train_single_step(train_iterator)
      else:
        # Converts steps to a Tensor to avoid tf.function retracing.
        train_steps(train_iterator,
                    tf.convert_to_tensor(steps, dtype=tf.int32))
      train_loss = _float_metric_value(train_loss_metric)
      current_step += steps
      _run_callbacks_on_batch_end(current_step - 1, {'loss': train_loss})

      # Updates training logging.
      training_status = 'Train Step: %d/%d  / loss = %s' % (
          current_step, total_training_steps, train_loss)

      if train_summary_writer:
        with train_summary_writer.as_default():
          tf.summary.scalar(
              train_loss_metric.name, train_loss, step=current_step)
          for metric in train_metrics + model.metrics:
            metric_value = _float_metric_value(metric)
            training_status += '  %s = %f' % (metric.name, metric_value)
            tf.summary.scalar(metric.name, metric_value, step=current_step)
          train_summary_writer.flush()
      logging.info(training_status)

      # Saves model checkpoints and run validation steps at every epoch end.
      if current_step % steps_per_epoch == 0:
        # To avoid repeated model saving, we do not save after the last
        # step of training.
        if current_step < total_training_steps:
          _save_checkpoint(strategy, checkpoint, model_dir,
                           checkpoint_name.format(step=current_step))
          if sub_model_export_name:
            _save_checkpoint(
                strategy, sub_model_checkpoint, model_dir,
                '%s_step_%d.ckpt' % (sub_model_export_name, current_step))
        if eval_input_fn:
          logging.info('Running evaluation after step: %s.', current_step)
          _run_evaluation(current_step,
                          _get_input_iterator(eval_input_fn, strategy))
          # Re-initialize evaluation metric.
          for metric in eval_metrics + model.metrics:
            metric.reset_states()

    _save_checkpoint(strategy, checkpoint, model_dir,
                     checkpoint_name.format(step=current_step))
    if sub_model_export_name:
      _save_checkpoint(strategy, sub_model_checkpoint, model_dir,
                       '%s.ckpt' % sub_model_export_name)

    if eval_input_fn:
      logging.info('Running final evaluation after training is complete.')
      _run_evaluation(current_step,
                      _get_input_iterator(eval_input_fn, strategy))

    training_summary = {
        'total_training_steps': total_training_steps,
        'train_loss': _float_metric_value(train_loss_metric),
    }
    if eval_metrics:
      # TODO(hongkuny): Cleans up summary reporting in text.
      training_summary['last_train_metrics'] = _float_metric_value(
          train_metrics[0])
      training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0])

    write_txt_summary(training_summary, summary_dir)

    if not _should_export_summary(strategy):
      tf.io.gfile.rmtree(summary_dir)

    return model