コード例 #1
0
def run(flags_obj):
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
    NotImplementedError: If some features are not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  mlp_log.mlperf_print('init_start', None)
  common.print_flags(flags_obj)

  keras_utils.set_session_config(
      enable_eager=flags_obj.enable_eager,
      enable_xla=flags_obj.enable_xla)

  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
    keras_utils.set_gpu_thread_mode_and_count(
        per_gpu_thread_count=flags_obj.per_gpu_thread_count,
        gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
        num_gpus=flags_obj.num_gpus,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads)
  common.set_cudnn_batchnorm_mode()

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == tf.float16:
    loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
    policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
        'mixed_float16', loss_scale=loss_scale)
    tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
    if not keras_utils.is_v2_0():
      raise ValueError('--dtype=fp16 is not supported in TensorFlow 1.')
  elif dtype == tf.bfloat16:
    policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
        'mixed_bfloat16')
    tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  # Configures cluster spec for distribution strategy.
  _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                           flags_obj.task_index)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs,
      tpu_address=flags_obj.tpu,
      tpu_zone=flags_obj.tpu_zone if flags_obj.tpu else None)

  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )

  strategy_scope = distribution_utils.get_strategy_scope(strategy)

  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
    distribution_utils.set_up_synthetic_data()
    input_fn = common.get_synth_input_fn(
        height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_preprocessing.NUM_CHANNELS,
        num_classes=flags_obj.num_classes,
        dtype=dtype,
        drop_remainder=True)
  else:
    distribution_utils.undo_set_up_synthetic_data()
    input_fn = imagenet_preprocessing.input_fn

  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  # drop_remainder = flags_obj.enable_xla

  # Current resnet_model.resnet50 input format is always channel-last.
  # We use keras_application mobilenet model which input format is depends on
  # the keras beckend image data format.
  # This use_keras_image_data_format flags indicates whether image preprocessor
  # output format should be same as the keras backend image data format or just
  # channel-last format.
  use_keras_image_data_format = (flags_obj.model == 'mobilenet')
  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
          use_keras_image_data_format=use_keras_image_data_format),
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
      dtype=dtype,
      drop_remainder=flags_obj.drop_train_remainder,
      tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
      training_dataset_cache=flags_obj.training_dataset_cache,
  )

  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
        dtype=dtype,
        drop_remainder=flags_obj.drop_eval_remainder)

  steps_per_epoch, train_epochs = common.get_num_train_iterations(flags_obj)

  mlp_log.mlperf_print('global_batch_size', flags_obj.batch_size)
  mlp_log.mlperf_print('num_train_examples',
                       imagenet_preprocessing.NUM_IMAGES['train'])
  mlp_log.mlperf_print('num_eval_examples',
                       imagenet_preprocessing.NUM_IMAGES['validation'])

  learning_rate_schedule_fn = None
  with strategy_scope:
    optimizer, learning_rate_schedule_fn = common.get_optimizer(
        flags_obj=flags_obj,
        steps_per_epoch=steps_per_epoch,
        train_steps=train_epochs * steps_per_epoch)
    if flags_obj.fp16_implementation == 'graph_rewrite':
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          optimizer)

    if flags_obj.model == 'resnet50_v1.5':
      resnet_model.change_keras_layer(flags_obj.use_tf_keras_layers)
      model = resnet_model.resnet50(num_classes=flags_obj.num_classes)
    elif flags_obj.model == 'mobilenet':
      # TODO(kimjaehong): Remove layers attribute when minimum TF version
      # support 2.0 layers by default.
      model = tf.keras.applications.mobilenet.MobileNet(
          weights=None, classes=flags_obj.num_classes, layers=tf.keras.layers)
    if flags_obj.pretrained_filepath:
      model.load_weights(flags_obj.pretrained_filepath)

    if flags_obj.pruning_method == 'polynomial_decay':
      if dtype != tf.float32:
        raise NotImplementedError(
            'Pruning is currently only supported on dtype=tf.float32.')
      pruning_params = {
          'pruning_schedule':
              tfmot.sparsity.keras.PolynomialDecay(
                  initial_sparsity=flags_obj.pruning_initial_sparsity,
                  final_sparsity=flags_obj.pruning_final_sparsity,
                  begin_step=flags_obj.pruning_begin_step,
                  end_step=flags_obj.pruning_end_step,
                  frequency=flags_obj.pruning_frequency),
      }
      model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
    elif flags_obj.pruning_method:
      raise NotImplementedError(
          'Only polynomial_decay is currently supported.')
    # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
    # a valid arg for this model. Also remove as a valid flag.
    if flags_obj.force_v2_in_keras_compile is not None:
      model.compile(
          loss='sparse_categorical_crossentropy',
          optimizer=optimizer,
          metrics=(['sparse_categorical_accuracy']
                   if flags_obj.report_accuracy_metrics else None),
          run_eagerly=flags_obj.run_eagerly,
          experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)
    else:
      model.compile(
          loss='sparse_categorical_crossentropy',
          optimizer=optimizer,
          metrics=(['sparse_categorical_accuracy']
                   if flags_obj.report_accuracy_metrics else None),
          run_eagerly=flags_obj.run_eagerly)

  callbacks = common.get_callbacks(
      steps_per_epoch=steps_per_epoch,
      learning_rate_schedule_fn=learning_rate_schedule_fn,
      pruning_method=flags_obj.pruning_method,
      enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
      model_dir=flags_obj.model_dir)

  num_eval_steps = common.get_num_eval_steps(flags_obj)
  if flags_obj.skip_eval:
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None

  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
    # when not using distribition strategy.
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

  mlp_log.mlperf_print('init_stop', None)
  mlp_log.mlperf_print('run_start', None)

  for epoch in range(train_epochs):
    mlp_log.mlperf_print('epoch_start', None,
                         metadata={'first_epoch_num': epoch,
                                   'epoch_count': 1})
    mlp_log.mlperf_print('block_start', None)
    history = model.fit(train_input_dataset,
                        epochs=1,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks,
                        verbose=2)
    mlp_log.mlperf_print('block_stop', None)

    eval_output = None
    if not flags_obj.skip_eval:
      mlp_log.mlperf_print('eval_start', None)
      eval_output = model.evaluate(eval_input_dataset,
                                   steps=num_eval_steps,
                                   verbose=2)
      mlp_log.mlperf_print('eval_stop', None)

      eval_accuracy = eval_output[1]
      mlp_log.mlperf_print(
          'eval_accuracy', eval_accuracy, metadata={'epoch_num': epoch})
      if eval_accuracy >= flags_obj.target_accuracy:
        break

    mlp_log.mlperf_print('epoch_stop', None)

  mlp_log.mlperf_print('run_stop', None)

  if flags_obj.pruning_method:
    model = tfmot.sparsity.keras.strip_pruning(model)
  if flags_obj.enable_checkpoint_and_export:
    if dtype == tf.bfloat16:
      logging.warning('Keras model.save does not support bfloat16 dtype.')
    else:
      # Keras model.save assumes a float32 input designature.
      export_path = os.path.join(flags_obj.model_dir, 'saved_model')
      model.save(export_path, include_optimizer=False)

  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

  stats = common.build_stats(history, eval_output, callbacks)
  return stats
