예제 #1
0
def run(flags_obj):
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  config = keras_common.get_config_proto()
  # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
  # Eager is default in tf 2.0 and should not be toggled
  if not keras_common.is_v2_0():
    if flags_obj.enable_eager:
      tf.compat.v1.enable_eager_execution(config=config)
    else:
      sess = tf.Session(config=config)
      tf.keras.backend.set_session(sess)
  # TODO(haoyuzhang): Set config properly in TF2.0 when the config API is ready.

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == 'float16':
    policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
    tf.keras.mixed_precision.experimental.set_policy(policy)

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
    distribution_utils.set_up_synthetic_data()
    input_fn = keras_common.get_synth_input_fn(
        height=imagenet_main.DEFAULT_IMAGE_SIZE,
        width=imagenet_main.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_main.NUM_CHANNELS,
        num_classes=imagenet_main.NUM_CLASSES,
        dtype=dtype)
  else:
    distribution_utils.undo_set_up_synthetic_data()
    input_fn = imagenet_main.input_fn

  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras,
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
      dtype=dtype)

  eval_input_dataset = input_fn(
      is_training=False,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras,
      dtype=dtype)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus)

  strategy_scope = keras_common.get_strategy_scope(strategy)

  with strategy_scope:
    optimizer = keras_common.get_optimizer()
    if dtype == 'float16':
      # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
      # can be enabled with a single line of code.
      optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
          optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))
    model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES,
                                  dtype=dtype)

    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer=optimizer,
                  metrics=['sparse_categorical_accuracy'])

  time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
      learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])

  train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

  num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
                    flags_obj.batch_size)

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
    tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None
    validation_data = None

  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
                      callbacks=[
                          time_callback,
                          lr_callback,
                          tensorboard_callback
                      ],
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      validation_freq=flags_obj.epochs_between_evals,
                      verbose=2)

  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)
  stats = keras_common.build_stats(history, eval_output, time_callback)
  return stats
예제 #2
0
def run(flags_obj):
  """Run ResNet Cifar-10 training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  keras_utils.set_session_config(
      enable_eager=flags_obj.enable_eager,
      enable_xla=flags_obj.enable_xla)

  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
    keras_utils.set_gpu_thread_mode_and_count(
        per_gpu_thread_count=flags_obj.per_gpu_thread_count,
        gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
        num_gpus=flags_obj.num_gpus,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads)
  common.set_cudnn_batchnorm_mode()

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == 'fp16':
    raise ValueError('dtype fp16 is not supported in Keras. Use the default '
                     'value(fp32).')

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
                   else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs)

  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )

  strategy_scope = distribution_utils.get_strategy_scope(strategy)

  if flags_obj.use_synthetic_data:
    distribution_utils.set_up_synthetic_data()
    input_fn = common.get_synth_input_fn(
        height=cifar_preprocessing.HEIGHT,
        width=cifar_preprocessing.WIDTH,
        num_channels=cifar_preprocessing.NUM_CHANNELS,
        num_classes=cifar_preprocessing.NUM_CLASSES,
        dtype=flags_core.get_tf_dtype(flags_obj),
        drop_remainder=True)
  else:
    distribution_utils.undo_set_up_synthetic_data()
    input_fn = cifar_preprocessing.input_fn

  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      parse_record_fn=cifar_preprocessing.parse_record,
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
      dtype=dtype,
      # Setting drop_remainder to avoid the partial batch logic in normalization
      # layer, which triggers tf.where and leads to extra memory copy of input
      # sizes between host and GPU.
      drop_remainder=(not flags_obj.enable_get_next_as_optional))

  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        parse_record_fn=cifar_preprocessing.parse_record)

  steps_per_epoch = (
      cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
  lr_schedule = 0.1
  if flags_obj.use_tensor_lr:
    initial_learning_rate = common.BASE_LEARNING_RATE * flags_obj.batch_size / 128
    lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
        boundaries=list(p[1] * steps_per_epoch for p in LR_SCHEDULE),
        values=[initial_learning_rate] +
        list(p[0] * initial_learning_rate for p in LR_SCHEDULE))

  with strategy_scope:
    optimizer = common.get_optimizer(lr_schedule)
    model = resnet_cifar_model.resnet56(classes=cifar_preprocessing.NUM_CLASSES)
    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=optimizer,
        metrics=(['sparse_categorical_accuracy']
                 if flags_obj.report_accuracy_metrics else None),
        run_eagerly=flags_obj.run_eagerly)

  train_epochs = flags_obj.train_epochs

  callbacks = common.get_callbacks(steps_per_epoch)

  if not flags_obj.use_tensor_lr:
    lr_callback = LearningRateBatchScheduler(
        schedule=learning_rate_schedule,
        batch_size=flags_obj.batch_size,
        steps_per_epoch=steps_per_epoch)
    callbacks.append(lr_callback)

  # if mutliple epochs, ignore the train_steps flag.
  if train_epochs <= 1 and flags_obj.train_steps:
    steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
    train_epochs = 1

  num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
                    flags_obj.batch_size)

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None
    validation_data = None

  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
    # when not using distribition strategy.
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=steps_per_epoch,
                      callbacks=callbacks,
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      validation_freq=flags_obj.epochs_between_evals,
                      verbose=2)
  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

  stats = common.build_stats(history, eval_output, callbacks)
  return stats
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        common.set_gpu_thread_mode_and_count(flags_obj)
    if flags_obj.data_delay_prefetch:
        common.data_delay_prefetch()
    common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'float16':
        policy = tf.keras.mixed_precision.experimental.Policy(
            'infer_float32_vars')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    # Configures cluster spec for distribution strategy.
    num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                                       flags_obj.task_index)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=num_workers,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    # pylint: disable=protected-access
    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = common.get_synth_input_fn(
            height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            num_channels=imagenet_preprocessing.NUM_CHANNELS,
            num_classes=imagenet_preprocessing.NUM_CLASSES,
            dtype=dtype,
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = imagenet_preprocessing.input_fn

    # When `enable_xla` is True, we always drop the remainder of the batches
    # in the dataset, as XLA-GPU doesn't support dynamic shapes.
    drop_remainder = flags_obj.enable_xla

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=imagenet_preprocessing.parse_record,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        drop_remainder=drop_remainder,
        tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
    )

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=flags_obj.batch_size,
            num_epochs=flags_obj.train_epochs,
            parse_record_fn=imagenet_preprocessing.parse_record,
            dtype=dtype,
            drop_remainder=drop_remainder)

    lr_schedule = 0.1
    if flags_obj.use_tensor_lr:
        lr_schedule = common.PiecewiseConstantDecayWithWarmup(
            batch_size=flags_obj.batch_size,
            epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
            warmup_epochs=LR_SCHEDULE[0][1],
            boundaries=list(p[1] for p in LR_SCHEDULE[1:]),
            multipliers=list(p[0] for p in LR_SCHEDULE),
            compute_lr_on_cpu=True)

    with strategy_scope:
        optimizer = common.get_optimizer(lr_schedule)
        if dtype == 'float16':
            # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
            # can be enabled with a single line of code.
            optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
                optimizer,
                loss_scale=flags_core.get_loss_scale(flags_obj,
                                                     default_for_fp16=128))

        if flags_obj.use_trivial_model:
            model = trivial_model.trivial_model(
                imagenet_preprocessing.NUM_CLASSES, dtype)
        else:
            model = resnet_model.resnet50(
                num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype)

        # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
        # a valid arg for this model. Also remove as a valid flag.
        if flags_obj.force_v2_in_keras_compile is not None:
            model.compile(
                loss='sparse_categorical_crossentropy',
                optimizer=optimizer,
                metrics=(['sparse_categorical_accuracy']
                         if flags_obj.report_accuracy_metrics else None),
                run_eagerly=flags_obj.run_eagerly,
                experimental_run_tf_function=flags_obj.
                force_v2_in_keras_compile)
        else:
            model.compile(
                loss='sparse_categorical_crossentropy',
                optimizer=optimizer,
                metrics=(['sparse_categorical_accuracy']
                         if flags_obj.report_accuracy_metrics else None),
                run_eagerly=flags_obj.run_eagerly)

    callbacks = common.get_callbacks(
        learning_rate_schedule, imagenet_preprocessing.NUM_IMAGES['train'])

    train_steps = (imagenet_preprocessing.NUM_IMAGES['train'] //
                   flags_obj.batch_size)
    train_epochs = flags_obj.train_epochs

    if flags_obj.train_steps:
        train_steps = min(flags_obj.train_steps, train_steps)
        train_epochs = 1

    num_eval_steps = (imagenet_preprocessing.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        # Only build the training graph. This reduces memory usage introduced by
        # control flow ops in layers that have different implementations for
        # training and inference (e.g., batch norm).
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    if not strategy and flags_obj.explicit_gpu_placement:
        # TODO(b/135607227): Add device scope automatically in Keras training loop
        # when not using distribition strategy.
        no_dist_strat_device = tf.device('/device:GPU:0')
        no_dist_strat_device.__enter__()

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)

    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

    if not strategy and flags_obj.explicit_gpu_placement:
        no_dist_strat_device.__exit__()

    stats = common.build_stats(history, eval_output, callbacks)
    return stats
예제 #4
0
def run(flags_obj):
    """Run ResNet Cifar-10 training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_common.set_gpu_thread_mode_and_count(flags_obj)
    keras_common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'fp16':
        raise ValueError(
            'dtype fp16 is not supported in Keras. Use the default '
            'value(fp32).')

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=distribution_utils.configure_cluster(),
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = keras_common.get_synth_input_fn(
            height=cifar_preprocessing.HEIGHT,
            width=cifar_preprocessing.WIDTH,
            num_channels=cifar_preprocessing.NUM_CHANNELS,
            num_classes=cifar_preprocessing.NUM_CLASSES,
            dtype=flags_core.get_tf_dtype(flags_obj),
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = cifar_preprocessing.input_fn

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=cifar_preprocessing.parse_record,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        # Setting drop_remainder to avoid the partial batch logic in normalization
        # layer, which triggers tf.where and leads to extra memory copy of input
        # sizes between host and GPU.
        drop_remainder=(not flags_obj.enable_get_next_as_optional))

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=flags_obj.batch_size,
            num_epochs=flags_obj.train_epochs,
            parse_record_fn=cifar_preprocessing.parse_record)

    with strategy_scope:
        optimizer = keras_common.get_optimizer()
        model = resnet_cifar_model.resnet56(
            classes=cifar_preprocessing.NUM_CLASSES)

        # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
        # a valid arg for this model. Also remove as a valid flag.
        if flags_obj.force_v2_in_keras_compile is not None:
            model.compile(
                loss='categorical_crossentropy',
                optimizer=optimizer,
                metrics=(['categorical_accuracy']
                         if flags_obj.report_accuracy_metrics else None),
                run_eagerly=flags_obj.run_eagerly,
                experimental_run_tf_function=flags_obj.
                force_v2_in_keras_compile)
        else:
            model.compile(
                loss='categorical_crossentropy',
                optimizer=optimizer,
                metrics=(['categorical_accuracy']
                         if flags_obj.report_accuracy_metrics else None),
                run_eagerly=flags_obj.run_eagerly)

    callbacks = keras_common.get_callbacks(
        learning_rate_schedule, cifar_preprocessing.NUM_IMAGES['train'])

    train_steps = cifar_preprocessing.NUM_IMAGES[
        'train'] // flags_obj.batch_size
    train_epochs = flags_obj.train_epochs

    if flags_obj.train_steps:
        train_steps = min(flags_obj.train_steps, train_steps)
        train_epochs = 1

    num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    if not strategy and flags_obj.explicit_gpu_placement:
        # TODO(b/135607227): Add device scope automatically in Keras training loop
        # when not using distribition strategy.
        no_dist_strat_device = tf.device('/device:GPU:0')
        no_dist_strat_device.__enter__()

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)
    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

    if not strategy and flags_obj.explicit_gpu_placement:
        no_dist_strat_device.__exit__()

    stats = keras_common.build_stats(history, eval_output, callbacks)
    return stats
