Ejemplo n.º 1
0
def run(flags_obj):
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  keras_utils.set_session_config(
      enable_eager=flags_obj.enable_eager,
      enable_xla=flags_obj.enable_xla)

  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
    keras_common.set_gpu_thread_mode_and_count(flags_obj)
  if flags_obj.data_delay_prefetch:
    keras_common.data_delay_prefetch()
  keras_common.set_cudnn_batchnorm_mode()

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == 'float16':
    policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
    tf.keras.mixed_precision.experimental.set_policy(policy)

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
      num_workers=distribution_utils.configure_cluster(),
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs)

  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )

  strategy_scope = distribution_utils.get_strategy_scope(strategy)

  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
    distribution_utils.set_up_synthetic_data()
    input_fn = keras_common.get_synth_input_fn(
        height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_preprocessing.NUM_CHANNELS,
        num_classes=imagenet_preprocessing.NUM_CLASSES,
        dtype=dtype,
        drop_remainder=True)
  else:
    distribution_utils.undo_set_up_synthetic_data()
    input_fn = imagenet_preprocessing.input_fn

  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=imagenet_preprocessing.parse_record,
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
      dtype=dtype,
      drop_remainder=drop_remainder,
      tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
  )

  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=imagenet_preprocessing.parse_record,
        dtype=dtype,
        drop_remainder=drop_remainder)

  lr_schedule = 0.1
  if flags_obj.use_tensor_lr:
    lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup(
        batch_size=flags_obj.batch_size,
        epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
        warmup_epochs=LR_SCHEDULE[0][1],
        boundaries=list(p[1] for p in LR_SCHEDULE[1:]),
        multipliers=list(p[0] for p in LR_SCHEDULE),
        compute_lr_on_cpu=True)

  with strategy_scope:
    optimizer = keras_common.get_optimizer(lr_schedule)
    if dtype == 'float16':
      # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
      # can be enabled with a single line of code.
      optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
          optimizer, loss_scale=flags_core.get_loss_scale(flags_obj,
                                                          default_for_fp16=128))

    if flags_obj.use_trivial_model:
      model = trivial_model.trivial_model(
          imagenet_preprocessing.NUM_CLASSES, dtype)
    else:
      model = resnet_model.resnet50(
          num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype)

    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=optimizer,
        metrics=(['sparse_categorical_accuracy']
                 if flags_obj.report_accuracy_metrics else None),
        run_eagerly=flags_obj.run_eagerly,
        experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)

  callbacks = keras_common.get_callbacks(
      learning_rate_schedule, imagenet_preprocessing.NUM_IMAGES['train'])

  train_steps = (
      imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

  num_eval_steps = (
      imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size)

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None
    validation_data = None

  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
    # when not using distribition strategy.
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
                      callbacks=callbacks,
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      validation_freq=flags_obj.epochs_between_evals,
                      verbose=2)

  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

  stats = keras_common.build_stats(history, eval_output, callbacks)
  return stats
Ejemplo n.º 2
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
    # Eager is default in tf 2.0 and should not be toggled
    if keras_common.is_v2_0():
        keras_common.set_config_v2()
    else:
        config = keras_common.get_config_proto_v1()
        if flags_obj.enable_eager:
            tf.compat.v1.enable_eager_execution(config=config)
        else:
            sess = tf.Session(config=config)
            tf.keras.backend.set_session(sess)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_common.set_gpu_thread_mode_and_count(flags_obj)
    if flags_obj.data_delay_prefetch:
        keras_common.data_delay_prefetch()
    keras_common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'float16':
        policy = tf.keras.mixed_precision.experimental.Policy(
            'infer_float32_vars')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=distribution_utils.configure_cluster(),
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last partial
    # batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    # pylint: disable=protected-access
    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = keras_common.get_synth_input_fn(
            height=imagenet_main.DEFAULT_IMAGE_SIZE,
            width=imagenet_main.DEFAULT_IMAGE_SIZE,
            num_channels=imagenet_main.NUM_CHANNELS,
            num_classes=imagenet_main.NUM_CLASSES,
            dtype=dtype,
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = imagenet_main.input_fn

    # When `enable_xla` is True, we always drop the remainder of the batches
    # in the dataset, as XLA-GPU doesn't support dynamic shapes.
    drop_remainder = flags_obj.enable_xla

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=parse_record_keras,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        drop_remainder=drop_remainder,
        tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
    )

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(is_training=False,
                                      data_dir=flags_obj.data_dir,
                                      batch_size=flags_obj.batch_size,
                                      num_epochs=flags_obj.train_epochs,
                                      parse_record_fn=parse_record_keras,
                                      dtype=dtype,
                                      drop_remainder=drop_remainder)

    lr_schedule = 0.1
    if flags_obj.use_tensor_lr:
        lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup(
            batch_size=flags_obj.batch_size,
            epoch_size=imagenet_main.NUM_IMAGES['train'],
            warmup_epochs=LR_SCHEDULE[0][1],
            boundaries=list(p[1] for p in LR_SCHEDULE[1:]),
            multipliers=list(p[0] for p in LR_SCHEDULE),
            compute_lr_on_cpu=True)

    with strategy_scope:
        optimizer = keras_common.get_optimizer(lr_schedule)
        if dtype == 'float16':
            # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
            # can be enabled with a single line of code.
            optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
                optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))

        if flags_obj.enable_xla and not flags_obj.enable_eager:
            # TODO(b/129861005): Fix OOM issue in eager mode when setting
            # `batch_size` in keras.Input layer.
            if strategy and strategy.num_replicas_in_sync > 1:
                # TODO(b/129791381): Specify `input_layer_batch_size` value in
                # DistributionStrategy multi-replica case.
                input_layer_batch_size = None
            else:
                input_layer_batch_size = flags_obj.batch_size
        else:
            input_layer_batch_size = None

        if flags_obj.use_trivial_model:
            model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES,
                                                dtype)
        else:
            model = resnet_model.resnet50(
                num_classes=imagenet_main.NUM_CLASSES,
                dtype=dtype,
                batch_size=input_layer_batch_size)

        model.compile(loss='sparse_categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=(['sparse_categorical_accuracy']
                               if flags_obj.report_accuracy_metrics else None),
                      cloning=flags_obj.clone_model_in_keras_dist_strat)

    callbacks = keras_common.get_callbacks(learning_rate_schedule,
                                           imagenet_main.NUM_IMAGES['train'])

    train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
    train_epochs = flags_obj.train_epochs

    if flags_obj.train_steps:
        train_steps = min(flags_obj.train_steps, train_steps)
        train_epochs = 1

    num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        # Only build the training graph. This reduces memory usage introduced by
        # control flow ops in layers that have different implementations for
        # training and inference (e.g., batch norm).
        tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)

    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)
    stats = keras_common.build_stats(history, eval_output, callbacks)
    return stats