Example #1
0
 def input_fn_eval():
   return input_function(
       is_training=False,
       data_dir=flags_obj.data_dir,
       batch_size=distribution_utils.per_device_batch_size(
           flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
       num_epochs=1,
       dtype=flags_core.get_tf_dtype(flags_obj))
Example #2
0
 def input_fn_train(num_epochs):
   return input_function(
       is_training=True,
       data_dir=flags_obj.data_dir,
       batch_size=distribution_utils.per_device_batch_size(
           flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
       num_epochs=num_epochs,
       dtype=flags_core.get_tf_dtype(flags_obj),
       datasets_num_private_threads=flags_obj.datasets_num_private_threads,
       num_parallel_batches=flags_obj.datasets_num_parallel_batches)
def run_cifar(flags_obj):
  """Run ResNet CIFAR-10 training and eval loop.

  Args:
    flags_obj: An object containing parsed flag values.
  """
  input_function = (flags_obj.use_synthetic_data and
                    get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or
                    input_fn)
  resnet_run_loop.resnet_main(
      flags_obj, cifar10_model_fn, input_function, DATASET_NAME,
      shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS])
Example #4
0
def run_imagenet(flags_obj):
  """Run ResNet ImageNet training and eval loop.

  Args:
    flags_obj: An object containing parsed flag values.
  """
  input_function = (flags_obj.use_synthetic_data and
                    get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or
                    input_fn)

  resnet_run_loop.resnet_main(
      flags_obj, imagenet_model_fn, input_function, DATASET_NAME,
      shape=[DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE, NUM_CHANNELS])
Example #5
0
  def test_parse_dtype_info(self):
    for dtype_str, tf_dtype, loss_scale in [["fp16", tf.float16, 128],
                                            ["fp32", tf.float32, 1]]:
      flags_core.parse_flags([__file__, "--dtype", dtype_str])

      self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf_dtype)
      self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), loss_scale)

      flags_core.parse_flags(
          [__file__, "--dtype", dtype_str, "--loss_scale", "5"])

      self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), 5)

    with self.assertRaises(SystemExit):
      flags_core.parse_flags([__file__, "--dtype", "int8"])
Example #6
0
def run_cifar(flags_obj):
  """Run ResNet CIFAR-10 training and eval loop.

  Args:
    flags_obj: An object containing parsed flag values.

  Returns:
    Dictionary of results. Including final accuracy.
  """
  if flags_obj.image_bytes_as_serving_input:
    tf.logging.fatal('--image_bytes_as_serving_input cannot be set to True '
                     'for CIFAR. This flag is only applicable to ImageNet.')
    return

  input_function = (flags_obj.use_synthetic_data and
                    get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or
                    input_fn)
  result = resnet_run_loop.resnet_main(
      flags_obj, cifar10_model_fn, input_function, DATASET_NAME,
      shape=[HEIGHT, WIDTH, NUM_CHANNELS])

  return result
Example #7
0
def run(flags_obj):
  """Run ResNet Cifar-10 training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  if flags_obj.enable_eager:
    tf.enable_eager_execution()

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == 'fp16':
    raise ValueError('dtype fp16 is not supported in Keras. Use the default '
                     'value(fp32).')

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  if flags_obj.use_synthetic_data:
    input_fn = keras_common.get_synth_input_fn(
        height=cifar_main.HEIGHT,
        width=cifar_main.WIDTH,
        num_channels=cifar_main.NUM_CHANNELS,
        num_classes=cifar_main.NUM_CLASSES,
        dtype=flags_core.get_tf_dtype(flags_obj))
  else:
    input_fn = cifar_main.input_fn

  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras)

  eval_input_dataset = input_fn(
      is_training=False,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras)

  strategy = distribution_utils.get_distribution_strategy(
      num_gpus=flags_obj.num_gpus,
      turn_off_distribution_strategy=flags_obj.turn_off_distribution_strategy)

  strategy_scope = keras_common.get_strategy_scope(strategy)

  with strategy_scope:
    optimizer = keras_common.get_optimizer()
    model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES)

    model.compile(loss='categorical_crossentropy',
                  optimizer=optimizer,
                  metrics=['categorical_accuracy'])

  time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
      learning_rate_schedule, cifar_main.NUM_IMAGES['train'])

  train_steps = cifar_main.NUM_IMAGES['train'] // flags_obj.batch_size
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

  num_eval_steps = (cifar_main.NUM_IMAGES['validation'] //
                    flags_obj.batch_size)

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None
    validation_data = None

  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
                      callbacks=[
                          time_callback,
                          lr_callback,
                          tensorboard_callback
                      ],
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      verbose=1)
  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=1)
  stats = keras_common.build_stats(history, eval_output, time_callback)
  return stats
Example #8
0
    def __init__(self, flags_obj, time_callback, epoch_steps):
        standard_runnable.StandardTrainable.__init__(
            self, flags_obj.use_tf_while_loop, flags_obj.use_tf_function)
        standard_runnable.StandardEvaluable.__init__(self,
                                                     flags_obj.use_tf_function)

        self.strategy = tf.distribute.get_strategy()
        self.flags_obj = flags_obj
        self.dtype = flags_core.get_tf_dtype(flags_obj)
        self.time_callback = time_callback

        # Input pipeline related
        batch_size = flags_obj.batch_size
        if batch_size % self.strategy.num_replicas_in_sync != 0:
            raise ValueError(
                'Batch size must be divisible by number of replicas : {}'.
                format(self.strategy.num_replicas_in_sync))

        # As auto rebatching is not supported in
        # `experimental_distribute_datasets_from_function()` API, which is
        # required when cloning dataset to multiple workers in eager mode,
        # we use per-replica batch size.
        self.batch_size = int(batch_size / self.strategy.num_replicas_in_sync)

        if self.flags_obj.use_synthetic_data:
            self.input_fn = common.get_synth_input_fn(
                height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
                width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
                num_channels=imagenet_preprocessing.NUM_CHANNELS,
                num_classes=imagenet_preprocessing.NUM_CLASSES,
                dtype=self.dtype,
                drop_remainder=True)
        else:
            self.input_fn = imagenet_preprocessing.input_fn

        self.model = resnet_model.resnet50(
            num_classes=imagenet_preprocessing.NUM_CLASSES,
            batch_size=flags_obj.batch_size,
            use_l2_regularizer=not flags_obj.single_l2_loss_op)

        lr_schedule = common.PiecewiseConstantDecayWithWarmup(
            batch_size=flags_obj.batch_size,
            epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
            warmup_epochs=common.LR_SCHEDULE[0][1],
            boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
            multipliers=list(p[0] for p in common.LR_SCHEDULE),
            compute_lr_on_cpu=True)
        self.optimizer = common.get_optimizer(lr_schedule)
        # Make sure iterations variable is created inside scope.
        self.global_step = self.optimizer.iterations

        use_graph_rewrite = flags_obj.fp16_implementation == 'graph_rewrite'
        if use_graph_rewrite and not flags_obj.use_tf_function:
            raise ValueError('--fp16_implementation=graph_rewrite requires '
                             '--use_tf_function to be true')
        self.optimizer = performance.configure_optimizer(
            self.optimizer,
            use_float16=self.dtype == tf.float16,
            use_graph_rewrite=use_graph_rewrite,
            loss_scale=flags_core.get_loss_scale(flags_obj,
                                                 default_for_fp16=128))

        self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
        self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'train_accuracy', dtype=tf.float32)
        self.test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
        self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'test_accuracy', dtype=tf.float32)

        self.checkpoint = tf.train.Checkpoint(model=self.model,
                                              optimizer=self.optimizer)

        # Handling epochs.
        self.epoch_steps = epoch_steps
        self.epoch_helper = utils.EpochHelper(epoch_steps, self.global_step)
def run(flags_obj):
  """Run ResNet Cifar-10 training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  keras_utils.set_session_config(
      enable_eager=flags_obj.enable_eager,
      enable_xla=flags_obj.enable_xla)

  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
    common.set_gpu_thread_mode_and_count(flags_obj)
  common.set_cudnn_batchnorm_mode()

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == 'fp16':
    raise ValueError('dtype fp16 is not supported in Keras. Use the default '
                     'value(fp32).')

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
      num_workers=distribution_utils.configure_cluster(),
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs)

  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )

  strategy_scope = distribution_utils.get_strategy_scope(strategy)

  if flags_obj.use_synthetic_data:
    distribution_utils.set_up_synthetic_data()
    input_fn = common.get_synth_input_fn(
        height=cifar_preprocessing.HEIGHT,
        width=cifar_preprocessing.WIDTH,
        num_channels=cifar_preprocessing.NUM_CHANNELS,
        num_classes=cifar_preprocessing.NUM_CLASSES,
        dtype=flags_core.get_tf_dtype(flags_obj),
        drop_remainder=True)
  else:
    distribution_utils.undo_set_up_synthetic_data()
    input_fn = cifar_preprocessing.input_fn

  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=cifar_preprocessing.parse_record,
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
      dtype=dtype,
      # Setting drop_remainder to avoid the partial batch logic in normalization
      # layer, which triggers tf.where and leads to extra memory copy of input
      # sizes between host and GPU.
      drop_remainder=(not flags_obj.enable_get_next_as_optional))

  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=cifar_preprocessing.parse_record)

  with strategy_scope:
    optimizer = common.get_optimizer()
    model = resnet_cifar_model.resnet56(classes=cifar_preprocessing.NUM_CLASSES)

    # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
    # a valid arg for this model. Also remove as a valid flag.
    if flags_obj.force_v2_in_keras_compile is not None:
      model.compile(
          loss='categorical_crossentropy',
          optimizer=optimizer,
          metrics=(['categorical_accuracy']
                   if flags_obj.report_accuracy_metrics else None),
          run_eagerly=flags_obj.run_eagerly,
          experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)
    else:
      model.compile(
          loss='categorical_crossentropy',
          optimizer=optimizer,
          metrics=(['categorical_accuracy']
                   if flags_obj.report_accuracy_metrics else None),
          run_eagerly=flags_obj.run_eagerly)

  callbacks = common.get_callbacks(
      learning_rate_schedule, cifar_preprocessing.NUM_IMAGES['train'])

  train_steps = cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

  num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
                    flags_obj.batch_size)

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None
    validation_data = None

  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
    # when not using distribition strategy.
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
                      callbacks=callbacks,
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      validation_freq=flags_obj.epochs_between_evals,
                      verbose=2)
  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

  stats = common.build_stats(history, eval_output, callbacks)
  return stats
Example #10
0
    def __init__(self, flags_obj):
        """Init function of TransformerMain.

    Args:
      flags_obj: Object containing parsed flag values, i.e., FLAGS.

    Raises:
      ValueError: if not using static batch for input data on TPU.
    """
        self.flags_obj = flags_obj
        self.predict_model = None

        # Add flag-defined parameters to params object
        num_gpus = flags_core.get_num_gpus(flags_obj)
        self.params = params = misc.get_model_params(flags_obj.param_set,
                                                     num_gpus)

        params["num_gpus"] = num_gpus
        params["use_ctl"] = flags_obj.use_ctl
        params["data_dir"] = flags_obj.data_dir
        params["model_dir"] = flags_obj.model_dir
        params["static_batch"] = flags_obj.static_batch
        params["max_length"] = flags_obj.max_length
        params["decode_batch_size"] = flags_obj.decode_batch_size
        params["decode_max_length"] = flags_obj.decode_max_length
        params["padded_decode"] = flags_obj.padded_decode
        params["max_io_parallelism"] = (flags_obj.num_parallel_calls
                                        or tf.data.experimental.AUTOTUNE)

        params["use_synthetic_data"] = flags_obj.use_synthetic_data
        params["batch_size"] = flags_obj.batch_size or params[
            "default_batch_size"]
        params["repeat_dataset"] = None
        params["dtype"] = flags_core.get_tf_dtype(flags_obj)
        params["enable_tensorboard"] = flags_obj.enable_tensorboard
        params[
            "enable_metrics_in_training"] = flags_obj.enable_metrics_in_training
        params["steps_between_evals"] = flags_obj.steps_between_evals
        params["enable_checkpointing"] = flags_obj.enable_checkpointing
        params["save_weights_only"] = flags_obj.save_weights_only

        self.distribution_strategy = distribute_utils.get_distribution_strategy(
            distribution_strategy=flags_obj.distribution_strategy,
            num_gpus=num_gpus,
            all_reduce_alg=flags_obj.all_reduce_alg,
            num_packs=flags_obj.num_packs,
            tpu_address=flags_obj.tpu or "")
        if self.use_tpu:
            params[
                "num_replicas"] = self.distribution_strategy.num_replicas_in_sync
        else:
            logging.info("Running transformer with num_gpus = %d", num_gpus)

        if self.distribution_strategy:
            logging.info("For training, using distribution strategy: %s",
                         self.distribution_strategy)
        else:
            logging.info("Not using any distribution strategy.")

        performance.set_mixed_precision_policy(
            params["dtype"],
            flags_core.get_loss_scale(flags_obj, default_for_fp16="dynamic"))
Example #11
0
def resnet_main(flags_obj,
                model_function,
                input_function,
                dataset_name,
                shape=None):
    """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.
  """

    model_helpers.apply_clean(flags.FLAGS)

    # Ensures flag override logic is only executed if explicitly triggered.
    if flags_obj.tf_gpu_thread_mode:
        override_flags_and_set_envars_for_gpu_thread_pool(flags_obj)

    # Creates session config. allow_soft_placement = True, is required for
    # multi-GPU and is not harmful for other modes.
    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

    # Creates a `RunConfig` that checkpoints every 24 hours which essentially
    # results in checkpoints determined only by `epochs_between_evals`.
    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config,
                                        save_checkpoints_secs=60 * 60 * 24)

    # Initializes model with all but the dense layer from pretrained ResNet.
    if flags_obj.pretrained_model_checkpoint_path is not None:
        warm_start_settings = tf.estimator.WarmStartSettings(
            flags_obj.pretrained_model_checkpoint_path,
            vars_to_warm_start='^(?!.*dense)')
    else:
        warm_start_settings = None

    classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        warm_start_from=warm_start_settings,
        params={
            'resnet_size': int(flags_obj.resnet_size),
            'data_format': flags_obj.data_format,
            'batch_size': flags_obj.batch_size,
            'resnet_version': int(flags_obj.resnet_version),
            'loss_scale': flags_core.get_loss_scale(flags_obj),
            'dtype': flags_core.get_tf_dtype(flags_obj),
            'fine_tune': flags_obj.fine_tune
        })

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'resnet_size': flags_obj.resnet_size,
        'resnet_version': flags_obj.resnet_version,
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
    }
    if flags_obj.use_synthetic_data:
        dataset_name = dataset_name + '-synthetic'

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info('resnet',
                                  dataset_name,
                                  run_params,
                                  test_id=flags_obj.benchmark_test_id)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    def input_fn_train(num_epochs):
        return input_function(
            is_training=True,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=num_epochs,
            dtype=flags_core.get_tf_dtype(flags_obj),
            datasets_num_private_threads=flags_obj.
            datasets_num_private_threads,
            num_parallel_batches=flags_obj.datasets_num_parallel_batches)

    def input_fn_eval():
        return input_function(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=1,
            dtype=flags_core.get_tf_dtype(flags_obj))

    if flags_obj.eval_only or not flags_obj.train_epochs:
        # If --eval_only is set, perform a single loop with zero train epochs.
        schedule, n_loops = [0], 1
    else:
        # Compute the number of times to loop while training. All but the last
        # pass will train for `epochs_between_evals` epochs, while the last will
        # train for the number needed to reach `training_epochs`. For instance if
        #   train_epochs = 25 and epochs_between_evals = 10
        # schedule will be set to [10, 10, 5]. That is to say, the loop will:
        #   Train for 10 epochs and then evaluate.
        #   Train for another 10 epochs and then evaluate.
        #   Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
        n_loops = math.ceil(flags_obj.train_epochs /
                            flags_obj.epochs_between_evals)
        schedule = [
            flags_obj.epochs_between_evals for _ in range(int(n_loops))
        ]
        schedule[-1] = flags_obj.train_epochs - sum(
            schedule[:-1])  # over counting.