コード例 #2
0
    def __init__(self, flags_obj, time_callback):
        standard_runnable.StandardRunnableWithWarmup.__init__(
            self, flags_obj.use_tf_while_loop, flags_obj.use_tf_function)

        self.strategy = tf.distribute.get_strategy()
        self.flags_obj = flags_obj
        self.dtype = flags_core.get_tf_dtype(flags_obj)
        self.time_callback = time_callback

        # Input pipeline related
        batch_size = flags_obj.batch_size
        if batch_size % self.strategy.num_replicas_in_sync != 0:
            raise ValueError(
                'Batch size must be divisible by number of replicas : {}'.
                format(self.strategy.num_replicas_in_sync))

        steps_per_epoch, train_epochs = common.get_num_train_iterations(
            flags_obj)
        if train_epochs > 1:
            train_epochs = flags_obj.train_epochs

        # As auto rebatching is not supported in
        # `experimental_distribute_datasets_from_function()` API, which is
        # required when cloning dataset to multiple workers in eager mode,
        # we use per-replica batch size.
        self.batch_size = int(batch_size / self.strategy.num_replicas_in_sync)

        self.synthetic_input_fn = common.get_synth_input_fn(
            height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            num_channels=imagenet_preprocessing.NUM_CHANNELS,
            num_classes=self.flags_obj.num_classes,
            dtype=self.dtype,
            drop_remainder=True)

        if self.flags_obj.use_synthetic_data:
            self.input_fn = self.synthetic_input_fn
        else:
            self.input_fn = imagenet_preprocessing.input_fn

        resnet_model.change_keras_layer(flags_obj.use_tf_keras_layers)
        self.model = resnet_model.resnet50(
            num_classes=self.flags_obj.num_classes,
            batch_size=flags_obj.batch_size,
            use_l2_regularizer=not flags_obj.single_l2_loss_op)

        self.use_lars_optimizer = False
        if self.flags_obj.optimizer == 'LARS':
            self.use_lars_optimizer = True
        self.optimizer, _ = common.get_optimizer(
            flags_obj=flags_obj,
            steps_per_epoch=steps_per_epoch,
            train_steps=steps_per_epoch * train_epochs)
        # Make sure iterations variable is created inside scope.
        self.global_step = self.optimizer.iterations

        if self.dtype == tf.float16:
            loss_scale = flags_core.get_loss_scale(flags_obj,
                                                   default_for_fp16=128)
            self.optimizer = (
                tf.keras.mixed_precision.experimental.LossScaleOptimizer(
                    self.optimizer, loss_scale))
        elif flags_obj.fp16_implementation == 'graph_rewrite':
            # `dtype` is still float32 in this case. We built the graph in float32
            # and let the graph rewrite change parts of it float16.
            if not flags_obj.use_tf_function:
                raise ValueError(
                    '--fp16_implementation=graph_rewrite requires '
                    '--use_tf_function to be true')
            loss_scale = flags_core.get_loss_scale(flags_obj,
                                                   default_for_fp16=128)
            self.optimizer = (
                tf.train.experimental.enable_mixed_precision_graph_rewrite(
                    self.optimizer, loss_scale))

        self.one_hot = False
        self.label_smoothing = flags_obj.label_smoothing
        if self.label_smoothing and self.label_smoothing > 0:
            self.one_hot = True

        if flags_obj.report_accuracy_metrics:
            self.train_loss = tf.keras.metrics.Mean('train_loss',
                                                    dtype=tf.float32)
            if self.one_hot:
                self.train_accuracy = tf.keras.metrics.CategoricalAccuracy(
                    'train_accuracy', dtype=tf.float32)
            else:
                self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
                    'train_accuracy', dtype=tf.float32)
            self.test_loss = tf.keras.metrics.Mean('test_loss',
                                                   dtype=tf.float32)
        else:
            self.train_loss = None
            self.train_accuracy = None
            self.test_loss = None

        if self.one_hot:
            self.test_accuracy = tf.keras.metrics.CategoricalAccuracy(
                'test_accuracy', dtype=tf.float32)
        else:
            self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
                'test_accuracy', dtype=tf.float32)
        # self.test_corrects = tf.keras.metrics.Sum(
        #     'test_corrects', dtype=tf.float32)
        self.num_eval_steps = common.get_num_eval_steps(flags_obj)

        self.checkpoint = tf.train.Checkpoint(model=self.model,
                                              optimizer=self.optimizer)

        # Handling epochs.
        self.epoch_steps = steps_per_epoch
        self.epoch_helper = utils.EpochHelper(steps_per_epoch,
                                              self.global_step)

        self.steps_per_loop = flags_obj.steps_per_loop
        profile_steps = flags_obj.profile_steps
        if profile_steps:
            profile_steps = [int(i) for i in profile_steps.split(',')]
            self.trace_start_step = profile_steps[
                0] if profile_steps[0] >= 0 else None
            self.trace_end_step = profile_steps[1]
        else:
            self.trace_start_step = None
            self.trace_end_step = None

        self.epochs_between_evals = flags_obj.epochs_between_evals