예제 #5
0
def run(flags_obj):
    """Run ResNet Cifar-10 training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
    # Eager is default in tf 2.0 and should not be toggled
    if flags_obj.enable_eager and not keras_common.is_v2_0():
        tf.compat.v1.enable_eager_execution()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'fp16':
        raise ValueError(
            'dtype fp16 is not supported in Keras. Use the default '
            'value(fp32).')

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = keras_common.get_synth_input_fn(
            height=cifar_main.HEIGHT,
            width=cifar_main.WIDTH,
            num_channels=cifar_main.NUM_CHANNELS,
            num_classes=cifar_main.NUM_CLASSES,
            dtype=flags_core.get_tf_dtype(flags_obj))
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = cifar_main.input_fn

    train_input_dataset = input_fn(is_training=True,
                                   data_dir=flags_obj.data_dir,
                                   batch_size=flags_obj.batch_size,
                                   num_epochs=flags_obj.train_epochs,
                                   parse_record_fn=parse_record_keras)

    eval_input_dataset = input_fn(is_training=False,
                                  data_dir=flags_obj.data_dir,
                                  batch_size=flags_obj.batch_size,
                                  num_epochs=flags_obj.train_epochs,
                                  parse_record_fn=parse_record_keras)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus)

    strategy_scope = keras_common.get_strategy_scope(strategy)

    with strategy_scope:
        optimizer = keras_common.get_optimizer()
        model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES)

        model.compile(loss='categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=['categorical_accuracy'])

    time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
        learning_rate_schedule, cifar_main.NUM_IMAGES['train'])

    train_steps = cifar_main.NUM_IMAGES['train'] // flags_obj.batch_size
    train_epochs = flags_obj.train_epochs

    if flags_obj.train_steps:
        train_steps = min(flags_obj.train_steps, train_steps)
        train_epochs = 1

    num_eval_steps = (cifar_main.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    history = model.fit(
        train_input_dataset,
        epochs=train_epochs,
        steps_per_epoch=train_steps,
        callbacks=[time_callback, lr_callback, tensorboard_callback],
        validation_steps=num_eval_steps,
        validation_data=validation_data,
        validation_freq=flags_obj.epochs_between_evals,
        verbose=2)
    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)
    stats = keras_common.build_stats(history, eval_output, time_callback)
    return stats
예제 #6
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using native Keras APIs.

    Args:
        flags_obj: An object containing parsed flag values.

    Raises:
        ValueError: If fp16 is passed as it is not currently supported.

    Returns:
        Dictionary of training and eval stats.
    """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_utils.set_gpu_thread_mode_and_count(
            per_gpu_thread_count=flags_obj.per_gpu_thread_count,
            gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
            num_gpus=flags_obj.num_gpus,
            datasets_num_private_threads=flags_obj.datasets_num_private_threads
        )
    common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == tf.float16:
        loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
        policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
            'mixed_float16', loss_scale=loss_scale)
        tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
        if not keras_utils.is_v2_0():
            raise ValueError('--dtype=fp16 is not supported in TensorFlow 1.')
    elif dtype == tf.bfloat16:
        policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
            'mixed_bfloat16')
        tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    preprocessing_seed = 12345

    # pylint: disable=protected-access
    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = common.get_synth_input_fn(
            height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            num_channels=imagenet_preprocessing.NUM_CHANNELS,
            num_classes=imagenet_preprocessing.NUM_CLASSES,
            dtype=dtype,
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = imagenet_preprocessing.input_fn

    # When `enable_xla` is True, we always drop the remainder of the batches
    # in the dataset, as XLA-GPU doesn't support dynamic shapes.
    drop_remainder = flags_obj.enable_xla

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=imagenet_preprocessing.parse_record,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        drop_remainder=drop_remainder,
        random_seed=preprocessing_seed,  #addition
        num_workers=current_cluster_size(),  #addition
        worker_ID=current_rank(),  #addition
        tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
        training_dataset_cache=flags_obj.training_dataset_cache,
    )

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=flags_obj.batch_size,
            num_epochs=flags_obj.train_epochs,
            parse_record_fn=imagenet_preprocessing.parse_record,
            dtype=dtype,
            drop_remainder=drop_remainder)

    lr_schedule = 0.1
    if flags_obj.use_tensor_lr:
        lr_schedule = common.PiecewiseConstantDecayWithWarmup(
            batch_size=flags_obj.batch_size,
            epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
            warmup_epochs=common.LR_SCHEDULE[0][1],
            boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
            multipliers=list(p[0] for p in common.LR_SCHEDULE),
            compute_lr_on_cpu=True)

    # Build KungFu optimizer
    opt = common.get_optimizer(lr_schedule)
    # logging.info(opt.__dict__)
    optimizer = SynchronousSGDOptimizer(opt, reshape=False, use_locking=True)
    optimizer._hyper = opt._hyper
    # logging.info(optimizer.__dict__)

    if flags_obj.fp16_implementation == 'graph_rewrite':
        # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
        # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
        # which will ensure tf.compat.v2.keras.mixed_precision and
        # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
        # up.
        optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
            optimizer)

    # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
    if flags_obj.use_trivial_model:
        model = trivial_model.trivial_model(imagenet_preprocessing.NUM_CLASSES)
    else:
        model = resnet_model.resnet50(
            num_classes=imagenet_preprocessing.NUM_CLASSES)

    # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
    # a valid arg for this model. Also remove as a valid flag.

    metrics = (['sparse_categorical_accuracy'])
    metrics.append('sparse_top_k_categorical_accuracy')

    if flags_obj.force_v2_in_keras_compile is not None:
        model.compile(
            loss='sparse_categorical_crossentropy',
            optimizer=optimizer,
            metrics=metrics,
            run_eagerly=flags_obj.run_eagerly,
            experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)
    else:
        model.compile(loss='sparse_categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=metrics,
                      run_eagerly=flags_obj.run_eagerly)

    # adjust number of steps
    cluster_size = current_cluster_size()
    steps_per_epoch = (imagenet_preprocessing.NUM_IMAGES['train'] //
                       flags_obj.batch_size)
    steps_per_epoch = steps_per_epoch // cluster_size

    train_epochs = flags_obj.train_epochs
    callbacks = common.get_callbacks(steps_per_epoch, current_rank(),
                                     cluster_size,
                                     common.learning_rate_schedule)

    # Broadcast variables for KungFu
    callbacks.append(BroadcastGlobalVariablesCallback())

    # Checkpoint callback only on worker 0
    if flags_obj.enable_checkpoint_and_export and current_rank() == 0:
        ckpt_full_path = os.path.join(flags_obj.model_dir,
                                      'model.ckpt-{epoch:04d}')
        callbacks.append(
            tf.keras.callbacks.ModelCheckpoint(ckpt_full_path,
                                               save_weights_only=True))

    if flags_obj.train_steps:
        steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)

    num_eval_steps = (imagenet_preprocessing.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        # Only build the training graph. This reduces memory usage introduced by
        # control flow ops in layers that have different implementations for
        # training and inference (e.g., batch norm).
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
            num_eval_steps = None
            validation_data = None

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)

    # Checkpoint only on 0th worker
    if flags_obj.enable_checkpoint_and_export and current_rank() == 0:
        if dtype == tf.bfloat16:
            logging.warning(
                "Keras model.save does not support bfloat16 dtype.")
        else:
            # Keras model.save assumes a float32 input designature.
            export_path = os.path.join(flags_obj.model_dir, 'saved_model')
            model.save(export_path, include_optimizer=False)

    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

    stats = common.build_stats(history, eval_output, callbacks)
    return stats