# generate json file under current directory
    hooks = [
        tf.train.ProfilerHook(output_dir='.', save_secs=600, show_memory=False)
    ]
    for cycle_index, num_train_epochs in enumerate(schedule):
        tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops))

        if num_train_epochs:
            classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
                             hooks=hooks,
                             max_steps=flags_obj.max_train_steps)

        tf.logging.info('Starting to evaluate.')

        # flags_obj.max_train_steps is generally associated with testing and
        # profiling. As a result it is frequently called with synthetic data, which
        # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
        # eval (which is generally unimportant in those circumstances) to terminate.
        # Note that eval will run for max_train_steps each loop, regardless of the
        # global_step count.
        eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                           steps=flags_obj.max_train_steps)

        benchmark_logger.log_evaluation_result(eval_results)

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                             eval_results['accuracy']):
            break

    if flags_obj.export_dir is not None:
        # Exports a saved model for the given classifier.
        export_dtype = flags_core.get_tf_dtype(flags_obj)
        if flags_obj.image_bytes_as_serving_input:
            input_receiver_fn = functools.partial(image_bytes_serving_input_fn,
                                                  shape,
                                                  dtype=export_dtype)
        else:
            input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
                shape, batch_size=flags_obj.batch_size, dtype=export_dtype)
        classifier.export_savedmodel(flags_obj.export_dir,
                                     input_receiver_fn,
                                     strip_default_attrs=True)