コード例 #3
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using custom training loops.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    mlp_log.mlperf_print('cache_clear', True)
    mlp_log.mlperf_print('init_start', None)
    mlp_log.mlperf_print('submission_benchmark', 'resnet')
    mlp_log.mlperf_print('submission_division', 'closed')
    mlp_log.mlperf_print('submission_org', 'google')
    mlp_log.mlperf_print(
        'submission_platform', 'tpu-v3-{}'.format(flags_obj.num_replicas)
        if flags_obj.tpu else 'gpu-v100-{}'.format(flags_obj.num_gpus))
    mlp_log.mlperf_print('submission_status', 'cloud')

    common.print_flags(flags_obj)

    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)
    performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))

    if tf.config.list_physical_devices('GPU'):
        if flags_obj.tf_gpu_thread_mode:
            datasets_num_private_threads = keras_utils.set_gpu_thread_mode_and_count(
                per_gpu_thread_count=flags_obj.per_gpu_thread_count,
                gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
                num_gpus=flags_obj.num_gpus)
            if not flags_obj.datasets_num_private_threads:
                flags_obj.datasets_num_private_threads = datasets_num_private_threads
        common.set_cudnn_batchnorm_mode()

    # TODO(anj-s): Set data_format without using Keras.
    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu,
        tpu_zone=flags_obj.tpu_zone if flags_obj.tpu else None)
    mlp_log.mlperf_print('global_batch_size', flags_obj.batch_size)
    mlp_log.mlperf_print('train_samples',
                         imagenet_preprocessing.NUM_IMAGES['train'])
    mlp_log.mlperf_print('eval_samples',
                         imagenet_preprocessing.NUM_IMAGES['validation'])
    mlp_log.mlperf_print(
        'model_bn_span',
        int(flags_obj.batch_size /
            (flags_obj.num_replicas if flags_obj.tpu else flags_obj.num_gpus)))

    per_epoch_steps, train_epochs = common.get_num_train_iterations(flags_obj)
    eval_steps = common.get_num_eval_steps(flags_obj)
    steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps)

    logging.info(
        'Training %d epochs, each epoch has %d steps, '
        'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
        train_epochs * per_epoch_steps, eval_steps)

    time_callback = keras_utils.TimeHistory(
        flags_obj.batch_size,
        flags_obj.log_steps,
        logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
    with distribution_utils.get_strategy_scope(strategy):
        runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback)

    eval_interval = (flags_obj.epochs_between_evals *
                     per_epoch_steps if not flags_obj.skip_eval else None)
    eval_offset = (flags_obj.eval_offset_epochs *
                   per_epoch_steps if not flags_obj.skip_eval else 0)
    if eval_offset != 0:
        eval_offset -= eval_interval
    checkpoint_interval = (per_epoch_steps
                           if flags_obj.enable_checkpoint_and_export else None)
    summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None

    checkpoint_manager = tf.train.CheckpointManager(
        runnable.checkpoint,
        directory=flags_obj.model_dir,
        max_to_keep=10,
        step_counter=runnable.global_step,
        checkpoint_interval=checkpoint_interval)

    device_warmup_steps = (flags_obj.device_warmup_steps
                           if flags_obj.enable_device_warmup else 0)
    if flags_obj.enable_device_warmup:
        logging.info('Warmup for %d steps.', device_warmup_steps)

    resnet_controller = controller.Controller(
        strategy,
        runnable.train,
        runnable.evaluate,
        runnable.warmup,
        global_step=runnable.global_step,
        steps_per_loop=steps_per_loop,
        train_steps=per_epoch_steps * train_epochs,
        device_warmup_steps=device_warmup_steps,
        checkpoint_manager=checkpoint_manager,
        summary_interval=summary_interval,
        eval_steps=eval_steps,
        eval_interval=eval_interval,
        eval_offset=eval_offset)

    if flags_obj.enable_device_warmup:
        resnet_controller.warmup()

    mlp_log.mlperf_print('init_stop', None)

    profile_steps = flags_obj.profile_steps
    if profile_steps:
        profile_steps = [int(i) for i in profile_steps.split(',')]
        if profile_steps[0] < 0:
            runnable.trace_start(-1)

    time_callback.on_train_begin()
    mlp_log.mlperf_print('run_start', None)
    mlp_log.mlperf_print(
        'block_start',
        None,
        metadata={
            'first_epoch_num':
            1,
            'epoch_count':
            (flags_obj.eval_offset_epochs if flags_obj.eval_offset_epochs != 0
             else flags_obj.epochs_between_evals)
        })
    resnet_controller.train(evaluate=not flags_obj.skip_eval)
    mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'})
    time_callback.on_train_end()
    mlp_log.mlperf_print('run_final', None)

    stats = build_stats(runnable, time_callback)
    return stats