예제 #7
0
def run(flags_obj):
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
  # Eager is default in tf 2.0 and should not be toggled
  if keras_common.is_v2_0():
    keras_common.set_config_v2()
  else:
    config = keras_common.get_config_proto_v1()
    if flags_obj.enable_eager:
      tf.compat.v1.enable_eager_execution(config=config)
    else:
      sess = tf.Session(config=config)
      tf.keras.backend.set_session(sess)

  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
    keras_common.set_gpu_thread_mode_and_count(flags_obj)

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == 'float16':
    policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
    tf.keras.mixed_precision.experimental.set_policy(policy)

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
      num_workers=distribution_utils.configure_cluster())

  strategy_scope = keras_common.get_strategy_scope(strategy)

  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
    distribution_utils.set_up_synthetic_data()
    input_fn = keras_common.get_synth_input_fn(
        height=imagenet_main.DEFAULT_IMAGE_SIZE,
        width=imagenet_main.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_main.NUM_CHANNELS,
        num_classes=imagenet_main.NUM_CLASSES,
        dtype=dtype)
  else:
    distribution_utils.undo_set_up_synthetic_data()
    input_fn = imagenet_main.input_fn

  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras,
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
      dtype=dtype)

  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=parse_record_keras,
        dtype=dtype)

  with strategy_scope:
    optimizer = keras_common.get_optimizer()
    if dtype == 'float16':
      # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
      # can be enabled with a single line of code.
      optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
          optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))

    if flags_obj.use_trivial_model:
      model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES)
    else:
      model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES,
                                    dtype=dtype)

    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer=optimizer,
                  metrics=['sparse_categorical_accuracy'])

  callbacks = keras_common.get_callbacks(
      learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])

  train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

  num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
                    flags_obj.batch_size)

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
    tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None
    validation_data = None

  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
                      callbacks=callbacks,
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      validation_freq=flags_obj.epochs_between_evals,
                      verbose=2)

  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)
  stats = keras_common.build_stats(history, eval_output, callbacks)
  return stats