Example #12
0
def run(flags_obj):
  """Run ResNet Cifar-10 training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                 enable_xla=flags_obj.enable_xla)

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == 'fp16':
    raise ValueError('dtype fp16 is not supported in Keras. Use the default '
                     'value(fp32).')

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus)

  strategy_scope = distribution_utils.get_strategy_scope(strategy)

  if flags_obj.use_synthetic_data:
    distribution_utils.set_up_synthetic_data()
    input_fn = keras_common.get_synth_input_fn(
        height=cifar_main.HEIGHT,
        width=cifar_main.WIDTH,
        num_channels=cifar_main.NUM_CHANNELS,
        num_classes=cifar_main.NUM_CLASSES,
        dtype=flags_core.get_tf_dtype(flags_obj))
  else:
    distribution_utils.undo_set_up_synthetic_data()
    input_fn = cifar_main.input_fn

  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras)

  eval_input_dataset = input_fn(
      is_training=False,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras)

  with strategy_scope:
    optimizer = keras_common.get_optimizer()
    model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES)

    model.compile(loss='categorical_crossentropy',
                  optimizer=optimizer,
                  run_eagerly=flags_obj.run_eagerly,
                  metrics=['categorical_accuracy'])

  callbacks = keras_common.get_callbacks(
      learning_rate_schedule, cifar_main.NUM_IMAGES['train'])

  train_steps = cifar_main.NUM_IMAGES['train'] // flags_obj.batch_size
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

  num_eval_steps = (cifar_main.NUM_IMAGES['validation'] //
                    flags_obj.batch_size)

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None
    validation_data = None

  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
    # when not using distribition strategy.
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
                      callbacks=callbacks,
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      validation_freq=flags_obj.epochs_between_evals,
                      verbose=2)
  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

  stats = keras_common.build_stats(history, eval_output, callbacks)
  return stats
def get_input_dataset(flags_obj, strategy):
    """Returns the test and train input datasets."""
    dtype = flags_core.get_tf_dtype(flags_obj)
    use_dataset_fn = isinstance(strategy,
                                tf.distribute.experimental.TPUStrategy)
    batch_size = flags_obj.batch_size
    if use_dataset_fn:
        if batch_size % strategy.num_replicas_in_sync != 0:
            raise ValueError(
                'Batch size must be divisible by number of replicas : {}'.
                format(strategy.num_replicas_in_sync))

        # As auto rebatching is not supported in
        # `experimental_distribute_datasets_from_function()` API, which is
        # required when cloning dataset to multiple workers in eager mode,
        # we use per-replica batch size.
        batch_size = int(batch_size / strategy.num_replicas_in_sync)

    if flags_obj.use_synthetic_data:
        input_fn = common.get_synth_input_fn(
            height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            num_channels=imagenet_preprocessing.NUM_CHANNELS,
            num_classes=imagenet_preprocessing.NUM_CLASSES,
            dtype=dtype,
            drop_remainder=True)
    else:
        input_fn = imagenet_preprocessing.input_fn

    def _train_dataset_fn(ctx=None):
        train_ds = input_fn(
            is_training=True,
            data_dir=flags_obj.data_dir,
            batch_size=batch_size,
            parse_record_fn=imagenet_preprocessing.parse_record,
            datasets_num_private_threads=flags_obj.
            datasets_num_private_threads,
            dtype=dtype,
            input_context=ctx,
            drop_remainder=True)
        return train_ds

    if strategy:
        if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
            train_ds = strategy.experimental_distribute_datasets_from_function(
                _train_dataset_fn)
        else:
            train_ds = strategy.experimental_distribute_dataset(
                _train_dataset_fn())
    else:
        train_ds = _train_dataset_fn()

    test_ds = None
    if not flags_obj.skip_eval:

        def _test_data_fn(ctx=None):
            test_ds = input_fn(
                is_training=False,
                data_dir=flags_obj.data_dir,
                batch_size=batch_size,
                parse_record_fn=imagenet_preprocessing.parse_record,
                dtype=dtype,
                input_context=ctx)
            return test_ds

        if strategy:
            if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
                test_ds = strategy.experimental_distribute_datasets_from_function(
                    _test_data_fn)
            else:
                test_ds = strategy.experimental_distribute_dataset(
                    _test_data_fn())
        else:
            test_ds = _test_data_fn()

    return train_ds, test_ds
Example #14
0
def resnet_main(flags_obj, model_function, input_function, shape=None):
    """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags.export_dir is passed.
  """

    # Using the Winograd non-fused algorithms provides a small performance boost.
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

    if flags_obj.multi_gpu:
        validate_batch_size_for_multi_gpu(flags_obj.batch_size)

        # There are two steps required if using multi-GPU: (1) wrap the model_fn,
        # and (2) wrap the optimizer. The first happens here, and (2) happens
        # in the model_fn itself when the optimizer is defined.
        model_function = tf.contrib.estimator.replicate_model_fn(
            model_function, loss_reduction=tf.losses.Reduction.MEAN)

    # Create session config based on values of inter_op_parallelism_threads and
    # intra_op_parallelism_threads. Note that we default to having
    # allow_soft_placement = True, which is required for multi-GPU and not
    # harmful for other modes.
    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    # Set up a RunConfig to save checkpoint and set session config.
    run_config = tf.estimator.RunConfig().replace(
        save_checkpoints_secs=1e9, session_config=session_config)
    classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        params={
            'resnet_size': int(flags_obj.resnet_size),
            'data_format': flags_obj.data_format,
            'batch_size': flags_obj.batch_size,
            'multi_gpu': flags_obj.multi_gpu,
            'version': int(flags_obj.version),
            'loss_scale': flags_core.get_loss_scale(flags_obj),
            'dtype': flags_core.get_tf_dtype(flags_obj)
        })

    benchmark_logger = logger.config_benchmark_logger(
        flags_obj.benchmark_log_dir)
    benchmark_logger.log_run_info('resnet')

    train_hooks = hooks_helper.get_train_hooks(
        flags_obj.hooks,
        batch_size=flags_obj.batch_size,
        benchmark_log_dir=flags_obj.benchmark_log_dir)

    def input_fn_train():
        return input_function(True, flags_obj.data_dir, flags_obj.batch_size,
                              flags_obj.epochs_between_evals,
                              flags_obj.num_parallel_calls,
                              flags_obj.multi_gpu)

    def input_fn_eval():
        return input_function(False, flags_obj.data_dir, flags_obj.batch_size,
                              1, flags_obj.num_parallel_calls,
                              flags_obj.multi_gpu)

    total_training_cycle = (flags_obj.train_epochs //
                            flags_obj.epochs_between_evals)
    for cycle_index in range(total_training_cycle):
        tf.logging.info('Starting a training cycle: %d/%d', cycle_index,
                        total_training_cycle)

        classifier.train(input_fn=input_fn_train,
                         hooks=train_hooks,
                         max_steps=flags_obj.max_train_steps)

        tf.logging.info('Starting to evaluate.')
        # flags.max_train_steps is generally associated with testing and profiling.
        # As a result it is frequently called with synthetic data, which will
        # iterate forever. Passing steps=flags.max_train_steps allows the eval
        # (which is generally unimportant in those circumstances) to terminate.
        # Note that eval will run for max_train_steps each loop, regardless of the
        # global_step count.
        eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                           steps=flags_obj.max_train_steps)

        benchmark_logger.log_evaluation_result(eval_results)

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                             eval_results['accuracy']):
            break

    if flags_obj.export_dir is not None:
        warn_on_multi_gpu_export(flags_obj.multi_gpu)

        # Exports a saved model for the given classifier.
        input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
            shape, batch_size=flags_obj.batch_size)
        classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
Example #15
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
    NotImplementedError: If some features are not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    try:
        _cudart = ctypes.CDLL('libcudart.so')
    except:
        _cudart = None

    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_utils.set_gpu_thread_mode_and_count(
            per_gpu_thread_count=flags_obj.per_gpu_thread_count,
            gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
            num_gpus=flags_obj.num_gpus,
            datasets_num_private_threads=flags_obj.datasets_num_private_threads
        )
    common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    performance.set_mixed_precision_policy(
        flags_core.get_tf_dtype(flags_obj),
        flags_core.get_loss_scale(flags_obj, default_for_fp16=128))

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    # Configures cluster spec for distribution strategy.
    _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                             flags_obj.task_index)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    # Current resnet_model.resnet50 input format is always channel-last.
    # We use keras_application mobilenet model which input format is depends on
    # the keras beckend image data format.
    # This use_keras_image_data_format flags indicates whether image preprocessor
    # output format should be same as the keras backend image data format or just
    # channel-last format.
    use_keras_image_data_format = flags_obj.keras_application_models

    preproccessing_type = imagenet_preprocessing if flags_obj.dataset == "imagenet" else cifar_preprocessing

    input_shape = (preproccessing_type.HEIGHT, preproccessing_type.WIDTH, \
       preproccessing_type.NUM_CHANNELS)

    if use_keras_image_data_format:
        if tf.keras.backend.image_data_format() == 'channels_first':
            input_shape = (preproccessing_type.NUM_CHANNELS, preproccessing_type.HEIGHT, \
               preproccessing_type.WIDTH)

    # pylint: disable=protected-access
    if flags_obj.use_synthetic_data:
        assert flags_obj.dataset == "imagenet", \
           f"Expect to only work with ImageNet, but have {flags_obj.dataset}"
        distribution_utils.set_up_synthetic_data()
        input_fn = common.get_synth_input_fn(
            height=preproccessing_type.DEFAULT_IMAGE_SIZE,
            width=preproccessing_type.DEFAULT_IMAGE_SIZE,
            num_channels=preproccessing_type.NUM_CHANNELS,
            num_classes=preproccessing_type.NUM_CLASSES,
            use_keras_image_data_format=use_keras_image_data_format,
            dtype=dtype,
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = preproccessing_type.input_fn

    # When `enable_xla` is True, we always drop the remainder of the batches
    # in the dataset, as XLA-GPU doesn't support dynamic shapes.
    drop_remainder = flags_obj.enable_xla

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        parse_record_fn=preproccessing_type.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        drop_remainder=drop_remainder,
        tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
        training_dataset_cache=flags_obj.training_dataset_cache,
    )

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=flags_obj.batch_size,
            parse_record_fn=preproccessing_type.get_parse_record_fn(
                use_keras_image_data_format=use_keras_image_data_format),
            dtype=dtype,
            drop_remainder=drop_remainder)

    lr_schedule = common.PiecewiseConstantDecayWithWarmup(
        batch_size=flags_obj.batch_size,
        epoch_size=preproccessing_type.NUM_IMAGES['train'],
        warmup_epochs=common.LR_SCHEDULE[0][1],
        boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
        multipliers=list(p[0] for p in common.LR_SCHEDULE),
        compute_lr_on_cpu=True)
    steps_per_epoch = (preproccessing_type.NUM_IMAGES['train'] //
                       flags_obj.batch_size)

    with strategy_scope:
        if flags_obj.optimizer == 'resnet50_default':
            optimizer = common.get_optimizer(lr_schedule)
        elif flags_obj.optimizer == 'mobilenet_default':
            initial_learning_rate = \
                flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
            optimizer = tf.keras.optimizers.SGD(
                learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
                    initial_learning_rate,
                    decay_steps=steps_per_epoch *
                    flags_obj.num_epochs_per_decay,
                    decay_rate=flags_obj.lr_decay_factor,
                    staircase=True),
                momentum=0.9)
        if flags_obj.fp16_implementation == 'graph_rewrite':
            # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
            # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
            # which will ensure tf.compat.v2.keras.mixed_precision and
            # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
            # up.
            optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                optimizer)

        # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
        if flags_obj.use_trivial_model:
            model = trivial_model.trivial_model(
                preproccessing_type.NUM_CLASSES)
        elif flags_obj.model == 'resnet50_v1.5':
            resnet_model.change_keras_layer(flags_obj.use_tf_keras_layers)
            model = resnet_model.resnet50(
                num_classes=preproccessing_type.NUM_CLASSES,
                input_shape=input_shape)
        elif flags_obj.model == 'mobilenet':
            # TODO(kimjaehong): Remove layers attribute when minimum TF version
            # support 2.0 layers by default.
            model = tf.keras.applications.mobilenet.MobileNet(
                weights=None,
                classes=preproccessing_type.NUM_CLASSES,
                input_shape=input_shape,
                layers=tf.keras.layers)
        elif flags_obj.keras_application_models:
            model_kfn = keras_app_models.get(flags_obj.model, None)
            if model_kfn is None:
                raise ValueError("No keras application model name %s" %
                                 flags_obj.model)
            model = model_kfn(weights=None,
                              input_shape=input_shape,
                              classes=preproccessing_type.NUM_CLASSES)
        if flags_obj.pretrained_filepath:
            model.load_weights(flags_obj.pretrained_filepath)

        if flags_obj.pruning_method == 'polynomial_decay':
            if dtype != tf.float32:
                raise NotImplementedError(
                    'Pruning is currently only supported on dtype=tf.float32.')
            pruning_params = {
                'pruning_schedule':
                tfmot.sparsity.keras.PolynomialDecay(
                    initial_sparsity=flags_obj.pruning_initial_sparsity,
                    final_sparsity=flags_obj.pruning_final_sparsity,
                    begin_step=flags_obj.pruning_begin_step,
                    end_step=flags_obj.pruning_end_step,
                    frequency=flags_obj.pruning_frequency),
            }
            model = tfmot.sparsity.keras.prune_low_magnitude(
                model, **pruning_params)
        elif flags_obj.pruning_method:
            raise NotImplementedError(
                'Only polynomial_decay is currently supported.')

        model.compile(loss='sparse_categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=(['sparse_categorical_accuracy']
                               if flags_obj.report_accuracy_metrics else None),
                      run_eagerly=flags_obj.run_eagerly)

    train_epochs = flags_obj.train_epochs

    callbacks = common.get_callbacks(
        steps_per_epoch=steps_per_epoch,
        pruning_method=flags_obj.pruning_method,
        enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
        model_dir=flags_obj.model_dir)

    # if mutliple epochs, ignore the train_steps flag.
    if train_epochs <= 1 and flags_obj.train_steps:
        steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
        train_epochs = 1

    num_eval_steps = (preproccessing_type.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        # Only build the training graph. This reduces memory usage introduced by
        # control flow ops in layers that have different implementations for
        # training and inference (e.g., batch norm).
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    if not strategy and flags_obj.explicit_gpu_placement:
        # TODO(b/135607227): Add device scope automatically in Keras training loop
        # when not using distribition strategy.
        no_dist_strat_device = tf.device('/device:GPU:0')
        no_dist_strat_device.__enter__()

    if _cudart:
        cuda_status = _cudart.cudaProfilerStart()
    else:
        cuda_status = None
    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)
    if cuda_status == 0:
        _cudart.cudaProfilerStop()
    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

    if flags_obj.pruning_method:
        model = tfmot.sparsity.keras.strip_pruning(model)
    if flags_obj.enable_checkpoint_and_export:
        if dtype == tf.bfloat16:
            logging.warning(
                'Keras model.save does not support bfloat16 dtype.')
        else:
            # Keras model.save assumes a float32 input designature.
            export_path = os.path.join(flags_obj.model_dir, 'saved_model')
            model.save(export_path, include_optimizer=False)

    if not strategy and flags_obj.explicit_gpu_placement:
        no_dist_strat_device.__exit__()

    stats = common.build_stats(history, eval_output, callbacks)
    return stats
Example #16
0
def resnet_main(flags_obj,
                model_function,
                input_function,
                dataset_name,
                shape=None):
    """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.
  """

    # Using the Winograd non-fused algorithms provides a small performance boost.
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

    # Create session config based on values of inter_op_parallelism_threads and
    # intra_op_parallelism_threads. Note that we default to having
    # allow_soft_placement = True, which is required for multi-GPU and not
    # harmful for other modes.
    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    if flags_core.get_num_gpus(flags_obj) == 0:
        distribution = tf.contrib.distribute.OneDeviceStrategy('device:CPU:0')
    elif flags_core.get_num_gpus(flags_obj) == 1:
        distribution = tf.contrib.distribute.OneDeviceStrategy('device:GPU:0')
    else:
        distribution = tf.contrib.distribute.MirroredStrategy(
            num_gpus=flags_core.get_num_gpus(flags_obj))

    run_config = tf.estimator.RunConfig(train_distribute=distribution,
                                        session_config=session_config)

    classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        params={
            'resnet_size': int(flags_obj.resnet_size),
            'data_format': flags_obj.data_format,
            'batch_size': flags_obj.batch_size,
            'resnet_version': int(flags_obj.resnet_version),
            'loss_scale': flags_core.get_loss_scale(flags_obj),
            'dtype': flags_core.get_tf_dtype(flags_obj)
        })

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'resnet_size': flags_obj.resnet_size,
        'resnet_version': flags_obj.resnet_version,
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
    }
    benchmark_logger = logger.config_benchmark_logger(flags_obj)
    benchmark_logger.log_run_info('resnet', dataset_name, run_params)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               batch_size=flags_obj.batch_size)

    def input_fn_train(num_epochs):
        return input_function(
            mode="train",
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=num_epochs,
            num_gpus=flags_core.get_num_gpus(flags_obj),
            dtype=flags_core.get_tf_dtype(flags_obj))

    def input_fn_eval():
        return input_function(mode="validate",
                              data_dir=flags_obj.data_dir,
                              batch_size=per_device_batch_size(
                                  flags_obj.batch_size,
                                  flags_core.get_num_gpus(flags_obj)),
                              num_epochs=1)

    def input_fn_pred():
        return input_function(mode="predict",
                              data_dir=flags_obj.data_dir,
                              batch_size=per_device_batch_size(
                                  flags_obj.batch_size,
                                  flags_core.get_num_gpus(flags_obj)),
                              num_epochs=1)

    #
    if flags_obj.predict_only:
        result = classifier.predict(input_fn=lambda: input_fn_pred())
        predicted_values = np.stack([r["predictions"] for r in result], axis=0)
        #print(predicted_values)
        df = pd.DataFrame(predicted_values)
        df.to_csv("validate_result.txt")
        return

    # train
    if flags_obj.eval_only or not flags_obj.train_epochs:
        # If --eval_only is set, perform a single loop with zero train epochs.
        schedule, n_loops = [0], 1
    else:
        # Compute the number of times to loop while training. All but the last
        # pass will train for `epochs_between_evals` epochs, while the last will
        # train for the number needed to reach `training_epochs`. For instance if
        #   train_epochs = 25 and epochs_between_evals = 10
        # schedule will be set to [10, 10, 5]. That is to say, the loop will:
        #   Train for 10 epochs and then evaluate.
        #   Train for another 10 epochs and then evaluate.
        #   Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
        n_loops = math.ceil(flags_obj.train_epochs /
                            flags_obj.epochs_between_evals)
        schedule = [
            flags_obj.epochs_between_evals for _ in range(int(n_loops))
        ]
        schedule[-1] = flags_obj.train_epochs - sum(
            schedule[:-1])  # over counting.

    for cycle_index, num_train_epochs in enumerate(schedule):
        tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops))

        if num_train_epochs:
            classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
                             hooks=train_hooks,
                             max_steps=flags_obj.max_train_steps)

        tf.logging.info('Starting to evaluate.')

        # flags_obj.max_train_steps is generally associated with testing and
        # profiling. As a result it is frequently called with synthetic data, which
        # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
        # eval (which is generally unimportant in those circumstances) to terminate.
        # Note that eval will run for max_train_steps each loop, regardless of the
        # global_step count.
        eval_results = classifier.evaluate(input_fn=input_fn_eval, steps=100)

        benchmark_logger.log_evaluation_result(eval_results)

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                             eval_results['mse']):
            break

    # save model for serving
    if flags_obj.export_dir is not None:
        # Exports a saved model for the given classifier.
        input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
            shape, batch_size=flags_obj.batch_size)
        classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
Example #17
0
def net_main(flags_obj,
             model_function,
             input_function,
             net_data_configs,
             shape=None):
    """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.
  """

    model_helpers.apply_clean(flags.FLAGS)
    is_metriclog = True
    if is_metriclog:
        metric_logfn = os.path.join(flags_obj.model_dir, 'log_metric.txt')
        metric_logf = open(metric_logfn, 'a')

    from tensorflow.contrib.memory_stats.ops import gen_memory_stats_ops
    max_memory_usage = gen_memory_stats_ops.max_bytes_in_use()

    # Using the Winograd non-fused algorithms provides a small performance boost.
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

    # Create session config based on values of inter_op_parallelism_threads and
    # intra_op_parallelism_threads. Note that we default to having
    # allow_soft_placement = True, which is required for multi-GPU and not
    # harmful for other modes.
    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config)

    # initialize our model with all but the dense layer from pretrained resnet
    if flags_obj.pretrained_model_checkpoint_path is not None:
        warm_start_settings = tf.estimator.WarmStartSettings(
            flags_obj.pretrained_model_checkpoint_path,
            vars_to_warm_start='^(?!.*dense)')
    else:
        warm_start_settings = None

    classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        warm_start_from=warm_start_settings,
        params={
            'data_format': flags_obj.data_format,
            'batch_size': flags_obj.batch_size,
            'loss_scale': flags_core.get_loss_scale(flags_obj),
            'weight_decay': flags_obj.weight_decay,
            'dtype': flags_core.get_tf_dtype(flags_obj),
            'fine_tune': flags_obj.fine_tune,
            'examples_per_epoch': flags_obj.examples_per_epoch,
            'net_data_configs': net_data_configs
        })

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
    }
    dataset_name = net_data_configs['dataset_name']
    if flags_obj.use_synthetic_data:
        dataset_name = dataset_name + '-synthetic'

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info('meshnet',
                                  dataset_name,
                                  run_params,
                                  test_id=flags_obj.benchmark_test_id)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    def input_fn_train(num_epochs):
        return input_function(
            is_training=True,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=num_epochs,
            num_gpus=flags_core.get_num_gpus(flags_obj),
            examples_per_epoch=flags_obj.examples_per_epoch,
            sg_settings=net_data_configs['sg_settings'])

    def input_fn_eval():
        return input_function(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=1,
            sg_settings=net_data_configs['sg_settings'])

    if flags_obj.eval_only or flags_obj.pred_ply or not flags_obj.train_epochs:
        # If --eval_only is set, perform a single loop with zero train epochs.
        schedule, n_loops = [0], 1
    else:
        # Compute the number of times to loop while training. All but the last
        # pass will train for `epochs_between_evals` epochs, while the last will
        # train for the number needed to reach `training_epochs`. For instance if
        #   train_epochs = 25 and epochs_between_evals = 10
        # schedule will be set to [10, 10, 5]. That is to say, the loop will:
        #   Train for 10 epochs and then evaluate.
        #   Train for another 10 epochs and then evaluate.
        #   Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
        n_loops = math.ceil(flags_obj.train_epochs /
                            flags_obj.epochs_between_evals)
        schedule = [
            flags_obj.epochs_between_evals for _ in range(int(n_loops))
        ]
        schedule[-1] = flags_obj.train_epochs - sum(
            schedule[:-1])  # over counting.

        classifier.train(input_fn=lambda: input_fn_train(1),
                         hooks=train_hooks,
                         max_steps=10)
        with tf.Session() as sess:
            max_memory_usage_v = sess.run(max_memory_usage)
            tf.logging.info('\n\nmemory usage: %0.3f G\n\n' %
                            (max_memory_usage_v * 1.0 / 1e9))

    best_acc, best_acc_checkpoint = load_saved_best(flags_obj.model_dir)
    for cycle_index, num_train_epochs in enumerate(schedule):
        tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops))

        t0 = time.time()
        train_t = 0
        if num_train_epochs:
            classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
                             hooks=train_hooks,
                             max_steps=flags_obj.max_train_steps)
            train_t = (time.time() - t0) / num_train_epochs

        tf.logging.info('Starting to evaluate.')

        # flags_obj.max_train_steps is generally associated with testing and
        # profiling. As a result it is frequently called with synthetic data, which
        # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
        # eval (which is generally unimportant in those circumstances) to terminate.
        # Note that eval will run for max_train_steps each loop, regardless of the
        # global_step count.
        only_train = False and (not flags_obj.eval_only) and (
            not flags_obj.pred_ply)
        if not only_train:
            t0 = time.time()
            eval_results = classifier.evaluate(
                input_fn=input_fn_eval,
                steps=flags_obj.max_train_steps,
            )
            #checkpoint_path=best_acc_checkpoint)
            eval_t = time.time() - t0

            if flags_obj.pred_ply:
                pred_generator = classifier.predict(input_fn=input_fn_eval)
                num_classes = net_data_configs['dset_metas'].num_classes
                gen_pred_ply(eval_results, pred_generator, flags_obj.model_dir,
                             num_classes)

            benchmark_logger.log_evaluation_result(eval_results)

            if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                                 eval_results['accuracy']):
                break

            cur_is_best = ''
            if num_train_epochs and eval_results['accuracy'] > best_acc:
                best_acc = eval_results['accuracy']
                save_cur_model_as_best_acc(flags_obj.model_dir, best_acc)
                cur_is_best = 'best'
            global_step = cur_global_step(flags_obj.model_dir)
            epoch = int(global_step / flags_obj.examples_per_epoch *
                        flags_obj.num_gpus)
            ious_str = get_ious_str(eval_results['cm'],
                                    net_data_configs['dset_metas'],
                                    eval_results['mean_iou'])
            metric_logf.write(
                '\n{} train t:{:.1f}  eval t:{:.1f} \teval acc:{:.3f} \tmean_iou:{:.3f} {} {}\n'
                .format(epoch, train_t, eval_t, eval_results['accuracy'],
                        eval_results['mean_iou'], cur_is_best, ious_str))
            metric_logf.flush()

    if flags_obj.export_dir is not None:
        # Exports a saved model for the given classifier.
        input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
            shape, batch_size=flags_obj.batch_size)
        classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
Example #18
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using custom training loops.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    dtype = flags_core.get_tf_dtype(flags_obj)

    # TODO(anj-s): Set data_format without using Keras.
    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=distribution_utils.configure_cluster(),
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    train_ds, test_ds = get_input_dataset(flags_obj, strategy)
    train_steps, train_epochs, eval_steps = get_num_train_iterations(flags_obj)

    time_callback = keras_utils.TimeHistory(flags_obj.batch_size,
                                            flags_obj.log_steps)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)
    with strategy_scope:
        model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES,
                                      dtype=dtype,
                                      batch_size=flags_obj.batch_size)

        optimizer = tf.keras.optimizers.SGD(
            learning_rate=keras_common.BASE_LEARNING_RATE,
            momentum=0.9,
            nesterov=True)

        training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'training_accuracy', dtype=tf.float32)
        test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
        test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'test_accuracy', dtype=tf.float32)

        def train_step(train_ds_inputs):
            """Training StepFn."""
            def step_fn(inputs):
                """Per-Replica StepFn."""
                images, labels = inputs
                with tf.GradientTape() as tape:
                    logits = model(images, training=True)

                    prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
                        labels, logits)
                    loss1 = tf.reduce_sum(prediction_loss) * (
                        1.0 / flags_obj.batch_size)
                    loss2 = (tf.reduce_sum(model.losses) /
                             tf.distribute.get_strategy().num_replicas_in_sync)
                    loss = loss1 + loss2

                grads = tape.gradient(loss, model.trainable_variables)
                optimizer.apply_gradients(zip(grads,
                                              model.trainable_variables))

                training_accuracy.update_state(labels, logits)
                return loss

            if strategy:
                per_replica_losses = strategy.experimental_run_v2(
                    step_fn, args=(train_ds_inputs, ))
                return strategy.reduce(tf.distribute.ReduceOp.SUM,
                                       per_replica_losses,
                                       axis=None)
            else:
                return step_fn(train_ds_inputs)

        def test_step(test_ds_inputs):
            """Evaluation StepFn."""
            def step_fn(inputs):
                images, labels = inputs
                logits = model(images, training=False)
                loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, logits)
                loss = tf.reduce_sum(loss) * (1.0 / flags_obj.batch_size)
                test_loss.update_state(loss)
                test_accuracy.update_state(labels, logits)

            if strategy:
                strategy.experimental_run_v2(step_fn, args=(test_ds_inputs, ))
            else:
                step_fn(test_ds_inputs)

        if flags_obj.use_tf_function:
            train_step = tf.function(train_step)
            test_step = tf.function(test_step)

        time_callback.on_train_begin()
        for epoch in range(train_epochs):

            train_iter = iter(train_ds)
            total_loss = 0.0
            training_accuracy.reset_states()

            for step in range(train_steps):
                optimizer.lr = keras_imagenet_main.learning_rate_schedule(
                    epoch, step, train_steps, flags_obj.batch_size)

                time_callback.on_batch_begin(step + epoch * train_steps)
                total_loss += train_step(next(train_iter))
                time_callback.on_batch_end(step + epoch * train_steps)

            train_loss = total_loss / train_steps
            logging.info('Training loss: %s, accuracy: %s%% at epoch: %d',
                         train_loss.numpy(),
                         training_accuracy.result().numpy(), epoch)

            if (not flags_obj.skip_eval
                    and (epoch + 1) % flags_obj.epochs_between_evals == 0):
                test_loss.reset_states()
                test_accuracy.reset_states()

                test_iter = iter(test_ds)
                for _ in range(eval_steps):
                    test_step(next(test_iter))

                logging.info('Test loss: %s, accuracy: %s%% at epoch: %d',
                             test_loss.result().numpy(),
                             test_accuracy.result().numpy(), epoch)

        time_callback.on_train_end()

        eval_result = None
        train_result = None
        if not flags_obj.skip_eval:
            eval_result = [
                test_loss.result().numpy(),
                test_accuracy.result().numpy()
            ]
            train_result = [
                train_loss.numpy(),
                training_accuracy.result().numpy()
            ]

        stats = build_stats(train_result, eval_result, time_callback)
        return stats
Example #19
0
    def gen_estimator(period=None):
        resnet_size = int(flags_obj.resnet_size)
        data_format = flags_obj.data_format
        batch_size = flags_obj.batch_size
        resnet_version = int(flags_obj.resnet_version)
        loss_scale = flags_core.get_loss_scale(flags_obj)
        dtype_tf = flags_core.get_tf_dtype(flags_obj)
        num_epochs_per_decay = flags_obj.num_epochs_per_decay
        learning_rate_decay_factor = flags_obj.learning_rate_decay_factor
        end_learning_rate = flags_obj.end_learning_rate
        learning_rate_decay_type = flags_obj.learning_rate_decay_type
        weight_decay = flags_obj.weight_decay
        zero_gamma = flags_obj.zero_gamma
        lr_warmup_epochs = flags_obj.lr_warmup_epochs
        base_learning_rate = flags_obj.base_learning_rate
        use_resnet_d = flags_obj.use_resnet_d
        use_dropblock = flags_obj.use_dropblock
        dropblock_kp = [float(be) for be in flags_obj.dropblock_kp]
        label_smoothing = flags_obj.label_smoothing
        momentum = flags_obj.momentum
        bn_momentum = flags_obj.bn_momentum
        train_epochs = flags_obj.train_epochs
        piecewise_lr_boundary_epochs = [
            int(be) for be in flags_obj.piecewise_lr_boundary_epochs
        ]
        piecewise_lr_decay_rates = [
            float(dr) for dr in flags_obj.piecewise_lr_decay_rates
        ]
        use_ranking_loss = flags_obj.use_ranking_loss
        use_se_block = flags_obj.use_se_block
        use_sk_block = flags_obj.use_sk_block
        mixup_type = flags_obj.mixup_type
        dataset_name = flags_obj.dataset_name
        kd_temp = flags_obj.kd_temp
        no_downsample = flags_obj.no_downsample
        anti_alias_filter_size = flags_obj.anti_alias_filter_size
        anti_alias_type = flags_obj.anti_alias_type
        cls_loss_type = flags_obj.cls_loss_type
        logit_type = flags_obj.logit_type
        embedding_size = flags_obj.embedding_size
        pool_type = flags_obj.pool_type
        arc_s = flags_obj.arc_s
        arc_m = flags_obj.arc_m
        bl_alpha = flags_obj.bl_alpha
        bl_beta = flags_obj.bl_beta
        exp = None

        if install_hyperdash and flags_obj.use_hyperdash:
            exp = Experiment(flags_obj.model_dir.split("/")[-1])
            resnet_size = exp.param("resnet_size", int(flags_obj.resnet_size))
            batch_size = exp.param("batch_size", flags_obj.batch_size)
            exp.param("dtype", flags_obj.dtype)
            learning_rate_decay_type = exp.param(
                "learning_rate_decay_type", flags_obj.learning_rate_decay_type)
            weight_decay = exp.param("weight_decay", flags_obj.weight_decay)
            zero_gamma = exp.param("zero_gamma", flags_obj.zero_gamma)
            lr_warmup_epochs = exp.param("lr_warmup_epochs",
                                         flags_obj.lr_warmup_epochs)
            base_learning_rate = exp.param("base_learning_rate",
                                           flags_obj.base_learning_rate)
            use_dropblock = exp.param("use_dropblock", flags_obj.use_dropblock)
            dropblock_kp = exp.param(
                "dropblock_kp", [float(be) for be in flags_obj.dropblock_kp])
            piecewise_lr_boundary_epochs = exp.param(
                "piecewise_lr_boundary_epochs",
                [int(be) for be in flags_obj.piecewise_lr_boundary_epochs])
            piecewise_lr_decay_rates = exp.param(
                "piecewise_lr_decay_rates",
                [float(dr) for dr in flags_obj.piecewise_lr_decay_rates])
            mixup_type = exp.param("mixup_type", flags_obj.mixup_type)
            dataset_name = exp.param("dataset_name", flags_obj.dataset_name)
            exp.param("autoaugment_type", flags_obj.autoaugment_type)

        classifier = tf.estimator.Estimator(
            model_fn=model_function,
            model_dir=flags_obj.model_dir,
            config=run_config,
            params={
                'resnet_size': resnet_size,
                'data_format': data_format,
                'batch_size': batch_size,
                'resnet_version': resnet_version,
                'loss_scale': loss_scale,
                'dtype': dtype_tf,
                'num_epochs_per_decay': num_epochs_per_decay,
                'learning_rate_decay_factor': learning_rate_decay_factor,
                'end_learning_rate': end_learning_rate,
                'learning_rate_decay_type': learning_rate_decay_type,
                'weight_decay': weight_decay,
                'zero_gamma': zero_gamma,
                'lr_warmup_epochs': lr_warmup_epochs,
                'base_learning_rate': base_learning_rate,
                'use_resnet_d': use_resnet_d,
                'use_dropblock': use_dropblock,
                'dropblock_kp': dropblock_kp,
                'label_smoothing': label_smoothing,
                'momentum': momentum,
                'bn_momentum': bn_momentum,
                'embedding_size': embedding_size,
                'train_epochs': train_epochs,
                'piecewise_lr_boundary_epochs': piecewise_lr_boundary_epochs,
                'piecewise_lr_decay_rates': piecewise_lr_decay_rates,
                'with_drawing_bbox': flags_obj.with_drawing_bbox,
                'use_ranking_loss': use_ranking_loss,
                'use_se_block': use_se_block,
                'use_sk_block': use_sk_block,
                'mixup_type': mixup_type,
                'kd_temp': kd_temp,
                'no_downsample': no_downsample,
                'dataset_name': dataset_name,
                'anti_alias_filter_size': anti_alias_filter_size,
                'anti_alias_type': anti_alias_type,
                'cls_loss_type': cls_loss_type,
                'logit_type': logit_type,
                'arc_s': arc_s,
                'arc_m': arc_m,
                'pool_type': pool_type,
                'bl_alpha': bl_alpha,
                'bl_beta': bl_beta,
                'train_steps': total_train_steps,
            })
        return classifier, exp
Example #20
0
def run(flags_obj):
    """Run ResNet Cifar-10 training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_utils.set_gpu_thread_mode_and_count(
            per_gpu_thread_count=flags_obj.per_gpu_thread_count,
            gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
            num_gpus=flags_obj.num_gpus,
            datasets_num_private_threads=flags_obj.datasets_num_private_threads
        )
    common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'fp16':
        raise ValueError(
            'dtype fp16 is not supported in Keras. Use the default '
            'value(fp32).')

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=distribution_utils.configure_cluster(),
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = common.get_synth_input_fn(
            height=cifar_preprocessing.HEIGHT,
            width=cifar_preprocessing.WIDTH,
            num_channels=cifar_preprocessing.NUM_CHANNELS,
            num_classes=cifar_preprocessing.NUM_CLASSES,
            dtype=flags_core.get_tf_dtype(flags_obj),
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = cifar_preprocessing.input_fn

    #train_input_dataset = input_fn(
    #    is_training=True,
    #    data_dir=flags_obj.data_dir,
    #    batch_size=flags_obj.batch_size,
    #    num_epochs=flags_obj.train_epochs,
    #    parse_record_fn=cifar_preprocessing.parse_record,
    #    datasets_num_private_threads=flags_obj.datasets_num_private_threads,
    #    dtype=dtype,
    #    # Setting drop_remainder to avoid the partial batch logic in normalization
    #    # layer, which triggers tf.where and leads to extra memory copy of input
    #    # sizes between host and GPU.
    #    drop_remainder=(not flags_obj.enable_get_next_as_optional))

    # eval_input_dataset = None
    # if not flags_obj.skip_eval:
    #   eval_input_dataset = input_fn(
    #       is_training=False,
    #       data_dir=flags_obj.data_dir,
    #       batch_size=flags_obj.batch_size,
    #       num_epochs=flags_obj.train_epochs,
    #       parse_record_fn=cifar_preprocessing.parse_record)

    (x_train, y_train), (x_test,
                         y_test) = tf.keras.datasets.cifar10.load_data()
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255
    x_test /= 255
    y_train = tf.keras.utils.to_categorical(y_train, num_classes)
    y_test = tf.keras.utils.to_categorical(y_test, num_classes)

    # optimizer = common.get_optimizer()

    opt = tf.keras.optimizers.SGD(learning_rate=0.1)

    logging.info(opt.__dict__)
    optimizer = SynchronousSGDOptimizer(opt, use_locking=True)
    optimizer._hyper = opt._hyper

    logging.info(optimizer.__dict__)

    model = Conv4_model(x_train, num_classes)

    # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
    # a valid arg for this model. Also remove as a valid flag.
    if flags_obj.force_v2_in_keras_compile is not None:
        model.compile(
            loss='categorical_crossentropy',
            optimizer=optimizer,
            metrics=(['accuracy']),
            run_eagerly=flags_obj.run_eagerly,
            experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)
    else:
        model.compile(loss='categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=(['accuracy']),
                      run_eagerly=flags_obj.run_eagerly)

    cluster_size = current_cluster_size()
    steps_per_epoch = (cifar_preprocessing.NUM_IMAGES['train'] //
                       flags_obj.batch_size)
    steps_per_epoch = steps_per_epoch // cluster_size
    train_epochs = flags_obj.train_epochs

    callbacks = common.get_callbacks(steps_per_epoch, current_rank(),
                                     cluster_size, learning_rate_schedule)
    callbacks.append(BroadcastGlobalVariablesCallback())

    if flags_obj.train_steps:
        steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)

    num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    # validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    tf.compat.v1.logging.info(x_train.shape)
    history = model.fit(x_train,
                        y_train,
                        batch_size=flags_obj.batch_size,
                        epochs=train_epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=(x_test, y_test),
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)
    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate((x_test, y_test),
                                     steps=num_eval_steps,
                                     verbose=2)
    stats = common.build_stats(history, eval_output, callbacks)
    return stats