コード例 #4
0
def run(flags_obj):
    """Run ResNet Cifar-10 training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_utils.set_gpu_thread_mode_and_count(
            per_gpu_thread_count=flags_obj.per_gpu_thread_count,
            gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
            num_gpus=flags_obj.num_gpus,
            datasets_num_private_threads=flags_obj.datasets_num_private_threads
        )
    common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'fp16':
        raise ValueError(
            'dtype fp16 is not supported in Keras. Use the default '
            'value(fp32).')

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = common.get_synth_input_fn(
            height=cifar_preprocessing.HEIGHT,
            width=cifar_preprocessing.WIDTH,
            num_channels=cifar_preprocessing.NUM_CHANNELS,
            num_classes=cifar_preprocessing.NUM_CLASSES,
            dtype=flags_core.get_tf_dtype(flags_obj),
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = cifar_preprocessing.input_fn

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=cifar_preprocessing.parse_record,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        # Setting drop_remainder to avoid the partial batch logic in normalization
        # layer, which triggers tf.where and leads to extra memory copy of input
        # sizes between host and GPU.
        drop_remainder=(not flags_obj.enable_get_next_as_optional))

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=flags_obj.batch_size,
            num_epochs=flags_obj.train_epochs,
            parse_record_fn=cifar_preprocessing.parse_record)

    steps_per_epoch = (cifar_preprocessing.NUM_IMAGES['train'] //
                       flags_obj.batch_size)
    lr_schedule = 0.1
    if flags_obj.use_tensor_lr:
        initial_learning_rate = common.BASE_LEARNING_RATE * flags_obj.batch_size / 128
        lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            boundaries=list(p[1] * steps_per_epoch for p in LR_SCHEDULE),
            values=[initial_learning_rate] + list(p[0] * initial_learning_rate
                                                  for p in LR_SCHEDULE))

    with strategy_scope:
        optimizer = common.get_optimizer(lr_schedule)
        model = resnet_cifar_model.resnet56(
            classes=cifar_preprocessing.NUM_CLASSES)

        # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
        # a valid arg for this model. Also remove as a valid flag.
        if flags_obj.force_v2_in_keras_compile is not None:
            model.compile(
                loss='sparse_categorical_crossentropy',
                optimizer=optimizer,
                metrics=(['sparse_categorical_accuracy']
                         if flags_obj.report_accuracy_metrics else None),
                run_eagerly=flags_obj.run_eagerly,
                experimental_run_tf_function=flags_obj.
                force_v2_in_keras_compile)
        else:
            model.compile(
                loss='sparse_categorical_crossentropy',
                optimizer=optimizer,
                metrics=(['sparse_categorical_accuracy']
                         if flags_obj.report_accuracy_metrics else None),
                run_eagerly=flags_obj.run_eagerly)

    train_epochs = flags_obj.train_epochs

    callbacks = common.get_callbacks(steps_per_epoch)

    if not flags_obj.use_tensor_lr:
        lr_callback = LearningRateBatchScheduler(
            schedule=learning_rate_schedule,
            batch_size=flags_obj.batch_size,
            steps_per_epoch=steps_per_epoch)
        callbacks.append(lr_callback)

    # if mutliple epochs, ignore the train_steps flag.
    if train_epochs <= 1 and flags_obj.train_steps:
        steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
        train_epochs = 1

    num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    if not strategy and flags_obj.explicit_gpu_placement:
        # TODO(b/135607227): Add device scope automatically in Keras training loop
        # when not using distribition strategy.
        no_dist_strat_device = tf.device('/device:GPU:0')
        no_dist_strat_device.__enter__()

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)
    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

    if not strategy and flags_obj.explicit_gpu_placement:
        no_dist_strat_device.__exit__()

    stats = common.build_stats(history, eval_output, callbacks)
    return stats