예제 #8
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
    # Eager is default in tf 2.0 and should not be toggled
    if keras_common.is_v2_0():
        keras_common.set_config_v2()
    else:
        config = keras_common.get_config_proto_v1()
        if flags_obj.enable_eager:
            tf.compat.v1.enable_eager_execution(config=config)
        else:
            sess = tf.Session(config=config)
            tf.keras.backend.set_session(sess)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_common.set_gpu_thread_mode_and_count(flags_obj)
    if flags_obj.data_delay_prefetch:
        keras_common.data_delay_prefetch()
    keras_common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'float16':
        policy = tf.keras.mixed_precision.experimental.Policy(
            'infer_float32_vars')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=distribution_utils.configure_cluster(),
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last partial
    # batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    # pylint: disable=protected-access
    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = keras_common.get_synth_input_fn(
            height=imagenet_main.DEFAULT_IMAGE_SIZE,
            width=imagenet_main.DEFAULT_IMAGE_SIZE,
            num_channels=imagenet_main.NUM_CHANNELS,
            num_classes=imagenet_main.NUM_CLASSES,
            dtype=dtype,
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = imagenet_main.input_fn

    # When `enable_xla` is True, we always drop the remainder of the batches
    # in the dataset, as XLA-GPU doesn't support dynamic shapes.
    drop_remainder = flags_obj.enable_xla

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=parse_record_keras,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        drop_remainder=drop_remainder,
        tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
    )

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(is_training=False,
                                      data_dir=flags_obj.data_dir,
                                      batch_size=flags_obj.batch_size,
                                      num_epochs=flags_obj.train_epochs,
                                      parse_record_fn=parse_record_keras,
                                      dtype=dtype,
                                      drop_remainder=drop_remainder)

    lr_schedule = 0.1
    if flags_obj.use_tensor_lr:
        lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup(
            batch_size=flags_obj.batch_size,
            epoch_size=imagenet_main.NUM_IMAGES['train'],
            warmup_epochs=LR_SCHEDULE[0][1],
            boundaries=list(p[1] for p in LR_SCHEDULE[1:]),
            multipliers=list(p[0] for p in LR_SCHEDULE),
            compute_lr_on_cpu=True)

    with strategy_scope:
        optimizer = keras_common.get_optimizer(lr_schedule)
        if dtype == 'float16':
            # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
            # can be enabled with a single line of code.
            optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
                optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))

        if flags_obj.enable_xla and not flags_obj.enable_eager:
            # TODO(b/129861005): Fix OOM issue in eager mode when setting
            # `batch_size` in keras.Input layer.
            if strategy and strategy.num_replicas_in_sync > 1:
                # TODO(b/129791381): Specify `input_layer_batch_size` value in
                # DistributionStrategy multi-replica case.
                input_layer_batch_size = None
            else:
                input_layer_batch_size = flags_obj.batch_size
        else:
            input_layer_batch_size = None

        if flags_obj.use_trivial_model:
            model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES,
                                                dtype)
        else:
            model = resnet_model.resnet50(
                num_classes=imagenet_main.NUM_CLASSES,
                dtype=dtype,
                batch_size=input_layer_batch_size)

        model.compile(loss='sparse_categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=(['sparse_categorical_accuracy']
                               if flags_obj.report_accuracy_metrics else None),
                      cloning=flags_obj.clone_model_in_keras_dist_strat)

    callbacks = keras_common.get_callbacks(learning_rate_schedule,
                                           imagenet_main.NUM_IMAGES['train'])

    train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
    train_epochs = flags_obj.train_epochs

    if flags_obj.train_steps:
        train_steps = min(flags_obj.train_steps, train_steps)
        train_epochs = 1

    num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        # Only build the training graph. This reduces memory usage introduced by
        # control flow ops in layers that have different implementations for
        # training and inference (e.g., batch norm).
        tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)

    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)
    stats = keras_common.build_stats(history, eval_output, callbacks)
    return stats