Example #21
0
def run(flags_obj):
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  config = keras_common.get_config_proto()
  # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
  # Eager is default in tf 2.0 and should not be toggled
  if not keras_common.is_v2_0():
    if flags_obj.enable_eager:
      tf.compat.v1.enable_eager_execution(config=config)
    else:
      sess = tf.Session(config=config)
      tf.keras.backend.set_session(sess)
  # TODO(haoyuzhang): Set config properly in TF2.0 when the config API is ready.

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == 'float16':
    policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
    tf.keras.mixed_precision.experimental.set_policy(policy)

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
    distribution_utils.set_up_synthetic_data()
    input_fn = keras_common.get_synth_input_fn(
        height=imagenet_main.DEFAULT_IMAGE_SIZE,
        width=imagenet_main.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_main.NUM_CHANNELS,
        num_classes=imagenet_main.NUM_CLASSES,
        dtype=dtype)
  else:
    distribution_utils.undo_set_up_synthetic_data()
    input_fn = imagenet_main.input_fn

  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras,
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
      dtype=dtype)

  eval_input_dataset = input_fn(
      is_training=False,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras,
      dtype=dtype)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus)

  strategy_scope = keras_common.get_strategy_scope(strategy)

  with strategy_scope:
    optimizer = keras_common.get_optimizer()
    if dtype == 'float16':
      # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
      # can be enabled with a single line of code.
      optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
          optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))
    model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES,
                                  dtype=dtype)

    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer=optimizer,
                  metrics=['sparse_categorical_accuracy'])

  time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
      learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])

  train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

  num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
                    flags_obj.batch_size)

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
    tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None
    validation_data = None

  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
                      callbacks=[
                          time_callback,
                          lr_callback,
                          tensorboard_callback
                      ],
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      validation_freq=flags_obj.epochs_between_evals,
                      verbose=2)

  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)
  stats = keras_common.build_stats(history, eval_output, time_callback)
  return stats
Example #22
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    config = keras_common.get_config_proto()
    # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
    # Eager is default in tf 2.0 and should not be toggled
    if not keras_common.is_v2_0():
        if flags_obj.enable_eager:
            tf.compat.v1.enable_eager_execution(config=config)
        else:
            sess = tf.Session(config=config)
            tf.keras.backend.set_session(sess)
    # TODO(haoyuzhang): Set config properly in TF2.0 when the config API is ready.

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'float16':
        policy = tf.keras.mixed_precision.experimental.Policy(
            'infer_float32_vars')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    # pylint: disable=protected-access
    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = keras_common.get_synth_input_fn(
            height=imagenet_main.DEFAULT_IMAGE_SIZE,
            width=imagenet_main.DEFAULT_IMAGE_SIZE,
            num_channels=imagenet_main.NUM_CHANNELS,
            num_classes=imagenet_main.NUM_CLASSES,
            dtype=dtype)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = imagenet_main.input_fn

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=parse_record_keras,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype)

    eval_input_dataset = input_fn(is_training=False,
                                  data_dir=flags_obj.data_dir,
                                  batch_size=flags_obj.batch_size,
                                  num_epochs=flags_obj.train_epochs,
                                  parse_record_fn=parse_record_keras,
                                  dtype=dtype)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=distribution_utils.configure_cluster())

    strategy_scope = keras_common.get_strategy_scope(strategy)

    with strategy_scope:
        optimizer = keras_common.get_optimizer()
        if dtype == 'float16':
            # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
            # can be enabled with a single line of code.
            optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
                optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))
        model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES,
                                      dtype=dtype)

        model.compile(loss='sparse_categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=['sparse_categorical_accuracy'])

    time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
        learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])

    train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
    train_epochs = flags_obj.train_epochs

    if flags_obj.train_steps:
        train_steps = min(flags_obj.train_steps, train_steps)
        train_epochs = 1

    num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        # Only build the training graph. This reduces memory usage introduced by
        # control flow ops in layers that have different implementations for
        # training and inference (e.g., batch norm).
        tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    history = model.fit(
        train_input_dataset,
        epochs=train_epochs,
        steps_per_epoch=train_steps,
        callbacks=[time_callback, lr_callback, tensorboard_callback],
        validation_steps=num_eval_steps,
        validation_data=validation_data,
        validation_freq=flags_obj.epochs_between_evals,
        verbose=2)

    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)
    stats = keras_common.build_stats(history, eval_output, time_callback)
    return stats
def run(flags_obj):
    """Run ResNet Cifar-10 training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_utils.set_gpu_thread_mode_and_count(
            per_gpu_thread_count=flags_obj.per_gpu_thread_count,
            gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
            num_gpus=flags_obj.num_gpus,
            datasets_num_private_threads=flags_obj.datasets_num_private_threads
        )
    common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'fp16':
        raise ValueError(
            'dtype fp16 is not supported in Keras. Use the default '
            'value(fp32).')

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.config.list_physical_devices('GPU') else
                       'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    if flags_obj.use_synthetic_data:
        synthetic_util.set_up_synthetic_data()
        input_fn = common.get_synth_input_fn(
            height=cifar_preprocessing.HEIGHT,
            width=cifar_preprocessing.WIDTH,
            num_channels=cifar_preprocessing.NUM_CHANNELS,
            num_classes=cifar_preprocessing.NUM_CLASSES,
            dtype=flags_core.get_tf_dtype(flags_obj),
            drop_remainder=True)
    else:
        synthetic_util.undo_set_up_synthetic_data()
        input_fn = cifar_preprocessing.input_fn

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        parse_record_fn=cifar_preprocessing.parse_record,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        # Setting drop_remainder to avoid the partial batch logic in normalization
        # layer, which triggers tf.where and leads to extra memory copy of input
        # sizes between host and GPU.
        drop_remainder=(not flags_obj.enable_get_next_as_optional))

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=flags_obj.batch_size,
            parse_record_fn=cifar_preprocessing.parse_record)
        options = tf.data.Options()
        options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
        eval_input_dataset = eval_input_dataset.with_options(options)

    steps_per_epoch = (cifar_preprocessing.NUM_IMAGES['train'] //
                       flags_obj.batch_size)
    lr_schedule = 0.1
    if flags_obj.use_tensor_lr:
        initial_learning_rate = common.BASE_LEARNING_RATE * flags_obj.batch_size / 128
        lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            boundaries=list(p[1] * steps_per_epoch for p in LR_SCHEDULE),
            values=[initial_learning_rate] + list(p[0] * initial_learning_rate
                                                  for p in LR_SCHEDULE))

    with strategy_scope:
        optimizer = common.get_optimizer(lr_schedule)
        model = resnet_cifar_model.resnet56(
            classes=cifar_preprocessing.NUM_CLASSES)
        model.compile(loss='sparse_categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=(['sparse_categorical_accuracy']
                               if flags_obj.report_accuracy_metrics else None),
                      run_eagerly=flags_obj.run_eagerly)

    train_epochs = flags_obj.train_epochs

    callbacks = common.get_callbacks()

    if not flags_obj.use_tensor_lr:
        lr_callback = LearningRateBatchScheduler(
            schedule=learning_rate_schedule,
            batch_size=flags_obj.batch_size,
            steps_per_epoch=steps_per_epoch)
        callbacks.append(lr_callback)

    # if mutliple epochs, ignore the train_steps flag.
    if train_epochs <= 1 and flags_obj.train_steps:
        steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
        train_epochs = 1

    num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    if not strategy and flags_obj.explicit_gpu_placement:
        # TODO(b/135607227): Add device scope automatically in Keras training loop
        # when not using distribition strategy.
        no_dist_strat_device = tf.device('/device:GPU:0')
        no_dist_strat_device.__enter__()

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)
    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

    if not strategy and flags_obj.explicit_gpu_placement:
        no_dist_strat_device.__exit__()

    stats = common.build_stats(history, eval_output, callbacks)
    return stats
Example #24
0
def run(flags_obj):
  """Run ResNet Cifar-10 training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  if flags_obj.enable_eager:
    tf.enable_eager_execution()

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == 'fp16':
    raise ValueError('dtype fp16 is not supported in Keras. Use the default '
                     'value(fp32).')

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  if flags_obj.use_synthetic_data:
    input_fn = keras_common.get_synth_input_fn(
        height=cifar_main.HEIGHT,
        width=cifar_main.WIDTH,
        num_channels=cifar_main.NUM_CHANNELS,
        num_classes=cifar_main.NUM_CLASSES,
        dtype=flags_core.get_tf_dtype(flags_obj))
  else:
    input_fn = cifar_main.input_fn

  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras)

  eval_input_dataset = input_fn(
      is_training=False,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras)

  strategy = distribution_utils.get_distribution_strategy(
      num_gpus=flags_obj.num_gpus,
      turn_off_distribution_strategy=flags_obj.turn_off_distribution_strategy)

  strategy_scope = keras_common.get_strategy_scope(strategy)

  with strategy_scope:
    optimizer = keras_common.get_optimizer()
    model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES)

    model.compile(loss='categorical_crossentropy',
                  optimizer=optimizer,
                  metrics=['categorical_accuracy'])

  time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
      learning_rate_schedule, cifar_main.NUM_IMAGES['train'])

  train_steps = cifar_main.NUM_IMAGES['train'] // flags_obj.batch_size
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

  num_eval_steps = (cifar_main.NUM_IMAGES['validation'] //
                    flags_obj.batch_size)

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None
    validation_data = None

  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
                      callbacks=[
                          time_callback,
                          lr_callback,
                          tensorboard_callback
                      ],
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      verbose=2)
  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=1)
  stats = keras_common.build_stats(history, eval_output, time_callback)
  return stats
