def initialize(params: base_configs.ExperimentConfig,
               dataset_builder: dataset_factory.DatasetBuilder):
    """Initializes backend related initializations."""
    keras_utils.set_session_config(enable_eager=params.runtime.run_eagerly,
                                   enable_xla=params.runtime.enable_xla)
    if params.runtime.gpu_threads_enabled:
        keras_utils.set_gpu_thread_mode_and_count(
            per_gpu_thread_count=params.runtime.per_gpu_thread_count,
            gpu_thread_mode=params.runtime.gpu_thread_mode,
            num_gpus=params.runtime.num_gpus,
            datasets_num_private_threads=params.runtime.
            dataset_num_private_threads)

    performance.set_mixed_precision_policy(dataset_builder.dtype)
    if tf.config.list_physical_devices('GPU'):
        data_format = 'channels_first'
    else:
        data_format = 'channels_last'
    tf.keras.backend.set_image_data_format(data_format)
    distribution_utils.configure_cluster(params.runtime.worker_hosts,
                                         params.runtime.task_index)
    if params.runtime.run_eagerly:
        # Enable eager execution to allow step-by-step debugging
        tf.config.experimental_run_functions_eagerly(True)
Beispiel #2
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
    NotImplementedError: If some features are not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_utils.set_gpu_thread_mode_and_count(
            per_gpu_thread_count=flags_obj.per_gpu_thread_count,
            gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
            num_gpus=flags_obj.num_gpus,
            datasets_num_private_threads=flags_obj.datasets_num_private_threads
        )
    common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    performance.set_mixed_precision_policy(
        flags_core.get_tf_dtype(flags_obj),
        flags_core.get_loss_scale(flags_obj, default_for_fp16=128))

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    # Configures cluster spec for distribution strategy.
    _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                             flags_obj.task_index)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    # pylint: disable=protected-access
    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = common.get_synth_input_fn(
            height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            num_channels=imagenet_preprocessing.NUM_CHANNELS,
            num_classes=imagenet_preprocessing.NUM_CLASSES,
            dtype=dtype,
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = imagenet_preprocessing.input_fn

    # When `enable_xla` is True, we always drop the remainder of the batches
    # in the dataset, as XLA-GPU doesn't support dynamic shapes.
    drop_remainder = flags_obj.enable_xla

    # Current resnet_model.resnet50 input format is always channel-last.
    # We use keras_application mobilenet model which input format is depends on
    # the keras beckend image data format.
    # This use_keras_image_data_format flags indicates whether image preprocessor
    # output format should be same as the keras backend image data format or just
    # channel-last format.
    use_keras_image_data_format = (flags_obj.model == 'mobilenet')
    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        drop_remainder=drop_remainder,
        tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
        training_dataset_cache=flags_obj.training_dataset_cache,
    )

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=flags_obj.batch_size,
            parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
                use_keras_image_data_format=use_keras_image_data_format),
            dtype=dtype,
            drop_remainder=drop_remainder)

    lr_schedule = common.PiecewiseConstantDecayWithWarmup(
        batch_size=flags_obj.batch_size,
        epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
        warmup_epochs=common.LR_SCHEDULE[0][1],
        boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
        multipliers=list(p[0] for p in common.LR_SCHEDULE),
        compute_lr_on_cpu=True)
    steps_per_epoch = (imagenet_preprocessing.NUM_IMAGES['train'] //
                       flags_obj.batch_size)

    with strategy_scope:
        if flags_obj.optimizer == 'resnet50_default':
            optimizer = common.get_optimizer(lr_schedule)
        elif flags_obj.optimizer == 'mobilenet_default':
            initial_learning_rate = \
                flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
            optimizer = tf.keras.optimizers.SGD(
                learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
                    initial_learning_rate,
                    decay_steps=steps_per_epoch *
                    flags_obj.num_epochs_per_decay,
                    decay_rate=flags_obj.lr_decay_factor,
                    staircase=True),
                momentum=0.9)
        if flags_obj.fp16_implementation == 'graph_rewrite':
            # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
            # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
            # which will ensure tf.compat.v2.keras.mixed_precision and
            # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
            # up.
            optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                optimizer)

        # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
        if flags_obj.use_trivial_model:
            model = test_utils.trivial_model(
                imagenet_preprocessing.NUM_CLASSES)
        elif flags_obj.model == 'resnet50_v1.5':
            model = resnet_model.resnet50(
                num_classes=imagenet_preprocessing.NUM_CLASSES)
        elif flags_obj.model == 'mobilenet':
            # TODO(kimjaehong): Remove layers attribute when minimum TF version
            # support 2.0 layers by default.
            model = tf.keras.applications.mobilenet.MobileNet(
                weights=None,
                classes=imagenet_preprocessing.NUM_CLASSES,
                layers=tf.keras.layers)
        if flags_obj.pretrained_filepath:
            model.load_weights(flags_obj.pretrained_filepath)

        if flags_obj.pruning_method == 'polynomial_decay':
            if dtype != tf.float32:
                raise NotImplementedError(
                    'Pruning is currently only supported on dtype=tf.float32.')
            pruning_params = {
                'pruning_schedule':
                tfmot.sparsity.keras.PolynomialDecay(
                    initial_sparsity=flags_obj.pruning_initial_sparsity,
                    final_sparsity=flags_obj.pruning_final_sparsity,
                    begin_step=flags_obj.pruning_begin_step,
                    end_step=flags_obj.pruning_end_step,
                    frequency=flags_obj.pruning_frequency),
            }
            model = tfmot.sparsity.keras.prune_low_magnitude(
                model, **pruning_params)
        elif flags_obj.pruning_method:
            raise NotImplementedError(
                'Only polynomial_decay is currently supported.')

        model.compile(loss='sparse_categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=(['sparse_categorical_accuracy']
                               if flags_obj.report_accuracy_metrics else None),
                      run_eagerly=flags_obj.run_eagerly)

    train_epochs = flags_obj.train_epochs

    callbacks = common.get_callbacks(
        steps_per_epoch=steps_per_epoch,
        pruning_method=flags_obj.pruning_method,
        enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
        model_dir=flags_obj.model_dir)

    # if mutliple epochs, ignore the train_steps flag.
    if train_epochs <= 1 and flags_obj.train_steps:
        steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
        train_epochs = 1

    num_eval_steps = (imagenet_preprocessing.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        # Only build the training graph. This reduces memory usage introduced by
        # control flow ops in layers that have different implementations for
        # training and inference (e.g., batch norm).
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    if not strategy and flags_obj.explicit_gpu_placement:
        # TODO(b/135607227): Add device scope automatically in Keras training loop
        # when not using distribition strategy.
        no_dist_strat_device = tf.device('/device:GPU:0')
        no_dist_strat_device.__enter__()

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)

    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

    if flags_obj.pruning_method:
        model = tfmot.sparsity.keras.strip_pruning(model)
    if flags_obj.enable_checkpoint_and_export:
        if dtype == tf.bfloat16:
            logging.warning(
                'Keras model.save does not support bfloat16 dtype.')
        else:
            # Keras model.save assumes a float32 input designature.
            export_path = os.path.join(flags_obj.model_dir, 'saved_model')
            model.save(export_path, include_optimizer=False)

    if not strategy and flags_obj.explicit_gpu_placement:
        no_dist_strat_device.__exit__()

    stats = common.build_stats(history, eval_output, callbacks)
    return stats