예제 #9
0
def run(flags_obj):
    """Run ResNet Cifar-10 training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'fp16':
        raise ValueError(
            'dtype fp16 is not supported in Keras. Use the default '
            'value(fp32).')

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=distribution_utils.configure_cluster(),
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = keras_common.get_synth_input_fn(
            height=cifar_main.HEIGHT,
            width=cifar_main.WIDTH,
            num_channels=cifar_main.NUM_CHANNELS,
            num_classes=cifar_main.NUM_CLASSES,
            dtype=flags_core.get_tf_dtype(flags_obj))
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = cifar_main.input_fn

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=parse_record_keras,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype)

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(is_training=False,
                                      data_dir=flags_obj.data_dir,
                                      batch_size=flags_obj.batch_size,
                                      num_epochs=flags_obj.train_epochs,
                                      parse_record_fn=parse_record_keras)

    with strategy_scope:
        optimizer = keras_common.get_optimizer()
        model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES)

        model.compile(loss='categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=(['categorical_accuracy']
                               if flags_obj.report_accuracy_metrics else None),
                      run_eagerly=flags_obj.run_eagerly)

    callbacks = keras_common.get_callbacks(learning_rate_schedule,
                                           cifar_main.NUM_IMAGES['train'])

    train_steps = cifar_main.NUM_IMAGES['train'] // flags_obj.batch_size
    train_epochs = flags_obj.train_epochs

    if flags_obj.train_steps:
        train_steps = min(flags_obj.train_steps, train_steps)
        train_epochs = 1

    num_eval_steps = (cifar_main.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    if not strategy and flags_obj.explicit_gpu_placement:
        # TODO(b/135607227): Add device scope automatically in Keras training loop
        # when not using distribition strategy.
        no_dist_strat_device = tf.device('/device:GPU:0')
        no_dist_strat_device.__enter__()

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)
    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

    if not strategy and flags_obj.explicit_gpu_placement:
        no_dist_strat_device.__exit__()

    stats = keras_common.build_stats(history, eval_output, callbacks)
    return stats
예제 #10
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
    NotImplementedError: If some features are not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_utils.set_gpu_thread_mode_and_count(
            per_gpu_thread_count=flags_obj.per_gpu_thread_count,
            gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
            num_gpus=flags_obj.num_gpus,
            datasets_num_private_threads=flags_obj.datasets_num_private_threads
        )
    common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == tf.float16:
        loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
        policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
            'mixed_float16', loss_scale=loss_scale)
        tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
        if not keras_utils.is_v2_0():
            raise ValueError('--dtype=fp16 is not supported in TensorFlow 1.')
    elif dtype == tf.bfloat16:
        policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
            'mixed_bfloat16')
        tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    # Configures cluster spec for distribution strategy.
    num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                                       flags_obj.task_index)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=num_workers,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    # pylint: disable=protected-access
    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = common.get_synth_input_fn(
            height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            num_channels=imagenet_preprocessing.NUM_CHANNELS,
            num_classes=imagenet_preprocessing.NUM_CLASSES,
            dtype=dtype,
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = imagenet_preprocessing.input_fn

    # When `enable_xla` is True, we always drop the remainder of the batches
    # in the dataset, as XLA-GPU doesn't support dynamic shapes.
    drop_remainder = flags_obj.enable_xla

    # Current resnet_model.resnet50 input format is always channel-last.
    # We use keras_application mobilenet model which input format is depends on
    # the keras beckend image data format.
    # This use_keras_image_data_format flags indicates whether image preprocessor
    # output format should be same as the keras backend image data format or just
    # channel-last format.
    use_keras_image_data_format = (flags_obj.model == 'mobilenet')
    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        drop_remainder=drop_remainder,
        tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
        training_dataset_cache=flags_obj.training_dataset_cache,
    )

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=flags_obj.batch_size,
            num_epochs=flags_obj.train_epochs,
            parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
                use_keras_image_data_format=use_keras_image_data_format),
            dtype=dtype,
            drop_remainder=drop_remainder)

    lr_schedule = 0.1
    if flags_obj.use_tensor_lr:
        lr_schedule = common.PiecewiseConstantDecayWithWarmup(
            batch_size=flags_obj.batch_size,
            epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
            warmup_epochs=common.LR_SCHEDULE[0][1],
            boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
            multipliers=list(p[0] for p in common.LR_SCHEDULE),
            compute_lr_on_cpu=True)
    steps_per_epoch = (imagenet_preprocessing.NUM_IMAGES['train'] //
                       flags_obj.batch_size)

    learning_rate_schedule_fn = None
    with strategy_scope:
        if flags_obj.optimizer == 'resnet50_default':
            optimizer = common.get_optimizer(lr_schedule)
            learning_rate_schedule_fn = common.learning_rate_schedule
        elif flags_obj.optimizer == 'mobilenet_default':
            lr_decay_factor = 0.94
            num_epochs_per_decay = 2.5
            initial_learning_rate_per_sample = 0.000007
            initial_learning_rate = \
                initial_learning_rate_per_sample * flags_obj.batch_size
            optimizer = tf.keras.optimizers.SGD(
                learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
                    initial_learning_rate,
                    decay_steps=steps_per_epoch * num_epochs_per_decay,
                    decay_rate=lr_decay_factor,
                    staircase=True),
                momentum=0.9)
        if flags_obj.fp16_implementation == 'graph_rewrite':
            # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
            # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
            # which will ensure tf.compat.v2.keras.mixed_precision and
            # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
            # up.
            optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                optimizer)

        # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
        if flags_obj.use_trivial_model:
            model = trivial_model.trivial_model(
                imagenet_preprocessing.NUM_CLASSES)
        elif flags_obj.model == 'resnet50_v1.5':
            resnet_model.change_keras_layer(flags_obj.use_tf_keras_layers)
            model = resnet_model.resnet50(
                num_classes=imagenet_preprocessing.NUM_CLASSES)
        elif flags_obj.model == 'mobilenet':
            # TODO(kimjaehong): Remove layers attribute when minimum TF version
            # support 2.0 layers by default.
            model = tf.keras.applications.mobilenet.MobileNet(
                weights=None,
                classes=imagenet_preprocessing.NUM_CLASSES,
                layers=tf.keras.layers)
        if flags_obj.pretrained_filepath:
            model.load_weights(flags_obj.pretrained_filepath)

        if flags_obj.pruning_method == 'polynomial_decay':
            if dtype != tf.float32:
                raise NotImplementedError(
                    'Pruning is currently only supported on dtype=tf.float32.')
            pruning_params = {
                'pruning_schedule':
                tfmot.sparsity.keras.PolynomialDecay(
                    initial_sparsity=flags_obj.pruning_initial_sparsity,
                    final_sparsity=flags_obj.pruning_final_sparsity,
                    begin_step=flags_obj.pruning_begin_step,
                    end_step=flags_obj.pruning_end_step,
                    frequency=flags_obj.pruning_frequency),
            }
            model = tfmot.sparsity.keras.prune_low_magnitude(
                model, **pruning_params)
        elif flags_obj.pruning_method:
            raise NotImplementedError(
                'Only polynomial_decay is currently supported.')
        # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
        # a valid arg for this model. Also remove as a valid flag.
        if flags_obj.force_v2_in_keras_compile is not None:
            model.compile(
                loss='sparse_categorical_crossentropy',
                optimizer=optimizer,
                metrics=(['sparse_categorical_accuracy']
                         if flags_obj.report_accuracy_metrics else None),
                run_eagerly=flags_obj.run_eagerly,
                experimental_run_tf_function=flags_obj.
                force_v2_in_keras_compile)
        else:
            model.compile(
                loss='sparse_categorical_crossentropy',
                optimizer=optimizer,
                metrics=(['sparse_categorical_accuracy']
                         if flags_obj.report_accuracy_metrics else None),
                run_eagerly=flags_obj.run_eagerly)

    train_epochs = flags_obj.train_epochs

    callbacks = common.get_callbacks(
        steps_per_epoch=steps_per_epoch,
        learning_rate_schedule_fn=learning_rate_schedule_fn,
        pruning_method=flags_obj.pruning_method,
        enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
        model_dir=flags_obj.model_dir)

    # if mutliple epochs, ignore the train_steps flag.
    if train_epochs <= 1 and flags_obj.train_steps:
        steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
        train_epochs = 1

    num_eval_steps = (imagenet_preprocessing.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        # Only build the training graph. This reduces memory usage introduced by
        # control flow ops in layers that have different implementations for
        # training and inference (e.g., batch norm).
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    if not strategy and flags_obj.explicit_gpu_placement:
        # TODO(b/135607227): Add device scope automatically in Keras training loop
        # when not using distribition strategy.
        no_dist_strat_device = tf.device('/device:GPU:0')
        no_dist_strat_device.__enter__()

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)

    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

    if flags_obj.pruning_method:
        model = tfmot.sparsity.keras.strip_pruning(model)
    if flags_obj.enable_checkpoint_and_export:
        if dtype == tf.bfloat16:
            logging.warning(
                'Keras model.save does not support bfloat16 dtype.')
        else:
            # Keras model.save assumes a float32 input designature.
            export_path = os.path.join(flags_obj.model_dir, 'saved_model')
            model.save(export_path, include_optimizer=False)

    if not strategy and flags_obj.explicit_gpu_placement:
        no_dist_strat_device.__exit__()

    stats = common.build_stats(history, eval_output, callbacks)
    return stats
예제 #11
0
def run(flags_obj):
    """Run ResNet Cifar-10 training and eval loop using native Keras APIs.

    Args:
      flags_obj: An object containing parsed flag values.

    Raises:
      ValueError: If fp16 is passed as it is not currently supported.

    Returns:
      Dictionary of training and eval stats.
    """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        common.set_gpu_thread_mode_and_count(flags_obj)
    common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'fp16':
        raise ValueError(
            'dtype fp16 is not supported in Keras. Use the default '
            'value(fp32).')

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=distribution_utils.configure_cluster(),
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = common.get_synth_input_fn(
            height=cifar_preprocessing.HEIGHT,
            width=cifar_preprocessing.WIDTH,
            num_channels=cifar_preprocessing.NUM_CHANNELS,
            num_classes=cifar_preprocessing.NUM_CLASSES,
            dtype=flags_core.get_tf_dtype(flags_obj),
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = cifar_preprocessing.input_fn

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=cifar_preprocessing.parse_record,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        # Setting drop_remainder to avoid the partial batch logic in normalization
        # layer, which triggers tf.where and leads to extra memory copy of input
        # sizes between host and GPU.
        drop_remainder=(not flags_obj.enable_get_next_as_optional))

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=flags_obj.batch_size,
            num_epochs=flags_obj.train_epochs,
            parse_record_fn=cifar_preprocessing.parse_record)

    with strategy_scope:
        optimizer = common.get_optimizer(learning_rate=0.1 * hvd.size())
        # Horovod: add Horovod DistributedOptimizer.
        optimizer = hvd.DistributedOptimizer(optimizer)

        model = resnet_cifar_model.resnet56(
            classes=cifar_preprocessing.NUM_CLASSES)

        # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
        # a valid arg for this model. Also remove as a valid flag.
        if flags_obj.force_v2_in_keras_compile is not None:
            model.compile(
                loss='categorical_crossentropy',
                optimizer=optimizer,
                metrics=(['categorical_accuracy']
                         if flags_obj.report_accuracy_metrics else None),
                #run_eagerly=flags_obj.run_eagerly,
                experimental_run_tf_function=False)
        else:
            model.compile(
                loss='categorical_crossentropy',
                optimizer=optimizer,
                metrics=(['categorical_accuracy']
                         if flags_obj.report_accuracy_metrics else None),
                #run_eagerly=flags_obj.run_eagerly,
                experimental_run_tf_function=False)

    callbacks = common.get_callbacks(learning_rate_schedule,
                                     cifar_preprocessing.NUM_IMAGES['train'])

    train_steps = cifar_preprocessing.NUM_IMAGES[
        'train'] // flags_obj.batch_size
    train_epochs = flags_obj.train_epochs

    if flags_obj.train_steps:
        train_steps = min(flags_obj.train_steps, train_steps)
        train_epochs = 1

    num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    if not strategy and flags_obj.explicit_gpu_placement:
        # TODO(b/135607227): Add device scope automatically in Keras training loop
        # when not using distribition strategy.
        no_dist_strat_device = tf.device('/device:GPU:0')
        no_dist_strat_device.__enter__()

    callbacks = [
        # Horovod: broadcast initial variable states from rank 0 to all other processes.
        # This is necessary to ensure consistent initialization of all workers when
        # training is started with random weights or restored from a checkpoint.
        hvd.callbacks.BroadcastGlobalVariablesCallback(0),

        # Horovod: average metrics among workers at the end of every epoch.
        #
        # Note: This callback must be in the list before the ReduceLROnPlateau,
        # TensorBoard or other metrics-based callbacks.
        hvd.callbacks.MetricAverageCallback(),

        # Horovod: using `lr = 1.0 * hvd.size()` from the very beginning leads to worse final
        # accuracy. Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` during
        # the first three epochs. See https://arxiv.org/abs/1706.02677 for details.
        hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=3, verbose=1),
    ]

    # Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
    if hvd.rank() == 0:
        callbacks.append(
            tf.keras.callbacks.ModelCheckpoint('./checkpoint-{epoch}.h5'))

    # Horovod: write logs on worker 0.
    verbose = 1 if hvd.rank() == 0 else 0

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=verbose)
    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

    if not strategy and flags_obj.explicit_gpu_placement:
        no_dist_strat_device.__exit__()

    stats = common.build_stats(history, eval_output, callbacks)
    return stats
예제 #12
0
def run(flags_obj):
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  keras_utils.set_session_config(
      enable_eager=flags_obj.enable_eager,
      enable_xla=flags_obj.enable_xla)

  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
    common.set_gpu_thread_mode_and_count(flags_obj)
  common.set_cudnn_batchnorm_mode()

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == tf.float16:
    loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
    policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
        'mixed_float16', loss_scale=loss_scale)
    tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
    if not keras_utils.is_v2_0():
      raise ValueError('--dtype=fp16 is not supported in TensorFlow 1.')
  elif dtype == tf.bfloat16:
    policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
        'mixed_bfloat16')
    tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  # Configures cluster spec for distribution strategy.
  num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                                     flags_obj.task_index)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
      num_workers=num_workers,
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs,
      tpu_address=flags_obj.tpu)

  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )

  strategy_scope = distribution_utils.get_strategy_scope(strategy)

  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
    distribution_utils.set_up_synthetic_data()
    input_fn = common.get_synth_input_fn(
        height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_preprocessing.NUM_CHANNELS,
        num_classes=imagenet_preprocessing.NUM_CLASSES,
        dtype=dtype,
        drop_remainder=True)
  else:
    distribution_utils.undo_set_up_synthetic_data()
    input_fn = imagenet_preprocessing.input_fn

  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=imagenet_preprocessing.parse_record,
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
      dtype=dtype,
      drop_remainder=drop_remainder,
      tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
      training_dataset_cache=flags_obj.training_dataset_cache,
  )

  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=imagenet_preprocessing.parse_record,
        dtype=dtype,
        drop_remainder=drop_remainder)

  lr_schedule = 0.1
  if flags_obj.use_tensor_lr:
    lr_schedule = common.PiecewiseConstantDecayWithWarmup(
        batch_size=flags_obj.batch_size,
        epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
        warmup_epochs=common.LR_SCHEDULE[0][1],
        boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
        multipliers=list(p[0] for p in common.LR_SCHEDULE),
        compute_lr_on_cpu=True)

  with strategy_scope:
    optimizer = common.get_optimizer(lr_schedule)
    if flags_obj.fp16_implementation == 'graph_rewrite':
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          optimizer)

    # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
    if flags_obj.use_trivial_model:
      model = trivial_model.trivial_model(
          imagenet_preprocessing.NUM_CLASSES)
    else:
      print('data format: ', tf.keras.backend.image_data_format())
      print('num_classes: ', imagenet_preprocessing.NUM_CLASSES)
      features = tf.keras.applications.ResNet50V2(include_top=False, weights='imagenet', pooling='avg', input_shape=(224,224,3)) # , input_shape=(224,224,3)
      x = tf.keras.layers.Dense(imagenet_preprocessing.NUM_CLASSES, name='fc1000')(features.output)
      x = tf.keras.layers.Activation('softmax', dtype='float32')(x)
      model = tf.keras.models.Model(features.input, x, name='resnet50')
      
      model.summary()

      #model = resnet_model.resnet50(
      #    num_classes=imagenet_preprocessing.NUM_CLASSES)

    # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
    # a valid arg for this model. Also remove as a valid flag.
    if flags_obj.force_v2_in_keras_compile is not None:
      model.compile(
          loss='sparse_categorical_crossentropy',
          optimizer=optimizer,
          metrics=(['sparse_categorical_accuracy']
                   if flags_obj.report_accuracy_metrics else None),
          run_eagerly=flags_obj.run_eagerly,
          experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)
    else:
      model.compile(
          loss='sparse_categorical_crossentropy',
          optimizer=optimizer,
          metrics=(['sparse_categorical_accuracy']
                   if flags_obj.report_accuracy_metrics else None),
          run_eagerly=flags_obj.run_eagerly)

  callbacks = common.get_callbacks(
      common.learning_rate_schedule, imagenet_preprocessing.NUM_IMAGES['train'])
  if flags_obj.enable_checkpoint_and_export:
    ckpt_full_path = os.path.join(flags_obj.model_dir, 'model.ckpt-{epoch:04d}')
    callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_full_path,
                                                        save_weights_only=True))
  train_steps = (
      imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
  train_epochs = flags_obj.train_epochs

  # if mutliple epochs, ignore the train_steps flag.
  if train_epochs <= 1 and flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

  num_eval_steps = (
      imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size)

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None
    validation_data = None

  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
    # when not using distribition strategy.
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
                      callbacks=callbacks,
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      validation_freq=flags_obj.epochs_between_evals,
                      verbose=2)
  if flags_obj.enable_checkpoint_and_export:
    if dtype == tf.bfloat16:
      logging.warning("Keras model.save does not support bfloat16 dtype.")
    else:
      # Keras model.save assumes a float32 input designature.
      export_path = os.path.join(flags_obj.model_dir, 'saved_model')
      model.save(export_path, include_optimizer=False)

  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

  stats = common.build_stats(history, eval_output, callbacks)
  return stats
예제 #13
0
def run(flags_obj):
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
  # Eager is default in tf 2.0 and should not be toggled
  if keras_common.is_v2_0():
    keras_common.set_config_v2()
  else:
    config = keras_common.get_config_proto_v1()
    if flags_obj.enable_eager:
      tf.compat.v1.enable_eager_execution(config=config)
    else:
      sess = tf.Session(config=config)
      tf.keras.backend.set_session(sess)

  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
    keras_common.set_gpu_thread_mode_and_count(flags_obj)
  if flags_obj.data_prefetch_with_slack:
    keras_common.data_prefetch_with_slack()
  keras_common.set_cudnn_batchnorm_mode()

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == 'float16':
    policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
    tf.keras.mixed_precision.experimental.set_policy(policy)

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
      num_workers=distribution_utils.configure_cluster(),
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs)

  strategy_scope = distribution_utils.get_strategy_scope(strategy)

  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
    distribution_utils.set_up_synthetic_data()
    input_fn = keras_common.get_synth_input_fn(
        height=imagenet_main.DEFAULT_IMAGE_SIZE,
        width=imagenet_main.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_main.NUM_CHANNELS,
        num_classes=imagenet_main.NUM_CLASSES,
        dtype=dtype,
        drop_remainder=True)
  else:
    distribution_utils.undo_set_up_synthetic_data()
    input_fn = imagenet_main.input_fn

  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras,
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
      dtype=dtype,
      drop_remainder=drop_remainder)

  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=parse_record_keras,
        dtype=dtype,
        drop_remainder=drop_remainder)

  lr_schedule = 0.1
  if flags_obj.use_tensor_lr:
    lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup(
        batch_size=flags_obj.batch_size,
        epoch_size=imagenet_main.NUM_IMAGES['train'],
        warmup_epochs=LR_SCHEDULE[0][1],
        boundaries=list(p[1] for p in LR_SCHEDULE[1:]),
        multipliers=list(p[0] for p in LR_SCHEDULE),
        compute_lr_on_cpu=True)

  with strategy_scope:
    optimizer = keras_common.get_optimizer(lr_schedule)
    if dtype == 'float16':
      # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
      # can be enabled with a single line of code.
      optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
          optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))

    if flags_obj.enable_xla and not flags_obj.enable_eager:
      # TODO(b/129861005): Fix OOM issue in eager mode when setting
      # `batch_size` in keras.Input layer.
      if strategy and strategy.num_replicas_in_sync > 1:
        # TODO(b/129791381): Specify `input_layer_batch_size` value in
        # DistributionStrategy multi-replica case.
        input_layer_batch_size = None
      else:
        input_layer_batch_size = flags_obj.batch_size
    else:
      input_layer_batch_size = None

    if flags_obj.use_trivial_model:
      model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES, dtype)
    else:
      model = resnet_model.resnet50(
          num_classes=imagenet_main.NUM_CLASSES,
          dtype=dtype,
          batch_size=input_layer_batch_size)

    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer=optimizer,
                  metrics=(['sparse_categorical_accuracy']
                           if flags_obj.report_accuracy_metrics else None),
                  cloning=flags_obj.clone_model_in_keras_dist_strat)

  callbacks = keras_common.get_callbacks(
      learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])

  train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

  num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
                    flags_obj.batch_size)

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
    tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None
    validation_data = None

  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
                      callbacks=callbacks,
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      validation_freq=flags_obj.epochs_between_evals,
                      verbose=2)

  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)
  stats = keras_common.build_stats(history, eval_output, callbacks)
  return stats
예제 #14
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
  """
    if flags_obj.enable_eager:
        tf.compat.v1.enable_eager_execution()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'fp16':
        raise ValueError(
            'dtype fp16 is not supported in Keras. Use the default '
            'value(fp32).')

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    # pylint: disable=protected-access
    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = keras_common.get_synth_input_fn(
            height=imagenet_main.DEFAULT_IMAGE_SIZE,
            width=imagenet_main.DEFAULT_IMAGE_SIZE,
            num_channels=imagenet_main.NUM_CHANNELS,
            num_classes=imagenet_main.NUM_CLASSES,
            dtype=flags_core.get_tf_dtype(flags_obj))
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = imagenet_main.input_fn

    train_input_dataset = input_fn(is_training=True,
                                   data_dir=flags_obj.data_dir,
                                   batch_size=flags_obj.batch_size,
                                   num_epochs=flags_obj.train_epochs,
                                   parse_record_fn=parse_record_keras)

    eval_input_dataset = input_fn(is_training=False,
                                  data_dir=flags_obj.data_dir,
                                  batch_size=flags_obj.batch_size,
                                  num_epochs=flags_obj.train_epochs,
                                  parse_record_fn=parse_record_keras)

    strategy = distribution_utils.get_distribution_strategy(
        num_gpus=flags_obj.num_gpus,
        turn_off_distribution_strategy=flags_obj.turn_off_distribution_strategy
    )

    strategy_scope = keras_common.get_strategy_scope(strategy)

    with strategy_scope:
        optimizer = keras_common.get_optimizer()
        model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES)

        model.compile(loss='sparse_categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=['sparse_categorical_accuracy'])

    time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
        learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])

    train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
    train_epochs = flags_obj.train_epochs

    if flags_obj.train_steps:
        train_steps = min(flags_obj.train_steps, train_steps)
        train_epochs = 1

    num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        # Only build the training graph. This reduces memory usage introduced by
        # control flow ops in layers that have different implementations for
        # training and inference (e.g., batch norm).
        tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    history = model.fit(
        train_input_dataset,
        epochs=train_epochs,
        steps_per_epoch=train_steps,
        callbacks=[time_callback, lr_callback, tensorboard_callback],
        validation_steps=num_eval_steps,
        validation_data=validation_data,
        validation_freq=flags_obj.epochs_between_evals,
        verbose=2)

    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)
    stats = keras_common.build_stats(history, eval_output, time_callback)
    return stats