Example #25
0
def resnet_main(
    flags_obj, model_function, input_function, dataset_name, shape=None):
  """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.
  """

  print("RESNET MAIN")
  model_helpers.apply_clean(flags.FLAGS)

  # Ensures flag override logic is only executed if explicitly triggered.
  if flags_obj.tf_gpu_thread_mode:
    override_flags_and_set_envars_for_gpu_thread_pool(flags_obj)

  # Creates session config. allow_soft_placement = True, is required for
  # multi-GPU and is not harmful for other modes.
  session_config = tf.ConfigProto(allow_soft_placement=True)

  run_config = tf.estimator.RunConfig(
      session_config=session_config,
      save_checkpoints_secs=60*60*24)

  # Initializes model with all but the dense layer from pretrained ResNet.
  if flags_obj.pretrained_model_checkpoint_path is not None:
    warm_start_settings = tf.estimator.WarmStartSettings(
        flags_obj.pretrained_model_checkpoint_path,
        vars_to_warm_start='^(?!.*dense)')
  else:
    warm_start_settings = None

  classifier = tf.estimator.Estimator(
      model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config,
      warm_start_from=warm_start_settings, params={
          'resnet_size': int(flags_obj.resnet_size),
          'data_format': flags_obj.data_format,
          'batch_size': flags_obj.batch_size,
          'resnet_version': int(flags_obj.resnet_version),
          'loss_scale': flags_core.get_loss_scale(flags_obj),
          'dtype': flags_core.get_tf_dtype(flags_obj),
          'fine_tune': flags_obj.fine_tune
      })

  run_params = {
      'batch_size': flags_obj.batch_size,
      'dtype': flags_core.get_tf_dtype(flags_obj),
      'resnet_size': flags_obj.resnet_size,
      'resnet_version': flags_obj.resnet_version,
      'synthetic_data': flags_obj.use_synthetic_data,
      'train_epochs': flags_obj.train_epochs,
  }

  def input_fn_eval():
    return input_function(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=distribution_utils.per_device_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=1,
        dtype=flags_core.get_tf_dtype(flags_obj))

  schedule, n_loops = [0], 1
  if flags_obj.export_dir is not None:
    # Exports a saved model for the given classifier.
    export_dtype = flags_core.get_tf_dtype(flags_obj)
    if flags_obj.image_bytes_as_serving_input:
      input_receiver_fn = functools.partial(
          image_bytes_serving_input_fn, shape, dtype=export_dtype)
    else:
      input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
          shape, batch_size=flags_obj.batch_size, dtype=export_dtype)
    classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn,
                                 strip_default_attrs=True)
Example #26
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using custom training loops.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)
    performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))

    # This only affects GPU.
    common.set_cudnn_batchnorm_mode()

    # TODO(anj-s): Set data_format without using Keras.
    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.config.list_physical_devices('GPU') else
                       'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
        flags_obj)
    steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps)

    logging.info(
        'Training %d epochs, each epoch has %d steps, '
        'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
        train_epochs * per_epoch_steps, eval_steps)

    time_callback = keras_utils.TimeHistory(
        flags_obj.batch_size,
        flags_obj.log_steps,
        logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
    with distribution_utils.get_strategy_scope(strategy):
        runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
                                                  per_epoch_steps)

    eval_interval = flags_obj.epochs_between_evals * per_epoch_steps
    checkpoint_interval = (per_epoch_steps
                           if flags_obj.enable_checkpoint_and_export else None)
    summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None

    checkpoint_manager = tf.train.CheckpointManager(
        runnable.checkpoint,
        directory=flags_obj.model_dir,
        max_to_keep=10,
        step_counter=runnable.global_step,
        checkpoint_interval=checkpoint_interval)

    resnet_controller = controller.Controller(
        strategy,
        runnable.train,
        runnable.evaluate,
        global_step=runnable.global_step,
        steps_per_loop=steps_per_loop,
        train_steps=per_epoch_steps * train_epochs,
        checkpoint_manager=checkpoint_manager,
        summary_interval=summary_interval,
        eval_steps=eval_steps,
        eval_interval=eval_interval)

    time_callback.on_train_begin()
    resnet_controller.train(evaluate=not flags_obj.skip_eval)
    time_callback.on_train_end()

    stats = build_stats(runnable, time_callback)
    return stats
Example #27
0
def use_float16():
  return flags_core.get_tf_dtype(flags.FLAGS) == tf.float16
def resnet_main(flags_obj,
                model_function,
                input_function,
                dataset_name,
                percent,
                model_class,
                shape=None):
    """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.

  Returns:
     Dict of results of the run.  Contains the keys `eval_results` and
    `train_hooks`. `eval_results` contains accuracy (top_1) and accuracy_top_5.
    `train_hooks` is a list the instances of hooks used during training.
  """

    model_helpers.apply_clean(flags.FLAGS)

    # Ensures flag override logic is only executed if explicitly triggered.
    if flags_obj.tf_gpu_thread_mode:
        override_flags_and_set_envars_for_gpu_thread_pool(flags_obj)

    # Configures cluster spec for distribution strategy.
    num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                                       flags_obj.task_index)

    # Creates session config. allow_soft_placement = True, is required for
    # multi-GPU and is not harmful for other modes.
    session_config = tf.compat.v1.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_core.get_num_gpus(flags_obj),
        num_workers=num_workers,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    # Creates a `RunConfig` that checkpoints every 24 hours which essentially
    # results in checkpoints determined only by `epochs_between_evals`.
    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config,
                                        save_checkpoints_secs=60 * 60 * 24,
                                        save_checkpoints_steps=None)

    # Initializes model with all but the dense layer from pretrained ResNet.
    if flags_obj.pretrained_model_checkpoint_path is not None:
        warm_start_settings = tf.estimator.WarmStartSettings(
            flags_obj.pretrained_model_checkpoint_path,
            vars_to_warm_start='^(?!.*dense)')
    else:
        warm_start_settings = None
    params = {
        'resnet_size': int(flags_obj.resnet_size),
        'data_format': 'channels_last',
        'batch_size': flags_obj.batch_size,
        'resnet_version': int(flags_obj.resnet_version),
        'loss_scale': flags_core.get_loss_scale(flags_obj,
                                                default_for_fp16=128),
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'fine_tune': flags_obj.fine_tune,
        'num_workers': num_workers,
        'adv_train': False,
        'attack': False,
    }

    classifier = tf.compat.v1.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        warm_start_from=warm_start_settings,
        params=params)
    params['adv_train'] = True
    classifier_adv = tf.compat.v1.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        warm_start_from=warm_start_settings,
        params=params)
    params['adv_train'] = False
    params['attack'] = True
    classifier_attack = tf.compat.v1.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        warm_start_from=warm_start_settings,
        params=params)
    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'resnet_size': flags_obj.resnet_size,
        'resnet_version': flags_obj.resnet_version,
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
        'num_workers': num_workers,
    }
    if flags_obj.use_synthetic_data:
        dataset_name = dataset_name + '-synthetic'

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info('resnet',
                                  dataset_name,
                                  run_params,
                                  test_id=flags_obj.benchmark_test_id)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    def input_fn_train(num_epochs, input_context=None):
        return input_function(
            is_training=True,
            percent=percent,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_replica_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=num_epochs,
            dtype=flags_core.get_tf_dtype(flags_obj),
            datasets_num_private_threads=flags_obj.
            datasets_num_private_threads,
            input_context=input_context)

    def input_fn_eval():
        return input_function(
            is_training=False,
            percent=0,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_replica_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=1,
            dtype=flags_core.get_tf_dtype(flags_obj))

    def input_fn_eval_attack():
        return input_function(
            is_training=False,
            percent=100,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_replica_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=1,
            dtype=flags_core.get_tf_dtype(flags_obj))

    train_epochs = (0 if flags_obj.eval_only or not flags_obj.train_epochs else
                    flags_obj.train_epochs)
    tf.compat.v1.logging.info(tf.global_variables())
    use_train_and_evaluate = flags_obj.use_train_and_evaluate or num_workers > 1
    if use_train_and_evaluate:
        train_spec = tf.estimator.TrainSpec(
            input_fn=lambda input_context=None: input_fn_train(
                train_epochs, input_context=input_context),
            hooks=train_hooks,
            max_steps=flags_obj.max_train_steps)
        eval_spec = tf.estimator.EvalSpec(input_fn=input_fn_eval)
        tf.compat.v1.logging.info('Starting to train and evaluate.')
        tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)
        # tf.estimator.train_and_evalute doesn't return anything in multi-worker
        # case.
        eval_results = {}
    else:
        if train_epochs == 0:
            # If --eval_only is set, perform a single loop with zero train epochs.
            schedule, n_loops = [0], 1
        else:
            # Compute the number of times to loop while training. All but the last
            # pass will train for `epochs_between_evals` epochs, while the last will
            # train for the number needed to reach `training_epochs`. For instance if
            #   train_epochs = 25 and epochs_between_evals = 10
            # schedule will be set to [10, 10, 5]. That is to say, the loop will:
            #   Train for 10 epochs and then evaluate.
            #   Train for another 10 epochs and then evaluate.
            #   Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
            n_loops = math.ceil(train_epochs / flags_obj.epochs_between_evals)
            schedule = [
                flags_obj.epochs_between_evals for _ in range(int(n_loops))
            ]
            schedule[-1] = train_epochs - sum(schedule[:-1])  # over counting.

        for cycle_index, num_train_epochs in enumerate(schedule):
            tf.compat.v1.logging.info('Starting cycle: %d/%d', cycle_index,
                                      int(n_loops))

            if num_train_epochs:
                # Since we are calling classifier.train immediately in each loop, the
                # value of num_train_epochs in the lambda function will not be changed
                # before it is used. So it is safe to ignore the pylint error here
                # pylint: disable=cell-var-from-loop
                if flags_obj.adv_train:

                    classifier_adv.train(
                        input_fn=lambda input_context=None: input_fn_train(
                            num_train_epochs, input_context=input_context),
                        hooks=train_hooks,
                        max_steps=flags_obj.max_train_steps)
                else:
                    classifier.train(
                        input_fn=lambda input_context=None: input_fn_train(
                            num_train_epochs, input_context=input_context),
                        hooks=train_hooks,
                        max_steps=flags_obj.max_train_steps)

            # flags_obj.max_train_steps is generally associated with testing and
            # profiling. As a result it is frequently called with synthetic data,
            # which will iterate forever. Passing steps=flags_obj.max_train_steps
            # allows the eval (which is generally unimportant in those circumstances)
            # to terminate.  Note that eval will run for max_train_steps each loop,
            # regardless of the global_step count.
            tf.compat.v1.logging.info('Starting to evaluate clean.')
            eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                               steps=flags_obj.max_train_steps)
            tf.compat.v1.logging.info('Starting to evaluate adv.')
            eval_results_adv = classifier_adv.evaluate(
                input_fn=input_fn_eval, steps=flags_obj.max_train_steps)
            tf.compat.v1.logging.info('Starting to evaluate attack.')
            eval_results_attack = classifier_attack.evaluate(
                input_fn=input_fn_eval_attack, steps=flags_obj.max_train_steps)
            print(
                '########################## clean #############################'
            )
            benchmark_logger.log_evaluation_result(eval_results)
            print(
                '########################## adv #############################')
            benchmark_logger.log_evaluation_result(eval_results_adv)
            print(
                '########################## attack #############################'
            )
            benchmark_logger.log_evaluation_result(eval_results_attack)

            if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                                 eval_results['accuracy']):
                break

    if flags_obj.export_dir is not None:
        # Exports a saved model for the given classifier.
        export_dtype = flags_core.get_tf_dtype(flags_obj)
        if flags_obj.image_bytes_as_serving_input:
            input_receiver_fn = functools.partial(image_bytes_serving_input_fn,
                                                  shape,
                                                  dtype=export_dtype)
        else:
            input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
                shape, batch_size=flags_obj.batch_size, dtype=export_dtype)
        classifier.export_savedmodel(flags_obj.export_dir,
                                     input_receiver_fn,
                                     strip_default_attrs=True)

    stats = {}
    stats['eval_results'] = eval_results
    stats['eval_atttack_results'] = eval_results_attack
    stats['eval_adv_results'] = eval_results_adv
    stats['train_hooks'] = train_hooks

    return stats