Beispiel #3
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using custom training loops.

    Args:
      flags_obj: An object containing parsed flag values.

    Raises:
      ValueError: If fp16 is passed as it is not currently supported.

    Returns:
      Dictionary of training and eval stats.
    """

    keras_utils.set_session_config(
        enable_eager=flags_obj.enable_eager,
        enable_xla=flags_obj.enable_xla,
        enable_scoped_allocator=flags_obj.enable_scoped_allocator)
    # Enable habana bf16 conversion pass only if native keras mixed precision is disabled
    if flags.FLAGS.dtype == 'bf16' and flags.FLAGS.use_keras_mixed_precision == False:
        performance.set_mixed_precision_policy(tf.float32)
        os.environ['TF_BF16_CONVERSION'] = flags.FLAGS.bf16_config_path
    else:
        performance.set_mixed_precision_policy(
            flags_core.get_tf_dtype(flags_obj))

    os.environ.setdefault("TF_DISABLE_MKL", "1")
    os.environ.setdefault("TF_ALLOW_CONTROL_EDGES_IN_HABANA_OPS", "1")

    # This only affects GPU.
    common.set_cudnn_batchnorm_mode()

    # TODO(anj-s): Set data_format without using Keras.
    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)
    batch_size = adjust_batch_size(flags_obj.batch_size)

    if horovod_enabled():
        model_dir = os.path.join(flags_obj.model_dir,
                                 "worker_" + str(hvd.rank()))
    else:
        model_dir = flags_obj.model_dir

    hls_addresses = str(os.environ.get("MULTI_HLS_IPS",
                                       "127.0.0.1")).split(",")
    TF_BASE_PORT = 2410
    mpi_rank = comm_rank()
    mpi_size = comm_size()

    worker_hosts = ",".join([
        ",".join([
            address + ':' + str(TF_BASE_PORT + rank)
            for rank in range(mpi_size // len(hls_addresses))
        ]) for address in hls_addresses
    ])
    task_index = mpi_rank

    # Configures cluster spec for distribution strategy.
    _ = distribution_utils.configure_cluster(worker_hosts, task_index)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    train_writer, eval_writer = None, None
    if flags_obj.enable_tensorboard:
        train_writer = tf.summary.create_file_writer(model_dir)
        eval_writer = tf.summary.create_file_writer(
            os.path.join(model_dir, 'eval'))
        hparams = flags_obj.flag_values_dict()
        hparams.setdefault('precision', flags_obj.dtype)
        write_hparams_v2(train_writer, hparams)

    per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
        flags_obj)
    steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps)
    train_steps = train_epochs * per_epoch_steps

    logging.info(
        'Training %d epochs, each epoch has %d steps, '
        'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
        train_steps, eval_steps)

    time_callback = keras_utils.TimeHistory(
        batch_size,
        flags_obj.log_steps,
        summary_writer=train_writer,
        batch_size_per_node=flags_obj.batch_size)
    profiler_callback = None
    if flags_obj.profile_steps is not None:
        profiler_callback = keras_utils.get_profiler_callback(
            model_dir, flags_obj.profile_steps, flags_obj.enable_tensorboard,
            per_epoch_steps)
    with distribution_utils.get_strategy_scope(strategy):
        runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
                                                  train_steps, per_epoch_steps,
                                                  profiler_callback)

    eval_interval = flags_obj.epochs_between_evals * per_epoch_steps
    checkpoint_interval = (per_epoch_steps
                           if flags_obj.enable_checkpoint_and_export else None)
    summary_interval = flags_obj.log_steps if flags_obj.enable_tensorboard else None

    checkpoint_manager = tf.train.CheckpointManager(
        runnable.checkpoint,
        directory=model_dir,
        max_to_keep=10,
        step_counter=runnable.global_step,
        checkpoint_interval=checkpoint_interval)

    train_steps = per_epoch_steps * train_epochs

    resnet_controller = controller.Controller(
        strategy,
        runnable.train,
        runnable.evaluate,
        global_step=runnable.global_step,
        steps_per_loop=steps_per_loop,
        train_steps=train_steps,
        checkpoint_manager=checkpoint_manager,
        summary_interval=summary_interval,
        eval_steps=eval_steps,
        eval_interval=eval_interval,
        train_summary_writer=train_writer,
        eval_summary_writer=eval_writer)

    time_callback.on_train_begin()
    resnet_controller.train(evaluate=not flags_obj.skip_eval)
    time_callback.on_train_end()

    stats = build_stats(runnable, time_callback)
    return stats
Beispiel #4
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using custom training loops.

    Args:
      flags_obj: An object containing parsed flag values.

    Raises:
      ValueError: If fp16 is passed as it is not currently supported.

    Returns:
      Dictionary of training and eval stats.
    """
    tf.get_logger().propagate = False
    output_dir = None
    if "LOG_DIR" in os.environ:
        output_dir = os.environ["LOG_DIR"]
    mlperf_mlloger, mlperf_mllog = get_mllog_mlloger(output_dir)
    mlperf_mlloger.event(key=mlperf_mllog.constants.CACHE_CLEAR, value=True)
    mlperf_mlloger.start(key=mlperf_mllog.constants.INIT_START, value=None)
    mlperf_mlloger.event(key=mlperf_mllog.constants.SUBMISSION_BENCHMARK,
                         value=mlperf_mllog.constants.RESNET)
    mlperf_mlloger.event(key=mlperf_mllog.constants.SUBMISSION_ORG,
                         value='Habana')
    mlperf_mlloger.event(key=mlperf_mllog.constants.SUBMISSION_DIVISION,
                         value='closed')
    mlperf_mlloger.event(key=mlperf_mllog.constants.SUBMISSION_PLATFORM,
                         value='gaudi-{}'.format(flags_obj.num_gpus))
    mlperf_mlloger.event(key=mlperf_mllog.constants.SUBMISSION_STATUS,
                         value='onprem')

    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)
    performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))

    # This only affects GPU.
    common.set_cudnn_batchnorm_mode()

    # TODO(anj-s): Set data_format without using Keras.
    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    if horovod_enabled():
        model_dir = os.path.join(flags_obj.model_dir,
                                 "worker_" + str(hvd.rank()))
    else:
        model_dir = flags_obj.model_dir

    global_batch_size = get_global_batch_size(flags_obj.batch_size)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    mlperf_mlloger.event(key=mlperf_mllog.constants.GLOBAL_BATCH_SIZE,
                         value=global_batch_size)
    mlperf_mlloger.event(key=mlperf_mllog.constants.TRAIN_SAMPLES,
                         value=imagenet_preprocessing.NUM_IMAGES['train'])
    mlperf_mlloger.event(key=mlperf_mllog.constants.EVAL_SAMPLES,
                         value=imagenet_preprocessing.NUM_IMAGES['validation'])
    group_batch_norm = 1
    mlperf_mlloger.event(key=mlperf_mllog.constants.MODEL_BN_SPAN,
                         value=flags_obj.batch_size * group_batch_norm)

    train_writer, eval_writer = None, None
    if flags_obj.enable_tensorboard:
        train_writer = tf.summary.create_file_writer(model_dir)
        eval_writer = tf.summary.create_file_writer(
            os.path.join(model_dir, 'eval'))
        hparams = flags_obj.flag_values_dict()
        write_hparams_v2(train_writer, hparams)

    per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
        flags_obj)
    steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps)
    train_steps = train_epochs * per_epoch_steps

    logging.info(
        'Training %d epochs, each epoch has %d steps, '
        'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
        train_steps, eval_steps)

    time_callback = keras_utils.TimeHistory(
        global_batch_size,
        flags_obj.log_steps,
        summary_writer=train_writer,
        batch_size_per_node=flags_obj.batch_size)
    profiler_callback = None
    if flags_obj.profile_steps is not None:
        profiler_callback = keras_utils.get_profiler_callback(
            model_dir, flags_obj.profile_steps, flags_obj.enable_tensorboard,
            per_epoch_steps)
    with distribution_utils.get_strategy_scope(strategy):
        runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
                                                  train_steps, per_epoch_steps,
                                                  profiler_callback,
                                                  mlperf_mlloger, mlperf_mllog)

    eval_interval = flags_obj.epochs_between_evals * per_epoch_steps
    eval_offset = flags_obj.eval_offset_epochs * per_epoch_steps
    if eval_offset != 0:
        eval_offset -= eval_interval
    checkpoint_interval = (per_epoch_steps
                           if flags_obj.enable_checkpoint_and_export else None)
    summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None

    checkpoint_manager = tf.train.CheckpointManager(
        runnable.checkpoint,
        directory=model_dir,
        max_to_keep=10,
        step_counter=runnable.global_step,
        checkpoint_interval=checkpoint_interval)

    device_warmup_steps = (flags_obj.device_warmup_steps
                           if flags_obj.enable_device_warmup else 0)
    if flags_obj.enable_device_warmup:
        logging.info('Warmup for %d steps.', device_warmup_steps)

    train_steps = per_epoch_steps * train_epochs

    resnet_controller = controller.Controller(
        strategy,
        runnable.train,
        runnable.evaluate,
        runnable.warmup,
        global_step=runnable.global_step,
        steps_per_loop=steps_per_loop,
        train_steps=train_steps,
        checkpoint_manager=checkpoint_manager,
        summary_interval=summary_interval,
        eval_steps=eval_steps,
        eval_interval=eval_interval,
        eval_offset=eval_offset,
        device_warmup_steps=device_warmup_steps,
        train_summary_writer=train_writer,
        eval_summary_writer=eval_writer)

    if flags_obj.enable_device_warmup:
        resnet_controller.warmup()

    mlperf_mlloger.end(key=mlperf_mllog.constants.INIT_STOP)

    hvd.broadcast(0, 0)
    time_callback.on_train_begin()
    mlperf_mlloger.start(key=mlperf_mllog.constants.RUN_START)
    mlperf_mlloger.start(
        key=mlperf_mllog.constants.BLOCK_START,
        value=None,
        metadata={
            'first_epoch_num':
            1,
            'epoch_count':
            (flags_obj.eval_offset_epochs if flags_obj.eval_offset_epochs > 0
             else flags_obj.epochs_between_evals)
        })
    resnet_controller.train(evaluate=not flags_obj.skip_eval)
    if not flags_obj.skip_eval:
        eval_accuracy = resnet_controller.last_eval_output['test_accuracy']
        if eval_accuracy >= flags_obj.target_accuracy:
            mlperf_mlloger.end(key=mlperf_mllog.constants.RUN_STOP,
                               value=None,
                               metadata={'status': 'success'})
        else:
            mlperf_mlloger.end(key=mlperf_mllog.constants.RUN_STOP,
                               value=None,
                               metadata={'status': 'fail'})
    time_callback.on_train_end()

    stats = build_stats(runnable, time_callback)
    return stats