예제 #15
0
def run(flags_obj):
    """Run ResNet Cifar-10 training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_utils.set_gpu_thread_mode_and_count(
            per_gpu_thread_count=flags_obj.per_gpu_thread_count,
            gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
            num_gpus=flags_obj.num_gpus,
            datasets_num_private_threads=flags_obj.datasets_num_private_threads
        )
    common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'fp16':
        raise ValueError(
            'dtype fp16 is not supported in Keras. Use the default '
            'value(fp32).')

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=distribution_utils.configure_cluster(),
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = common.get_synth_input_fn(
            height=cifar_preprocessing.HEIGHT,
            width=cifar_preprocessing.WIDTH,
            num_channels=cifar_preprocessing.NUM_CHANNELS,
            num_classes=cifar_preprocessing.NUM_CLASSES,
            dtype=flags_core.get_tf_dtype(flags_obj),
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = cifar_preprocessing.input_fn

    #train_input_dataset = input_fn(
    #    is_training=True,
    #    data_dir=flags_obj.data_dir,
    #    batch_size=flags_obj.batch_size,
    #    num_epochs=flags_obj.train_epochs,
    #    parse_record_fn=cifar_preprocessing.parse_record,
    #    datasets_num_private_threads=flags_obj.datasets_num_private_threads,
    #    dtype=dtype,
    #    # Setting drop_remainder to avoid the partial batch logic in normalization
    #    # layer, which triggers tf.where and leads to extra memory copy of input
    #    # sizes between host and GPU.
    #    drop_remainder=(not flags_obj.enable_get_next_as_optional))

    # eval_input_dataset = None
    # if not flags_obj.skip_eval:
    #   eval_input_dataset = input_fn(
    #       is_training=False,
    #       data_dir=flags_obj.data_dir,
    #       batch_size=flags_obj.batch_size,
    #       num_epochs=flags_obj.train_epochs,
    #       parse_record_fn=cifar_preprocessing.parse_record)

    (x_train, y_train), (x_test,
                         y_test) = tf.keras.datasets.cifar10.load_data()
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255
    x_test /= 255
    y_train = tf.keras.utils.to_categorical(y_train, num_classes)
    y_test = tf.keras.utils.to_categorical(y_test, num_classes)

    # optimizer = common.get_optimizer()

    opt = tf.keras.optimizers.SGD(learning_rate=0.1)

    logging.info(opt.__dict__)
    optimizer = SynchronousSGDOptimizer(opt, use_locking=True)
    optimizer._hyper = opt._hyper

    logging.info(optimizer.__dict__)

    model = Conv4_model(x_train, num_classes)

    # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
    # a valid arg for this model. Also remove as a valid flag.
    if flags_obj.force_v2_in_keras_compile is not None:
        model.compile(
            loss='categorical_crossentropy',
            optimizer=optimizer,
            metrics=(['accuracy']),
            run_eagerly=flags_obj.run_eagerly,
            experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)
    else:
        model.compile(loss='categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=(['accuracy']),
                      run_eagerly=flags_obj.run_eagerly)

    cluster_size = current_cluster_size()
    steps_per_epoch = (cifar_preprocessing.NUM_IMAGES['train'] //
                       flags_obj.batch_size)
    steps_per_epoch = steps_per_epoch // cluster_size
    train_epochs = flags_obj.train_epochs

    callbacks = common.get_callbacks(steps_per_epoch, current_rank(),
                                     cluster_size, learning_rate_schedule)
    callbacks.append(BroadcastGlobalVariablesCallback())

    if flags_obj.train_steps:
        steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)

    num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    # validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    tf.compat.v1.logging.info(x_train.shape)
    history = model.fit(x_train,
                        y_train,
                        batch_size=flags_obj.batch_size,
                        epochs=train_epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=(x_test, y_test),
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)
    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate((x_test, y_test),
                                     steps=num_eval_steps,
                                     verbose=2)
    stats = common.build_stats(history, eval_output, callbacks)
    return stats
예제 #16
0
def run(flags_obj):
    """Run ResNet Cifar-10 training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        common.set_gpu_thread_mode_and_count(flags_obj)
    common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'fp16':
        raise ValueError(
            'dtype fp16 is not supported in Keras. Use the default '
            'value(fp32).')

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = common.get_synth_input_fn(
            height=cifar_preprocessing.HEIGHT,
            width=cifar_preprocessing.WIDTH,
            num_channels=cifar_preprocessing.NUM_CHANNELS,
            num_classes=cifar_preprocessing.NUM_CLASSES,
            dtype=flags_core.get_tf_dtype(flags_obj),
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = cifar_preprocessing.input_fn

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=cifar_preprocessing.parse_record,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        # Setting drop_remainder to avoid the partial batch logic in normalization
        # layer, which triggers tf.where and leads to extra memory copy of input
        # sizes between host and GPU.
        drop_remainder=(not flags_obj.enable_get_next_as_optional))

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=flags_obj.batch_size,
            num_epochs=flags_obj.train_epochs,
            parse_record_fn=cifar_preprocessing.parse_record)

    # with strategy_scope:
    optimizer = common.get_optimizer()
    model = resnet_cifar_model.resnet56(
        classes=cifar_preprocessing.NUM_CLASSES)

    # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
    # a valid arg for this model. Also remove as a valid flag.
    if flags_obj.force_v2_in_keras_compile is not None:
        model.compile(
            loss='categorical_crossentropy',
            optimizer=optimizer,
            metrics=(['categorical_accuracy']
                     if flags_obj.report_accuracy_metrics else None),
            run_eagerly=flags_obj.run_eagerly,
            experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)
    else:
        model.compile(loss='categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=(['categorical_accuracy']
                               if flags_obj.report_accuracy_metrics else None),
                      run_eagerly=flags_obj.run_eagerly)

    callbacks = common.get_callbacks(learning_rate_schedule,
                                     cifar_preprocessing.NUM_IMAGES['train'])

    train_steps = cifar_preprocessing.NUM_IMAGES[
        'train'] // flags_obj.batch_size
    train_epochs = flags_obj.train_epochs

    if flags_obj.train_steps:
        train_steps = min(flags_obj.train_steps, train_steps)
        train_epochs = 1

    num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    # _callback = tf.keras.callbacks.ModelCheckpoint(filepath=flags_obj.model_dir,
    #                                                save_weights_only=True,
    #                                                verbose=1)
    #  callbacks.append(cp_callback)

    model.load_weights(flags_obj.model_dir)
    '''
    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)
    '''
    # status = checkpoint.restore(tf.train.latest_checkpoint("/calc/stoeckl/binary_populations/checkpoints/"))
    # status.expect_partial()
    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)
    # stats = common.build_stats(history, eval_output, callbacks)
    print("Result")
    print(f"loss: {eval_output[0]}, acc: {eval_output[1]}")
    stats = None
    return stats