Example #29
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using native Keras APIs.

    Args:
      flags_obj: An object containing parsed flag values.

    Raises:
      ValueError: If fp16 is passed as it is not currently supported.

    Returns:
      Dictionary of training and eval stats.
    """
    keras_utils.set_session_config(
        enable_eager=flags_obj.enable_eager,
        enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        common.set_gpu_thread_mode_and_count(flags_obj)
    if flags_obj.data_delay_prefetch:
        common.data_delay_prefetch()
    common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'float16':
        policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    # Configures cluster spec for distribution strategy.
    num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                                       flags_obj.task_index)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=num_workers,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional
        )

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    # pylint: disable=protected-access
    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = common.get_synth_input_fn(
            height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            num_channels=imagenet_preprocessing.NUM_CHANNELS,
            num_classes=imagenet_preprocessing.NUM_CLASSES,
            dtype=dtype,
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = imagenet_preprocessing.input_fn

    # When `enable_xla` is True, we always drop the remainder of the batches
    # in the dataset, as XLA-GPU doesn't support dynamic shapes.
    drop_remainder = flags_obj.enable_xla

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=imagenet_preprocessing.parse_record,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        drop_remainder=drop_remainder,
        tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
    )

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=flags_obj.batch_size,
            num_epochs=flags_obj.train_epochs,
            parse_record_fn=imagenet_preprocessing.parse_record,
            dtype=dtype,
            drop_remainder=drop_remainder)

    lr_schedule = 0.1
    if flags_obj.use_tensor_lr:
        lr_schedule = common.PiecewiseConstantDecayWithWarmup(
            batch_size=flags_obj.batch_size,
            epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
            warmup_epochs=LR_SCHEDULE[0][1],
            boundaries=list(p[1] for p in LR_SCHEDULE[1:]),
            multipliers=list(p[0] for p in LR_SCHEDULE),
            compute_lr_on_cpu=True)

    with strategy_scope:
        optimizer = common.get_optimizer(lr_schedule)
        if dtype == 'float16':
            # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
            # can be enabled with a single line of code.
            optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
                optimizer, loss_scale=flags_core.get_loss_scale(flags_obj,
                                                                default_for_fp16=128))

        if flags_obj.use_trivial_model:
            model = trivial_model.trivial_model(
                imagenet_preprocessing.NUM_CLASSES, dtype)
        else:
            model = resnet_model.resnet50(
                num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype)

        # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
        # a valid arg for this model. Also remove as a valid flag.
        if flags_obj.force_v2_in_keras_compile is not None:
            model.compile(
                loss='sparse_categorical_crossentropy',
                optimizer=optimizer,
                metrics=(['sparse_categorical_accuracy']
                         if flags_obj.report_accuracy_metrics else None),
                run_eagerly=flags_obj.run_eagerly,
                experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)
        else:
            model.compile(
                loss='sparse_categorical_crossentropy',
                optimizer=optimizer,
                metrics=(['sparse_categorical_accuracy']
                         if flags_obj.report_accuracy_metrics else None),
                run_eagerly=flags_obj.run_eagerly)

    callbacks = common.get_callbacks(
        learning_rate_schedule, imagenet_preprocessing.NUM_IMAGES['train'])

    train_steps = (
            imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
    train_epochs = flags_obj.train_epochs

    if flags_obj.train_steps:
        train_steps = min(flags_obj.train_steps, train_steps)
        train_epochs = 1

    num_eval_steps = (
            imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        # Only build the training graph. This reduces memory usage introduced by
        # control flow ops in layers that have different implementations for
        # training and inference (e.g., batch norm).
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    if not strategy and flags_obj.explicit_gpu_placement:
        # TODO(b/135607227): Add device scope automatically in Keras training loop
        # when not using distribition strategy.
        no_dist_strat_device = tf.device('/device:GPU:0')
        no_dist_strat_device.__enter__()

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps//15,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=1)

    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=1)

    if not strategy and flags_obj.explicit_gpu_placement:
        no_dist_strat_device.__exit__()

    stats = common.build_stats(history, eval_output, callbacks)
    return stats
Example #30
0
  def __init__(self, flags_obj):
    """Init function of TransformerMain.

    Args:
      flags_obj: Object containing parsed flag values, i.e., FLAGS.

    Raises:
      ValueError: if not using static batch for input data on TPU.
    """
    self.flags_obj = flags_obj
    self.predict_model = None

    # Add flag-defined parameters to params object
    num_gpus = flags_core.get_num_gpus(flags_obj)
    self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus)

    params["num_gpus"] = num_gpus
    params["use_ctl"] = flags_obj.use_ctl
    params["data_dir"] = flags_obj.data_dir
    params["model_dir"] = flags_obj.model_dir
    params["static_batch"] = flags_obj.static_batch
    params["max_length"] = flags_obj.max_length
    params["decode_batch_size"] = flags_obj.decode_batch_size
    params["decode_max_length"] = flags_obj.decode_max_length
    params["padded_decode"] = flags_obj.padded_decode
    params["num_parallel_calls"] = (
        flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE)

    params["use_synthetic_data"] = flags_obj.use_synthetic_data
    params["batch_size"] = flags_obj.batch_size or params["default_batch_size"]
    params["repeat_dataset"] = None
    params["dtype"] = flags_core.get_tf_dtype(flags_obj)
    params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training

    if params["dtype"] == tf.float16:
      # TODO(reedwm): It's pretty ugly to set the global policy in a constructor
      # like this. What if multiple instances of TransformerTask are created?
      # We should have a better way in the tf.keras.mixed_precision API of doing
      # this.
      loss_scale = flags_core.get_loss_scale(flags_obj,
                                             default_for_fp16="dynamic")
      policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
          "mixed_float16", loss_scale=loss_scale)
      tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)

    self.distribution_strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=num_gpus,
        tpu_address=flags_obj.tpu or "")
    if self.use_tpu:
      params["num_replicas"] = self.distribution_strategy.num_replicas_in_sync
      if not params["static_batch"]:
        raise ValueError("TPU requires static batch for input data.")
    else:
      logging.info("Running transformer with num_gpus =", num_gpus)

    if self.distribution_strategy:
      logging.info("For training, using distribution strategy: ",
                   self.distribution_strategy)
    else:
      logging.info("Not using any distribution strategy.")
def run(flags_obj):
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
  """
  if flags_obj.enable_eager:
    tf.enable_eager_execution()

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == 'fp16':
    raise ValueError('dtype fp16 is not supported in Keras. Use the default '
                     'value(fp32).')

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
    input_fn = keras_common.get_synth_input_fn(
        height=imagenet_main.DEFAULT_IMAGE_SIZE,
        width=imagenet_main.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_main.NUM_CHANNELS,
        num_classes=imagenet_main.NUM_CLASSES,
        dtype=flags_core.get_tf_dtype(flags_obj))
  else:
    input_fn = imagenet_main.input_fn

  train_input_dataset = input_fn(is_training=True,
                                 data_dir=flags_obj.data_dir,
                                 batch_size=flags_obj.batch_size,
                                 num_epochs=flags_obj.train_epochs,
                                 parse_record_fn=parse_record_keras)

  eval_input_dataset = input_fn(is_training=False,
                                data_dir=flags_obj.data_dir,
                                batch_size=flags_obj.batch_size,
                                num_epochs=flags_obj.train_epochs,
                                parse_record_fn=parse_record_keras)

  strategy = distribution_utils.get_distribution_strategy(
      num_gpus=flags_obj.num_gpus,
      turn_off_distribution_strategy=flags_obj.turn_off_distribution_strategy)

  strategy_scope = keras_common.get_strategy_scope(strategy)

  with strategy_scope:
    optimizer = keras_common.get_optimizer()
    model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES)

    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer=optimizer,
                  metrics=['sparse_categorical_accuracy'])

  time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
      learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])

  train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

  num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
                    flags_obj.batch_size)

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
    tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None
    validation_data = None

  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
                      callbacks=[
                          time_callback,
                          lr_callback,
                          tensorboard_callback
                      ],
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      verbose=1)

  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=1)
  stats = keras_common.build_stats(history, eval_output, time_callback)
  return stats
Example #32
0
def resnet_main(
    flags_obj, model_function, input_function, dataset_name, shape=None):
  """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.
  """

  model_helpers.apply_clean(flags.FLAGS)

  # Using the Winograd non-fused algorithms provides a small performance boost.
  os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

  # Create session config based on values of inter_op_parallelism_threads and
  # intra_op_parallelism_threads. Note that we default to having
  # allow_soft_placement = True, which is required for multi-GPU and not
  # harmful for other modes.
  session_config = tf.ConfigProto(
      inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
      intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
      allow_soft_placement=True)

  distribution_strategy = distribution_utils.get_distribution_strategy(
      flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

  run_config = tf.estimator.RunConfig(
      train_distribute=distribution_strategy, session_config=session_config)

  classifier = tf.estimator.Estimator(
      model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config,
      params={
          'resnet_size': int(flags_obj.resnet_size),
          'data_format': flags_obj.data_format,
          'batch_size': flags_obj.batch_size,
          'resnet_version': int(flags_obj.resnet_version),
          'loss_scale': flags_core.get_loss_scale(flags_obj),
          'dtype': flags_core.get_tf_dtype(flags_obj)
      })

  run_params = {
      'batch_size': flags_obj.batch_size,
      'dtype': flags_core.get_tf_dtype(flags_obj),
      'resnet_size': flags_obj.resnet_size,
      'resnet_version': flags_obj.resnet_version,
      'synthetic_data': flags_obj.use_synthetic_data,
      'train_epochs': flags_obj.train_epochs,
  }
  if flags_obj.use_synthetic_data:
    dataset_name = dataset_name + '-synthetic'

  benchmark_logger = logger.get_benchmark_logger()
  benchmark_logger.log_run_info('resnet', dataset_name, run_params,
                                test_id=flags_obj.benchmark_test_id)

  train_hooks = hooks_helper.get_train_hooks(
      flags_obj.hooks,
      model_dir=flags_obj.model_dir,
      batch_size=flags_obj.batch_size)

  def input_fn_train():
    return input_function(
        is_training=True, data_dir=flags_obj.data_dir,
        batch_size=distribution_utils.per_device_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=flags_obj.epochs_between_evals,
        num_gpus=flags_core.get_num_gpus(flags_obj))

  def input_fn_eval():
    return input_function(
        is_training=False, data_dir=flags_obj.data_dir,
        batch_size=distribution_utils.per_device_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=1)

  total_training_cycle = (flags_obj.train_epochs //
                          flags_obj.epochs_between_evals)
  for cycle_index in range(total_training_cycle):
    tf.logging.info('Starting a training cycle: %d/%d',
                    cycle_index, total_training_cycle)

    classifier.train(input_fn=input_fn_train, hooks=train_hooks,
                     max_steps=flags_obj.max_train_steps)

    tf.logging.info('Starting to evaluate.')

    # flags_obj.max_train_steps is generally associated with testing and
    # profiling. As a result it is frequently called with synthetic data, which
    # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
    # eval (which is generally unimportant in those circumstances) to terminate.
    # Note that eval will run for max_train_steps each loop, regardless of the
    # global_step count.
    eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                       steps=flags_obj.max_train_steps)

    benchmark_logger.log_evaluation_result(eval_results)

    if model_helpers.past_stop_threshold(
        flags_obj.stop_threshold, eval_results['accuracy']):
      break

  if flags_obj.export_dir is not None:
    # Exports a saved model for the given classifier.
    input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
        shape, batch_size=flags_obj.batch_size)
    classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
Example #33
0
def run(flags_obj):
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
  # Eager is default in tf 2.0 and should not be toggled
  if keras_common.is_v2_0():
    keras_common.set_config_v2()
  else:
    config = keras_common.get_config_proto_v1()
    if flags_obj.enable_eager:
      tf.compat.v1.enable_eager_execution(config=config)
    else:
      sess = tf.Session(config=config)
      tf.keras.backend.set_session(sess)

  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
    keras_common.set_gpu_thread_mode_and_count(flags_obj)
  if flags_obj.data_prefetch_with_slack:
    keras_common.data_prefetch_with_slack()
  keras_common.set_cudnn_batchnorm_mode()

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == 'float16':
    policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
    tf.keras.mixed_precision.experimental.set_policy(policy)

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
      num_workers=distribution_utils.configure_cluster(),
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs)

  strategy_scope = distribution_utils.get_strategy_scope(strategy)

  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
    distribution_utils.set_up_synthetic_data()
    input_fn = keras_common.get_synth_input_fn(
        height=imagenet_main.DEFAULT_IMAGE_SIZE,
        width=imagenet_main.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_main.NUM_CHANNELS,
        num_classes=imagenet_main.NUM_CLASSES,
        dtype=dtype,
        drop_remainder=True)
  else:
    distribution_utils.undo_set_up_synthetic_data()
    input_fn = imagenet_main.input_fn

  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras,
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
      dtype=dtype,
      drop_remainder=drop_remainder)

  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=parse_record_keras,
        dtype=dtype,
        drop_remainder=drop_remainder)

  lr_schedule = 0.1
  if flags_obj.use_tensor_lr:
    lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup(
        batch_size=flags_obj.batch_size,
        epoch_size=imagenet_main.NUM_IMAGES['train'],
        warmup_epochs=LR_SCHEDULE[0][1],
        boundaries=list(p[1] for p in LR_SCHEDULE[1:]),
        multipliers=list(p[0] for p in LR_SCHEDULE),
        compute_lr_on_cpu=True)

  with strategy_scope:
    optimizer = keras_common.get_optimizer(lr_schedule)
    if dtype == 'float16':
      # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
      # can be enabled with a single line of code.
      optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
          optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))

    if flags_obj.enable_xla and not flags_obj.enable_eager:
      # TODO(b/129861005): Fix OOM issue in eager mode when setting
      # `batch_size` in keras.Input layer.
      if strategy and strategy.num_replicas_in_sync > 1:
        # TODO(b/129791381): Specify `input_layer_batch_size` value in
        # DistributionStrategy multi-replica case.
        input_layer_batch_size = None
      else:
        input_layer_batch_size = flags_obj.batch_size
    else:
      input_layer_batch_size = None

    if flags_obj.use_trivial_model:
      model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES, dtype)
    else:
      model = resnet_model.resnet50(
          num_classes=imagenet_main.NUM_CLASSES,
          dtype=dtype,
          batch_size=input_layer_batch_size)

    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer=optimizer,
                  metrics=(['sparse_categorical_accuracy']
                           if flags_obj.report_accuracy_metrics else None),
                  cloning=flags_obj.clone_model_in_keras_dist_strat)

  callbacks = keras_common.get_callbacks(
      learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])

  train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

  num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
                    flags_obj.batch_size)

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
    tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None
    validation_data = None

  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
                      callbacks=callbacks,
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      validation_freq=flags_obj.epochs_between_evals,
                      verbose=2)

  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)
  stats = keras_common.build_stats(history, eval_output, callbacks)
  return stats
Example #34
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
    # Eager is default in tf 2.0 and should not be toggled
    if keras_common.is_v2_0():
        keras_common.set_config_v2()
    else:
        config = keras_common.get_config_proto_v1()
        if flags_obj.enable_eager:
            tf.compat.v1.enable_eager_execution(config=config)
        else:
            sess = tf.Session(config=config)
            tf.keras.backend.set_session(sess)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_common.set_gpu_thread_mode_and_count(flags_obj)

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'float16':
        policy = tf.keras.mixed_precision.experimental.Policy(
            'infer_float32_vars')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=distribution_utils.configure_cluster())

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    # pylint: disable=protected-access
    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = keras_common.get_synth_input_fn(
            height=imagenet_main.DEFAULT_IMAGE_SIZE,
            width=imagenet_main.DEFAULT_IMAGE_SIZE,
            num_channels=imagenet_main.NUM_CHANNELS,
            num_classes=imagenet_main.NUM_CLASSES,
            dtype=dtype,
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = imagenet_main.input_fn

    # When `enable_xla` is True, we always drop the remainder of the batches
    # in the dataset, as XLA-GPU doesn't support dynamic shapes.
    drop_remainder = flags_obj.enable_xla

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=parse_record_keras,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        drop_remainder=drop_remainder)

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(is_training=False,
                                      data_dir=flags_obj.data_dir,
                                      batch_size=flags_obj.batch_size,
                                      num_epochs=flags_obj.train_epochs,
                                      parse_record_fn=parse_record_keras,
                                      dtype=dtype,
                                      drop_remainder=drop_remainder)

    with strategy_scope:
        optimizer = keras_common.get_optimizer()
        if dtype == 'float16':
            # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
            # can be enabled with a single line of code.
            optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
                optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))

        if flags_obj.enable_xla and not flags_obj.enable_eager:
            # TODO(b/129861005): Fix OOM issue in eager mode when setting
            # `batch_size` in keras.Input layer.
            if strategy and strategy.num_replicas_in_sync > 1:
                # TODO(b/129791381): Specify `input_layer_batch_size` value in
                # DistributionStrategy multi-replica case.
                input_layer_batch_size = None
            else:
                input_layer_batch_size = flags_obj.batch_size
        else:
            input_layer_batch_size = None

        if flags_obj.use_trivial_model:
            model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES)
        else:
            model = resnet_model.resnet50(
                num_classes=imagenet_main.NUM_CLASSES,
                dtype=dtype,
                batch_size=input_layer_batch_size)

        model.compile(loss='sparse_categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=['sparse_categorical_accuracy'])

    callbacks = keras_common.get_callbacks(learning_rate_schedule,
                                           imagenet_main.NUM_IMAGES['train'])

    train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
    train_epochs = flags_obj.train_epochs

    if flags_obj.train_steps:
        train_steps = min(flags_obj.train_steps, train_steps)
        train_epochs = 1

    num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        # Only build the training graph. This reduces memory usage introduced by
        # control flow ops in layers that have different implementations for
        # training and inference (e.g., batch norm).
        tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)

    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)
    stats = keras_common.build_stats(history, eval_output, callbacks)
    return stats