def run(flags_obj):
  """Run ResNet ImageNet training and eval loop using custom training loops.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  keras_utils.set_session_config(
      enable_eager=flags_obj.enable_eager,
      enable_xla=flags_obj.enable_xla)
  performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))

  # This only affects GPU.
  common.set_cudnn_batchnorm_mode()

  # TODO(anj-s): Set data_format without using Keras.
  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs,
      tpu_address=flags_obj.tpu)

  per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
      flags_obj)
  steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps)

  logging.info(
      'Training %d epochs, each epoch has %d steps, '
      'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
      train_epochs * per_epoch_steps, eval_steps)

  time_callback = keras_utils.TimeHistory(
      flags_obj.batch_size,
      flags_obj.log_steps,
      logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
  profiler_callback = None
  if flags_obj.profile_steps is not None:
    profiler_callback = keras_utils.get_profiler_callback(
        flags_obj.model_dir,
        flags_obj.profile_steps,
        flags_obj.enable_tensorboard,
        per_epoch_steps)
  with distribution_utils.get_strategy_scope(strategy):
    runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
                                              per_epoch_steps,
                                              profiler_callback)

  eval_interval = flags_obj.epochs_between_evals * per_epoch_steps
  checkpoint_interval = (
      per_epoch_steps if flags_obj.enable_checkpoint_and_export else None)
  summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None

  checkpoint_manager = tf.train.CheckpointManager(
      runnable.checkpoint,
      directory=flags_obj.model_dir,
      max_to_keep=10,
      step_counter=runnable.global_step,
      checkpoint_interval=checkpoint_interval)

  resnet_controller = controller.Controller(
      strategy,
      runnable.train,
      runnable.evaluate,
      global_step=runnable.global_step,
      steps_per_loop=steps_per_loop,
      train_steps=per_epoch_steps * train_epochs,
      checkpoint_manager=checkpoint_manager,
      summary_interval=summary_interval,
      eval_steps=eval_steps,
      eval_interval=eval_interval)

  time_callback.on_train_begin()
  resnet_controller.train(evaluate=not flags_obj.skip_eval)
  time_callback.on_train_end()

  stats = build_stats(runnable, time_callback)
  return stats