Example #35
0
def resnet_main(
    flags_obj, model_function, input_function, dataset_name, shape=None):
  """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.
  """

  # Using the Winograd non-fused algorithms provides a small performance boost.
  os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

  # Create session config based on values of inter_op_parallelism_threads and
  # intra_op_parallelism_threads. Note that we default to having
  # allow_soft_placement = True, which is required for multi-GPU and not
  # harmful for other modes.
  session_config = tf.ConfigProto(
      inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
      intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
      allow_soft_placement=True)

  if flags_core.get_num_gpus(flags_obj) == 0:
    distribution = tf.contrib.distribute.OneDeviceStrategy('device:CPU:0')
  elif flags_core.get_num_gpus(flags_obj) == 1:
    distribution = tf.contrib.distribute.OneDeviceStrategy('device:GPU:0')
  else:
    distribution = tf.contrib.distribute.MirroredStrategy(
        num_gpus=flags_core.get_num_gpus(flags_obj)
    )

  run_config = tf.estimator.RunConfig(train_distribute=distribution,
                                      session_config=session_config)

  classifier = tf.estimator.Estimator(
      model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config,
      params={
          'resnet_size': int(flags_obj.resnet_size),
          'data_format': flags_obj.data_format,
          'batch_size': flags_obj.batch_size,
          'resnet_version': int(flags_obj.resnet_version),
          'loss_scale': flags_core.get_loss_scale(flags_obj),
          'dtype': flags_core.get_tf_dtype(flags_obj)
      })

  run_params = {
      'batch_size': flags_obj.batch_size,
      'dtype': flags_core.get_tf_dtype(flags_obj),
      'resnet_size': flags_obj.resnet_size,
      'resnet_version': flags_obj.resnet_version,
      'synthetic_data': flags_obj.use_synthetic_data,
      'train_epochs': flags_obj.train_epochs,
  }
  benchmark_logger = logger.config_benchmark_logger(flags_obj)
  benchmark_logger.log_run_info('resnet', dataset_name, run_params)

  train_hooks = hooks_helper.get_train_hooks(
      flags_obj.hooks,
      batch_size=flags_obj.batch_size)

  def input_fn_train():
    return input_function(
        is_training=True, data_dir=flags_obj.data_dir,
        batch_size=per_device_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=flags_obj.epochs_between_evals)

  def input_fn_eval():
    return input_function(
        is_training=False, data_dir=flags_obj.data_dir,
        batch_size=per_device_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=1)

  total_training_cycle = (flags_obj.train_epochs //
                          flags_obj.epochs_between_evals)
  for cycle_index in range(total_training_cycle):
    tf.logging.info('Starting a training cycle: %d/%d',
                    cycle_index, total_training_cycle)

    classifier.train(input_fn=input_fn_train, hooks=train_hooks,
                     max_steps=flags_obj.max_train_steps)

    tf.logging.info('Starting to evaluate.')

    # flags_obj.max_train_steps is generally associated with testing and
    # profiling. As a result it is frequently called with synthetic data, which
    # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
    # eval (which is generally unimportant in those circumstances) to terminate.
    # Note that eval will run for max_train_steps each loop, regardless of the
    # global_step count.
    eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                       steps=flags_obj.max_train_steps)

    benchmark_logger.log_evaluation_result(eval_results)

    if model_helpers.past_stop_threshold(
        flags_obj.stop_threshold, eval_results['accuracy']):
      break

  if flags_obj.export_dir is not None:
    # Exports a saved model for the given classifier.
    input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
        shape, batch_size=flags_obj.batch_size)
    classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using custom training loops.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == tf.bfloat16:
        policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
            'mixed_bfloat16')
        tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)

    # TODO(anj-s): Set data_format without using Keras.
    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=distribution_utils.configure_cluster(),
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    train_ds, test_ds = get_input_dataset(flags_obj, strategy)
    per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
        flags_obj)
    steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps)
    logging.info(
        "Training %d epochs, each epoch has %d steps, "
        "total steps: %d; Eval %d steps", train_epochs, per_epoch_steps,
        train_epochs * per_epoch_steps, eval_steps)

    time_callback = keras_utils.TimeHistory(flags_obj.batch_size,
                                            flags_obj.log_steps)

    with distribution_utils.get_strategy_scope(strategy):
        model = resnet_model.resnet50(
            num_classes=imagenet_preprocessing.NUM_CLASSES,
            batch_size=flags_obj.batch_size,
            use_l2_regularizer=not flags_obj.single_l2_loss_op)

        lr_schedule = common.PiecewiseConstantDecayWithWarmup(
            batch_size=flags_obj.batch_size,
            epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
            warmup_epochs=common.LR_SCHEDULE[0][1],
            boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
            multipliers=list(p[0] for p in common.LR_SCHEDULE),
            compute_lr_on_cpu=True)
        optimizer = common.get_optimizer(lr_schedule)

        if flags_obj.fp16_implementation == 'graph_rewrite':
            if not flags_obj.use_tf_function:
                raise ValueError(
                    '--fp16_implementation=graph_rewrite requires '
                    '--use_tf_function to be true')
            loss_scale = flags_core.get_loss_scale(flags_obj,
                                                   default_for_fp16=128)
            optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                optimizer, loss_scale)

        train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
        training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'training_accuracy', dtype=tf.float32)
        test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
        test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'test_accuracy', dtype=tf.float32)

        trainable_variables = model.trainable_variables

        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            with tf.GradientTape() as tape:
                logits = model(images, training=True)

                prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, logits)
                loss = tf.reduce_sum(prediction_loss) * (1.0 /
                                                         flags_obj.batch_size)
                num_replicas = tf.distribute.get_strategy(
                ).num_replicas_in_sync

                if flags_obj.single_l2_loss_op:
                    filtered_variables = [
                        tf.reshape(v, (-1, )) for v in trainable_variables
                        if 'bn' not in v.name
                    ]
                    l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss(
                        tf.concat(filtered_variables, axis=0))
                    loss += (l2_loss / num_replicas)
                else:
                    loss += (tf.reduce_sum(model.losses) / num_replicas)

                # Scale the loss
                if flags_obj.dtype == "fp16":
                    loss = optimizer.get_scaled_loss(loss)

            grads = tape.gradient(loss, trainable_variables)

            # Unscale the grads
            if flags_obj.dtype == "fp16":
                grads = optimizer.get_unscaled_gradients(grads)

            optimizer.apply_gradients(zip(grads, trainable_variables))
            train_loss.update_state(loss)
            training_accuracy.update_state(labels, logits)

        @tf.function
        def train_steps(iterator, steps):
            """Performs distributed training steps in a loop."""
            for _ in tf.range(steps):
                strategy.experimental_run_v2(step_fn, args=(next(iterator), ))

        def train_single_step(iterator):
            if strategy:
                strategy.experimental_run_v2(step_fn, args=(next(iterator), ))
            else:
                return step_fn(next(iterator))

        def test_step(iterator):
            """Evaluation StepFn."""
            def step_fn(inputs):
                images, labels = inputs
                logits = model(images, training=False)
                loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, logits)
                loss = tf.reduce_sum(loss) * (1.0 / flags_obj.batch_size)
                test_loss.update_state(loss)
                test_accuracy.update_state(labels, logits)

            if strategy:
                strategy.experimental_run_v2(step_fn, args=(next(iterator), ))
            else:
                step_fn(next(iterator))

        if flags_obj.use_tf_function:
            train_single_step = tf.function(train_single_step)
            test_step = tf.function(test_step)

        train_iter = iter(train_ds)
        time_callback.on_train_begin()
        for epoch in range(train_epochs):
            train_loss.reset_states()
            training_accuracy.reset_states()

            steps_in_current_epoch = 0
            while steps_in_current_epoch < per_epoch_steps:
                time_callback.on_batch_begin(steps_in_current_epoch +
                                             epoch * per_epoch_steps)
                steps = _steps_to_run(steps_in_current_epoch, per_epoch_steps,
                                      steps_per_loop)
                if steps == 1:
                    train_single_step(train_iter)
                else:
                    # Converts steps to a Tensor to avoid tf.function retracing.
                    train_steps(train_iter,
                                tf.convert_to_tensor(steps, dtype=tf.int32))
                time_callback.on_batch_end(steps_in_current_epoch +
                                           epoch * per_epoch_steps)
                steps_in_current_epoch += steps

            logging.info('Training loss: %s, accuracy: %s at epoch %d',
                         train_loss.result().numpy(),
                         training_accuracy.result().numpy(), epoch + 1)

            if (not flags_obj.skip_eval
                    and (epoch + 1) % flags_obj.epochs_between_evals == 0):
                test_loss.reset_states()
                test_accuracy.reset_states()

                test_iter = iter(test_ds)
                for _ in range(eval_steps):
                    test_step(test_iter)

                logging.info('Test loss: %s, accuracy: %s%% at epoch: %d',
                             test_loss.result().numpy(),
                             test_accuracy.result().numpy(), epoch + 1)

        time_callback.on_train_end()

        eval_result = None
        train_result = None
        if not flags_obj.skip_eval:
            eval_result = [
                test_loss.result().numpy(),
                test_accuracy.result().numpy()
            ]
            train_result = [
                train_loss.result().numpy(),
                training_accuracy.result().numpy()
            ]

        stats = build_stats(train_result, eval_result, time_callback)
        return stats
Example #37
0
def resnet_main(flags_obj,
                model_function,
                input_function,
                dataset_name,
                shape=None,
                num_images=None,
                zeroshot_eval=False):
    model_helpers.apply_clean(flags.FLAGS)

    # Using the Winograd non-fused algorithms provides a small performance boost.
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

    # Create session config based on values of inter_op_parallelism_threads and
    # intra_op_parallelism_threads. Note that we default to having
    # allow_soft_placement = True, which is required for multi-GPU and not
    # harmful for other modes.
    session_config = config_utils.get_session_config(flags_obj)
    run_config = config_utils.get_run_config(flags_obj, flags_core,
                                             session_config,
                                             num_images['train'])

    def gen_estimator(period=None):
        resnet_size = int(flags_obj.resnet_size)
        data_format = flags_obj.data_format
        batch_size = flags_obj.batch_size
        resnet_version = int(flags_obj.resnet_version)
        loss_scale = flags_core.get_loss_scale(flags_obj)
        dtype_tf = flags_core.get_tf_dtype(flags_obj)
        num_epochs_per_decay = flags_obj.num_epochs_per_decay
        learning_rate_decay_factor = flags_obj.learning_rate_decay_factor
        end_learning_rate = flags_obj.end_learning_rate
        learning_rate_decay_type = flags_obj.learning_rate_decay_type
        weight_decay = flags_obj.weight_decay
        zero_gamma = flags_obj.zero_gamma
        lr_warmup_epochs = flags_obj.lr_warmup_epochs
        base_learning_rate = flags_obj.base_learning_rate
        use_resnet_d = flags_obj.use_resnet_d
        use_dropblock = flags_obj.use_dropblock
        dropblock_kp = [float(be) for be in flags_obj.dropblock_kp]
        label_smoothing = flags_obj.label_smoothing
        momentum = flags_obj.momentum
        bn_momentum = flags_obj.bn_momentum
        train_epochs = flags_obj.train_epochs
        piecewise_lr_boundary_epochs = [
            int(be) for be in flags_obj.piecewise_lr_boundary_epochs
        ]
        piecewise_lr_decay_rates = [
            float(dr) for dr in flags_obj.piecewise_lr_decay_rates
        ]
        use_ranking_loss = flags_obj.use_ranking_loss
        use_se_block = flags_obj.use_se_block
        use_sk_block = flags_obj.use_sk_block
        mixup_type = flags_obj.mixup_type
        dataset_name = flags_obj.dataset_name
        kd_temp = flags_obj.kd_temp
        no_downsample = flags_obj.no_downsample
        anti_alias_filter_size = flags_obj.anti_alias_filter_size
        anti_alias_type = flags_obj.anti_alias_type
        cls_loss_type = flags_obj.cls_loss_type
        logit_type = flags_obj.logit_type
        embedding_size = flags_obj.embedding_size
        pool_type = flags_obj.pool_type
        arc_s = flags_obj.arc_s
        arc_m = flags_obj.arc_m
        bl_alpha = flags_obj.bl_alpha
        bl_beta = flags_obj.bl_beta
        exp = None

        if install_hyperdash and flags_obj.use_hyperdash:
            exp = Experiment(flags_obj.model_dir.split("/")[-1])
            resnet_size = exp.param("resnet_size", int(flags_obj.resnet_size))
            batch_size = exp.param("batch_size", flags_obj.batch_size)
            exp.param("dtype", flags_obj.dtype)
            learning_rate_decay_type = exp.param(
                "learning_rate_decay_type", flags_obj.learning_rate_decay_type)
            weight_decay = exp.param("weight_decay", flags_obj.weight_decay)
            zero_gamma = exp.param("zero_gamma", flags_obj.zero_gamma)
            lr_warmup_epochs = exp.param("lr_warmup_epochs",
                                         flags_obj.lr_warmup_epochs)
            base_learning_rate = exp.param("base_learning_rate",
                                           flags_obj.base_learning_rate)
            use_dropblock = exp.param("use_dropblock", flags_obj.use_dropblock)
            dropblock_kp = exp.param(
                "dropblock_kp", [float(be) for be in flags_obj.dropblock_kp])
            piecewise_lr_boundary_epochs = exp.param(
                "piecewise_lr_boundary_epochs",
                [int(be) for be in flags_obj.piecewise_lr_boundary_epochs])
            piecewise_lr_decay_rates = exp.param(
                "piecewise_lr_decay_rates",
                [float(dr) for dr in flags_obj.piecewise_lr_decay_rates])
            mixup_type = exp.param("mixup_type", flags_obj.mixup_type)
            dataset_name = exp.param("dataset_name", flags_obj.dataset_name)
            exp.param("autoaugment_type", flags_obj.autoaugment_type)

        classifier = tf.estimator.Estimator(
            model_fn=model_function,
            model_dir=flags_obj.model_dir,
            config=run_config,
            params={
                'resnet_size': resnet_size,
                'data_format': data_format,
                'batch_size': batch_size,
                'resnet_version': resnet_version,
                'loss_scale': loss_scale,
                'dtype': dtype_tf,
                'num_epochs_per_decay': num_epochs_per_decay,
                'learning_rate_decay_factor': learning_rate_decay_factor,
                'end_learning_rate': end_learning_rate,
                'learning_rate_decay_type': learning_rate_decay_type,
                'weight_decay': weight_decay,
                'zero_gamma': zero_gamma,
                'lr_warmup_epochs': lr_warmup_epochs,
                'base_learning_rate': base_learning_rate,
                'use_resnet_d': use_resnet_d,
                'use_dropblock': use_dropblock,
                'dropblock_kp': dropblock_kp,
                'label_smoothing': label_smoothing,
                'momentum': momentum,
                'bn_momentum': bn_momentum,
                'embedding_size': embedding_size,
                'train_epochs': train_epochs,
                'piecewise_lr_boundary_epochs': piecewise_lr_boundary_epochs,
                'piecewise_lr_decay_rates': piecewise_lr_decay_rates,
                'with_drawing_bbox': flags_obj.with_drawing_bbox,
                'use_ranking_loss': use_ranking_loss,
                'use_se_block': use_se_block,
                'use_sk_block': use_sk_block,
                'mixup_type': mixup_type,
                'kd_temp': kd_temp,
                'no_downsample': no_downsample,
                'dataset_name': dataset_name,
                'anti_alias_filter_size': anti_alias_filter_size,
                'anti_alias_type': anti_alias_type,
                'cls_loss_type': cls_loss_type,
                'logit_type': logit_type,
                'arc_s': arc_s,
                'arc_m': arc_m,
                'pool_type': pool_type,
                'bl_alpha': bl_alpha,
                'bl_beta': bl_beta,
                'train_steps': total_train_steps,
            })
        return classifier, exp

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'resnet_size': flags_obj.resnet_size,
        'resnet_version': flags_obj.resnet_version,
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
    }
    if flags_obj.use_synthetic_data:
        dataset_name = dataset_name + '-synthetic'

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info('resnet',
                                  dataset_name,
                                  run_params,
                                  test_id=flags_obj.benchmark_test_id)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    def input_fn_train(num_epochs):
        return input_function(is_training=True,
                              use_random_crop=flags_obj.training_random_crop,
                              num_epochs=num_epochs,
                              flags_obj=flags_obj)

    def input_fn_eval():
        return input_function(is_training=False,
                              use_random_crop=False,
                              num_epochs=1,
                              flags_obj=flags_obj)

    ckpt_keeper = checkpoint_utils.CheckpointKeeper(
        save_dir=flags_obj.model_dir,
        num_to_keep=flags_obj.num_best_ckpt_to_keep,
        keep_epoch=flags_obj.keep_ckpt_every_eval,
        maximize=True)

    if zeroshot_eval:
        dataset = data_config.get_config(dataset_name)
        model = model_fns.Model(
            int(flags_obj.resnet_size),
            flags_obj.data_format,
            resnet_version=int(flags_obj.resnet_version),
            num_classes=dataset.num_classes,
            zero_gamma=flags_obj.zero_gamma,
            use_se_block=flags_obj.use_se_block,
            use_sk_block=flags_obj.use_sk_block,
            no_downsample=flags_obj.no_downsample,
            anti_alias_filter_size=flags_obj.anti_alias_filter_size,
            anti_alias_type=flags_obj.anti_alias_type,
            bn_momentum=flags_obj.bn_momentum,
            embedding_size=flags_obj.embedding_size,
            pool_type=flags_obj.pool_type,
            bl_alpha=flags_obj.bl_alpha,
            bl_beta=flags_obj.bl_beta,
            dtype=flags_core.get_tf_dtype(flags_obj),
            loss_type=flags_obj.cls_loss_type)

    def train_and_evaluate(hooks):
        tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops))

        if num_train_epochs:
            classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
                             hooks=hooks,
                             steps=flags_obj.max_train_steps)

        tf.logging.info('Starting to evaluate.')

        if zeroshot_eval:
            tf.reset_default_graph()
            eval_results = recall_metric.recall_at_k(
                flags_obj,
                flags_core,
                input_fns.input_fn_ir_eval,
                model,
                num_images['validation'],
                eval_similarity=flags_obj.eval_similarity,
                return_embedding=True)
        else:
            eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                               steps=flags_obj.max_train_steps)

        return eval_results

    total_train_steps = flags_obj.train_epochs * int(
        num_images['train'] / flags_obj.batch_size)

    if flags_obj.eval_only or not flags_obj.train_epochs:
        # If --eval_only is set, perform a single loop with zero train epochs.
        schedule, n_loops = [0], 1
    elif flags_obj.export_only:
        schedule, n_loops = [], 0
    else:
        n_loops = math.ceil(flags_obj.train_epochs /
                            flags_obj.epochs_between_evals)
        schedule = [
            flags_obj.epochs_between_evals for _ in range(int(n_loops))
        ]
        schedule[-1] = flags_obj.train_epochs - sum(
            schedule[:-1])  # over counting.

        schedule = config_utils.get_epoch_schedule(flags_obj, schedule,
                                                   num_images)
        tf.logging.info('epoch schedule:')
        tf.logging.info(schedule)

    classifier, exp = gen_estimator()
    if flags_obj.pretrained_model_checkpoint_path:
        warm_start_hook = WarmStartHook(
            flags_obj.pretrained_model_checkpoint_path)
        train_hooks.append(warm_start_hook)

    for cycle_index, num_train_epochs in enumerate(schedule):
        eval_results = train_and_evaluate(train_hooks)
        if zeroshot_eval:
            metric = eval_results['recall_at_1']
        else:
            metric = eval_results['accuracy']
        ckpt_keeper.save(metric, flags_obj.model_dir)
        if exp:
            exp.metric("accuracy", metric)
        benchmark_logger.log_evaluation_result(eval_results)
        if model_helpers.past_stop_threshold(flags_obj.stop_threshold, metric):
            break
        if model_helpers.past_stop_threshold(total_train_steps,
                                             eval_results['global_step']):
            break

    if exp:
        exp.end()

    if flags_obj.export_dir is not None:
        export_utils.export_pb(flags_core, flags_obj, shape, classifier)
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using custom training loops.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_xla=flags_obj.enable_xla)
    performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))

    if tf.config.list_physical_devices('GPU'):
        if flags_obj.tf_gpu_thread_mode:
            keras_utils.set_gpu_thread_mode_and_count(
                per_gpu_thread_count=flags_obj.per_gpu_thread_count,
                gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
                num_gpus=flags_obj.num_gpus,
                datasets_num_private_threads=flags_obj.
                datasets_num_private_threads)
        common.set_cudnn_batchnorm_mode()

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.config.list_physical_devices('GPU') else
                       'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
        flags_obj)
    if flags_obj.steps_per_loop is None:
        steps_per_loop = per_epoch_steps
    elif flags_obj.steps_per_loop > per_epoch_steps:
        steps_per_loop = per_epoch_steps
        logging.warn('Setting steps_per_loop to %d to respect epoch boundary.',
                     steps_per_loop)
    else:
        steps_per_loop = flags_obj.steps_per_loop

    logging.info(
        'Training %d epochs, each epoch has %d steps, '
        'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
        train_epochs * per_epoch_steps, eval_steps)

    time_callback = keras_utils.TimeHistory(
        flags_obj.batch_size,
        flags_obj.log_steps,
        logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
    with distribute_utils.get_strategy_scope(strategy):
        runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
                                                  per_epoch_steps)

    eval_interval = flags_obj.epochs_between_evals * per_epoch_steps
    checkpoint_interval = (steps_per_loop * 5
                           if flags_obj.enable_checkpoint_and_export else None)
    summary_interval = steps_per_loop if flags_obj.enable_tensorboard else None

    checkpoint_manager = tf.train.CheckpointManager(
        runnable.checkpoint,
        directory=flags_obj.model_dir,
        max_to_keep=10,
        step_counter=runnable.global_step,
        checkpoint_interval=checkpoint_interval)

    resnet_controller = orbit.Controller(
        strategy=strategy,
        trainer=runnable,
        evaluator=runnable if not flags_obj.skip_eval else None,
        global_step=runnable.global_step,
        steps_per_loop=steps_per_loop,
        checkpoint_manager=checkpoint_manager,
        summary_interval=summary_interval,
        summary_dir=flags_obj.model_dir,
        eval_summary_dir=os.path.join(flags_obj.model_dir, 'eval'))

    time_callback.on_train_begin()
    if not flags_obj.skip_eval:
        resnet_controller.train_and_evaluate(train_steps=per_epoch_steps *
                                             train_epochs,
                                             eval_steps=eval_steps,
                                             eval_interval=eval_interval)
    else:
        resnet_controller.train(steps=per_epoch_steps * train_epochs)
    time_callback.on_train_end()

    stats = build_stats(runnable, time_callback)
    return stats
Example #39
0
def resnet_main(
    flags_obj, model_function, input_function, dataset_name, shape=None):
  """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.

  Returns:
    Dict of results of the run.
  """

  model_helpers.apply_clean(flags.FLAGS)

  # Ensures flag override logic is only executed if explicitly triggered.
  if flags_obj.tf_gpu_thread_mode:
    override_flags_and_set_envars_for_gpu_thread_pool(flags_obj)

  # Creates session config. allow_soft_placement = True, is required for
  # multi-GPU and is not harmful for other modes.
  session_config = tf.ConfigProto(
      inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
      intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
      allow_soft_placement=True)

  distribution_strategy = distribution_utils.get_distribution_strategy(
      flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

  # Creates a `RunConfig` that checkpoints every 24 hours which essentially
  # results in checkpoints determined only by `epochs_between_evals`.
  run_config = tf.estimator.RunConfig(
      train_distribute=distribution_strategy,
      session_config=session_config,
      save_checkpoints_secs=60*60*24)

  # Initializes model with all but the dense layer from pretrained ResNet.
  if flags_obj.pretrained_model_checkpoint_path is not None:
    warm_start_settings = tf.estimator.WarmStartSettings(
        flags_obj.pretrained_model_checkpoint_path,
        vars_to_warm_start='^(?!.*dense)')
  else:
    warm_start_settings = None

  classifier = tf.estimator.Estimator(
      model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config,
      warm_start_from=warm_start_settings, params={
          'resnet_size': int(flags_obj.resnet_size),
          'data_format': flags_obj.data_format,
          'batch_size': flags_obj.batch_size,
          'resnet_version': int(flags_obj.resnet_version),
          'loss_scale': flags_core.get_loss_scale(flags_obj),
          'dtype': flags_core.get_tf_dtype(flags_obj),
          'fine_tune': flags_obj.fine_tune
      })

  run_params = {
      'batch_size': flags_obj.batch_size,
      'dtype': flags_core.get_tf_dtype(flags_obj),
      'resnet_size': flags_obj.resnet_size,
      'resnet_version': flags_obj.resnet_version,
      'synthetic_data': flags_obj.use_synthetic_data,
      'train_epochs': flags_obj.train_epochs,
  }
  if flags_obj.use_synthetic_data:
    dataset_name = dataset_name + '-synthetic'

  benchmark_logger = logger.get_benchmark_logger()
  benchmark_logger.log_run_info('resnet', dataset_name, run_params,
                                test_id=flags_obj.benchmark_test_id)

  train_hooks = hooks_helper.get_train_hooks(
      flags_obj.hooks,
      model_dir=flags_obj.model_dir,
      batch_size=flags_obj.batch_size)

  def input_fn_train(num_epochs):
    return input_function(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=distribution_utils.per_device_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=num_epochs,
        dtype=flags_core.get_tf_dtype(flags_obj),
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        num_parallel_batches=flags_obj.datasets_num_parallel_batches)

  def input_fn_eval():
    return input_function(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=distribution_utils.per_device_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=1,
        dtype=flags_core.get_tf_dtype(flags_obj))

  if flags_obj.eval_only or not flags_obj.train_epochs:
    # If --eval_only is set, perform a single loop with zero train epochs.
    schedule, n_loops = [0], 1
  else:
    # Compute the number of times to loop while training. All but the last
    # pass will train for `epochs_between_evals` epochs, while the last will
    # train for the number needed to reach `training_epochs`. For instance if
    #   train_epochs = 25 and epochs_between_evals = 10
    # schedule will be set to [10, 10, 5]. That is to say, the loop will:
    #   Train for 10 epochs and then evaluate.
    #   Train for another 10 epochs and then evaluate.
    #   Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
    n_loops = math.ceil(flags_obj.train_epochs / flags_obj.epochs_between_evals)
    schedule = [flags_obj.epochs_between_evals for _ in range(int(n_loops))]
    schedule[-1] = flags_obj.train_epochs - sum(schedule[:-1])  # over counting.

  for cycle_index, num_train_epochs in enumerate(schedule):
    tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops))

    if num_train_epochs:
      classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
                       hooks=train_hooks, max_steps=flags_obj.max_train_steps)

    tf.logging.info('Starting to evaluate.')

    # flags_obj.max_train_steps is generally associated with testing and
    # profiling. As a result it is frequently called with synthetic data, which
    # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
    # eval (which is generally unimportant in those circumstances) to terminate.
    # Note that eval will run for max_train_steps each loop, regardless of the
    # global_step count.
    eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                       steps=flags_obj.max_train_steps)

    benchmark_logger.log_evaluation_result(eval_results)

    if model_helpers.past_stop_threshold(
        flags_obj.stop_threshold, eval_results['accuracy']):
      break

  if flags_obj.export_dir is not None:
    # Exports a saved model for the given classifier.
    export_dtype = flags_core.get_tf_dtype(flags_obj)
    if flags_obj.image_bytes_as_serving_input:
      input_receiver_fn = functools.partial(
          image_bytes_serving_input_fn, shape, dtype=export_dtype)
    else:
      input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
          shape, batch_size=flags_obj.batch_size, dtype=export_dtype)
    classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn,
                                 strip_default_attrs=True)
  return eval_results
def dtype():
    return flags_core.get_tf_dtype(flags.FLAGS)
Example #41
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
  """
    if flags_obj.enable_eager:
        tf.enable_eager_execution()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'fp16':
        raise ValueError(
            'dtype fp16 is not supported in Keras. Use the default '
            'value(fp32).')

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    # pylint: disable=protected-access
    if flags_obj.use_synthetic_data:
        input_fn = keras_common.get_synth_input_fn(
            height=imagenet_main.DEFAULT_IMAGE_SIZE,
            width=imagenet_main.DEFAULT_IMAGE_SIZE,
            num_channels=imagenet_main.NUM_CHANNELS,
            num_classes=imagenet_main.NUM_CLASSES,
            dtype=flags_core.get_tf_dtype(flags_obj))
    else:
        input_fn = imagenet_main.input_fn

    train_input_dataset = input_fn(is_training=True,
                                   data_dir=flags_obj.data_dir,
                                   batch_size=flags_obj.batch_size,
                                   num_epochs=flags_obj.train_epochs,
                                   parse_record_fn=parse_record_keras)

    eval_input_dataset = input_fn(is_training=False,
                                  data_dir=flags_obj.data_dir,
                                  batch_size=flags_obj.batch_size,
                                  num_epochs=flags_obj.train_epochs,
                                  parse_record_fn=parse_record_keras)

    strategy = distribution_utils.get_distribution_strategy(
        num_gpus=flags_obj.num_gpus,
        turn_off_distribution_strategy=flags_obj.turn_off_distribution_strategy
    )

    strategy_scope = keras_common.get_strategy_scope(strategy)

    with strategy_scope:
        optimizer = keras_common.get_optimizer()
        model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES)

        model.compile(loss='sparse_categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=['sparse_categorical_accuracy'])

    time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
        learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])

    train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
    train_epochs = flags_obj.train_epochs

    if flags_obj.train_steps:
        train_steps = min(flags_obj.train_steps, train_steps)
        train_epochs = 1

    num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        # Only build the training graph. This reduces memory usage introduced by
        # control flow ops in layers that have different implementations for
        # training and inference (e.g., batch norm).
        tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    history = model.fit(
        train_input_dataset,
        epochs=train_epochs,
        steps_per_epoch=train_steps,
        callbacks=[time_callback, lr_callback, tensorboard_callback],
        validation_steps=num_eval_steps,
        validation_data=validation_data,
        verbose=2)

    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=1)
    stats = keras_common.build_stats(history, eval_output, time_callback)
    return stats
Example #42
0
def convinh_main(
    flags_obj, model_function, input_function, dataset_name, shape=None):
  """Shared main loop for convinh Models.

  Args:
    flags_obj: An object containing parsed flags. See define_convinh_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.
      _dir is passed.
  """

  model_helpers.apply_clean(flags.FLAGS)

  # Using the Winograd non-fused algorithms provides a small performance boost.
  os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

  # Create session config based on values of inter_op_parallelism_threads and
  # intra_op_parallelism_threads. Note that we default to having
  # allow_soft_placement = True, which is required for multi-GPU and not
  # harmful for other modes.
  session_config = tf.ConfigProto(
      inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
      intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
      allow_soft_placement=True)

  distribution_strategy = distribution_utils.get_distribution_strategy(
      flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)
  
  run_config = tf.estimator.RunConfig(
      tf_random_seed=flags_obj.seed,
      train_distribute=distribution_strategy, 
      session_config=session_config,
      keep_checkpoint_max = flags_obj.num_ckpt
      )

  classifier = tf.estimator.Estimator(
      model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config,
      params={
          'model_params':{
                'data_format':flags_obj.data_format, 
                'filters':list(map(int,flags_obj.filters)),
                'ratio_PV': flags_obj.ratio_PV,
                'ratio_SST': flags_obj.ratio_SST,
                'conv_kernel_size':list(map(int,flags_obj.conv_kernel_size)),
                'conv_kernel_size_inh':list(map(int,flags_obj.conv_kernel_size_inh)),
                'conv_strides':list(map(int,flags_obj.conv_strides)),
                'pool_size':list(map(int,flags_obj.pool_size)),
                'pool_strides':list(map(int,flags_obj.pool_strides)),
                'num_ff_layers':flags_obj.num_ff_layers,
                'num_rnn_layers':flags_obj.num_rnn_layers,
                'connection':flags_obj.connection,
                'n_time':flags_obj.n_time,
                'cell_fn':flags_obj.cell_fn,
                'act_fn':flags_obj.act_fn, 
                'pvsst_circuit':flags_obj.pvsst_circuit,
                'gating':flags_obj.gating,
                'normalize':flags_obj.normalize,
                'num_classes':flags_obj.num_classes
              },
          'batch_size' : flags_obj.batch_size,
          'weight_decay': flags_obj.weight_decay,
          'loss_scale': flags_core.get_loss_scale(flags_obj),
          'dtype': flags_core.get_tf_dtype(flags_obj)
      })

  run_params = {
      'batch_size': flags_obj.batch_size,
      'dtype': flags_core.get_tf_dtype(flags_obj),
      'convinh_size': flags_obj.convinh_size, # deprecated
      'convinh_version': flags_obj.convinh_version, # deprecated
      'synthetic_data': flags_obj.use_synthetic_data, # deprecated
      'train_epochs': flags_obj.train_epochs,
  }
  if flags_obj.use_synthetic_data:
    dataset_name = dataset_name + '-synthetic'

  benchmark_logger = logger.get_benchmark_logger()
  benchmark_logger.log_run_info('convinh', dataset_name, run_params,
                                test_id=flags_obj.benchmark_test_id)

  train_hooks = hooks_helper.get_train_hooks(
      flags_obj.hooks,
      model_dir=flags_obj.model_dir,
      batch_size=flags_obj.batch_size)
  
  class input_fn_train(object):
    def __init__(self,num_epochs):
      self._num_epochs = num_epochs
    def __call__(self):
      return input_function(
          is_training=True, data_dir=flags_obj.data_dir,
          batch_size=distribution_utils.per_device_batch_size(
              flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
          num_epochs=self._num_epochs,
          num_gpus=flags_core.get_num_gpus(flags_obj))

  def input_fn_eval():
    return input_function(
        is_training=False, data_dir=flags_obj.data_dir,
        batch_size=distribution_utils.per_device_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=1)
 
  tf.logging.info('Evaluate the intial model.')                          
  eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                       steps=flags_obj.max_train_steps)

  benchmark_logger.log_evaluation_result(eval_results)
  
  # training
  total_training_cycle = (flags_obj.train_epochs //
                          flags_obj.epochs_between_evals) + 1
                          
  for cycle_index in range(total_training_cycle):
    
    cur_train_epochs = flags_obj.epochs_between_evals if cycle_index else 1
    
    tf.logging.info('Starting a training cycle: %d/%d, with %d epochs',
                    cycle_index, total_training_cycle, cur_train_epochs)
    
    classifier.train(input_fn=input_fn_train(cur_train_epochs), 
                     hooks=train_hooks, max_steps=flags_obj.max_train_steps)

    tf.logging.info('Starting to evaluate.')

    # flags_obj.max_train_steps is generally associated with testing and
    # profiling. As a result it is frequently called with synthetic data, which
    # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
    # eval (which is generally unimportant in those circumstances) to terminate.
    # Note that eval will run for max_train_steps each loop, regardless of the
    # global_step count.
    eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                       steps=flags_obj.max_train_steps)

    benchmark_logger.log_evaluation_result(eval_results)

    if model_helpers.past_stop_threshold(
        flags_obj.stop_threshold, eval_results['accuracy']):
      break

    if flags_obj.export_dir is not None:
      # Exports a saved model for the given classifier.
      input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
          shape, batch_size=1)
      if cycle_index==0:
        classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn,
                checkpoint_path='{}/model.ckpt-0'.format(flags_obj.model_dir))
      classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)