Exemple #1
0
  def test_parse_dtype_info(self):
    for dtype_str, tf_dtype, loss_scale in [["fp16", tf.float16, 128],
                                            ["fp32", tf.float32, 1]]:
      flags_core.parse_flags([__file__, "--dtype", dtype_str])

      self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf_dtype)
      self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), loss_scale)

      flags_core.parse_flags(
          [__file__, "--dtype", dtype_str, "--loss_scale", "5"])

      self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), 5)

    with self.assertRaises(SystemExit):
      flags_core.parse_flags([__file__, "--dtype", "int8"])
def resnet_main(
    flags_obj, model_function, input_function, dataset_name, shape=None):
  """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.

  Returns:
     Dict of results of the run.  Contains the keys `eval_results` and
    `train_hooks`. `eval_results` contains accuracy (top_1) and accuracy_top_5.
    `train_hooks` is a list the instances of hooks used during training.
  """

  model_helpers.apply_clean(flags.FLAGS)

  # Ensures flag override logic is only executed if explicitly triggered.
  if flags_obj.tf_gpu_thread_mode:
    override_flags_and_set_envars_for_gpu_thread_pool(flags_obj)

  # Configures cluster spec for distribution strategy.
  num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                                     flags_obj.task_index)

  # Creates session config. allow_soft_placement = True, is required for
  # multi-GPU and is not harmful for other modes.
  session_config = tf.compat.v1.ConfigProto(
      inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
      intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
      allow_soft_placement=True)

  distribution_strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_core.get_num_gpus(flags_obj),
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs)

  # Creates a `RunConfig` that checkpoints every 24 hours which essentially
  # results in checkpoints determined only by `epochs_between_evals`.
  run_config = tf.estimator.RunConfig(
      train_distribute=distribution_strategy,
      session_config=session_config,
      save_checkpoints_secs=60*60*24,
      save_checkpoints_steps=None)

  # Initializes model with all but the dense layer from pretrained ResNet.
  if flags_obj.pretrained_model_checkpoint_path is not None:
    warm_start_settings = tf.estimator.WarmStartSettings(
        flags_obj.pretrained_model_checkpoint_path,
        vars_to_warm_start='^(?!.*dense)')
  else:
    warm_start_settings = None

  classifier = tf.estimator.Estimator(
      model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config,
      warm_start_from=warm_start_settings, params={
          'resnet_size': int(flags_obj.resnet_size),
          'data_format': flags_obj.data_format,
          'batch_size': flags_obj.batch_size,
          'resnet_version': int(flags_obj.resnet_version),
          'loss_scale': flags_core.get_loss_scale(flags_obj,
                                                  default_for_fp16=128),
          'dtype': flags_core.get_tf_dtype(flags_obj),
          'fine_tune': flags_obj.fine_tune,
          'num_workers': num_workers,
      })

  run_params = {
      'batch_size': flags_obj.batch_size,
      'dtype': flags_core.get_tf_dtype(flags_obj),
      'resnet_size': flags_obj.resnet_size,
      'resnet_version': flags_obj.resnet_version,
      'synthetic_data': flags_obj.use_synthetic_data,
      'train_epochs': flags_obj.train_epochs,
      'num_workers': num_workers,
  }
  if flags_obj.use_synthetic_data:
    dataset_name = dataset_name + '-synthetic'

  benchmark_logger = logger.get_benchmark_logger()
  benchmark_logger.log_run_info('resnet', dataset_name, run_params,
                                test_id=flags_obj.benchmark_test_id)

  train_hooks = hooks_helper.get_train_hooks(
      flags_obj.hooks,
      model_dir=flags_obj.model_dir,
      batch_size=flags_obj.batch_size)

  def input_fn_train(num_epochs, input_context=None):
    return input_function(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=per_replica_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=num_epochs,
        dtype=flags_core.get_tf_dtype(flags_obj),
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        input_context=input_context)

  def input_fn_eval():
    return input_function(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=per_replica_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=1,
        dtype=flags_core.get_tf_dtype(flags_obj))

  train_epochs = (0 if flags_obj.eval_only or not flags_obj.train_epochs else
                  flags_obj.train_epochs)

  use_train_and_evaluate = flags_obj.use_train_and_evaluate or num_workers > 1
  if use_train_and_evaluate:
    train_spec = tf.estimator.TrainSpec(
        input_fn=lambda input_context=None: input_fn_train(
            train_epochs, input_context=input_context),
        hooks=train_hooks,
        max_steps=flags_obj.max_train_steps)
    eval_spec = tf.estimator.EvalSpec(input_fn=input_fn_eval)
    logging.info('Starting to train and evaluate.')
    tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)
    # tf.estimator.train_and_evalute doesn't return anything in multi-worker
    # case.
    eval_results = {}
  else:
    if train_epochs == 0:
      # If --eval_only is set, perform a single loop with zero train epochs.
      schedule, n_loops = [0], 1
    else:
      # Compute the number of times to loop while training. All but the last
      # pass will train for `epochs_between_evals` epochs, while the last will
      # train for the number needed to reach `training_epochs`. For instance if
      #   train_epochs = 25 and epochs_between_evals = 10
      # schedule will be set to [10, 10, 5]. That is to say, the loop will:
      #   Train for 10 epochs and then evaluate.
      #   Train for another 10 epochs and then evaluate.
      #   Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
      n_loops = math.ceil(train_epochs / flags_obj.epochs_between_evals)
      schedule = [flags_obj.epochs_between_evals for _ in range(int(n_loops))]
      schedule[-1] = train_epochs - sum(schedule[:-1])  # over counting.

    for cycle_index, num_train_epochs in enumerate(schedule):
      logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops))

      if num_train_epochs:
        # Since we are calling classifier.train immediately in each loop, the
        # value of num_train_epochs in the lambda function will not be changed
        # before it is used. So it is safe to ignore the pylint error here
        # pylint: disable=cell-var-from-loop
        classifier.train(
            input_fn=lambda input_context=None: input_fn_train(
                num_train_epochs, input_context=input_context),
            hooks=train_hooks,
            max_steps=flags_obj.max_train_steps)

      # flags_obj.max_train_steps is generally associated with testing and
      # profiling. As a result it is frequently called with synthetic data,
      # which will iterate forever. Passing steps=flags_obj.max_train_steps
      # allows the eval (which is generally unimportant in those circumstances)
      # to terminate.  Note that eval will run for max_train_steps each loop,
      # regardless of the global_step count.
      logging.info('Starting to evaluate.')
      eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                         steps=flags_obj.max_train_steps)

      benchmark_logger.log_evaluation_result(eval_results)

      if model_helpers.past_stop_threshold(
          flags_obj.stop_threshold, eval_results['accuracy']):
        break

  if flags_obj.export_dir is not None:
    # Exports a saved model for the given classifier.
    export_dtype = flags_core.get_tf_dtype(flags_obj)
    if flags_obj.image_bytes_as_serving_input:
      input_receiver_fn = functools.partial(
          image_bytes_serving_input_fn, shape, dtype=export_dtype)
    else:
      input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
          shape, batch_size=flags_obj.batch_size, dtype=export_dtype)
    classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn,
                                 strip_default_attrs=True)

  stats = {}
  stats['eval_results'] = eval_results
  stats['train_hooks'] = train_hooks

  return stats
Exemple #3
0
def run_ncf(_):
  """Run NCF training and eval with Keras."""

  keras_utils.set_session_config(enable_xla=FLAGS.enable_xla)

  if FLAGS.seed is not None:
    print("Setting tf seed")
    tf.random.set_seed(FLAGS.seed)

  model_helpers.apply_clean(FLAGS)

  if FLAGS.dtype == "fp16" and FLAGS.fp16_implementation == "keras":
    policy = tf.keras.mixed_precision.experimental.Policy(
        "mixed_float16",
        loss_scale=flags_core.get_loss_scale(FLAGS, default_for_fp16="dynamic"))
    tf.keras.mixed_precision.experimental.set_policy(policy)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=FLAGS.distribution_strategy,
      num_gpus=FLAGS.num_gpus,
      tpu_address=FLAGS.tpu)

  params = ncf_common.parse_flags(FLAGS)
  params["distribute_strategy"] = strategy

  if not keras_utils.is_v2_0() and strategy is not None:
    logging.error("NCF Keras only works with distribution strategy in TF 2.0")
    return
  if (params["keras_use_ctl"] and (
      not keras_utils.is_v2_0() or strategy is None)):
    logging.error(
        "Custom training loop only works with tensorflow 2.0 and dist strat.")
    return
  if params["use_tpu"] and not params["keras_use_ctl"]:
    logging.error("Custom training loop must be used when using TPUStrategy.")
    return

  batch_size = params["batch_size"]
  time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps)
  callbacks = [time_callback]

  producer, input_meta_data = None, None
  generate_input_online = params["train_dataset_path"] is None

  if generate_input_online:
    # Start data producing thread.
    num_users, num_items, _, _, producer = ncf_common.get_inputs(params)
    producer.start()
    per_epoch_callback = IncrementEpochCallback(producer)
    callbacks.append(per_epoch_callback)
  else:
    assert params["eval_dataset_path"] and params["input_meta_data_path"]
    with tf.io.gfile.GFile(params["input_meta_data_path"], "rb") as reader:
      input_meta_data = json.loads(reader.read().decode("utf-8"))
      num_users = input_meta_data["num_users"]
      num_items = input_meta_data["num_items"]

  params["num_users"], params["num_items"] = num_users, num_items

  if FLAGS.early_stopping:
    early_stopping_callback = CustomEarlyStopping(
        "val_HR_METRIC", desired_value=FLAGS.hr_threshold)
    callbacks.append(early_stopping_callback)

  (train_input_dataset, eval_input_dataset,
   num_train_steps, num_eval_steps) = \
    (ncf_input_pipeline.create_ncf_input_data(
        params, producer, input_meta_data, strategy))
  steps_per_epoch = None if generate_input_online else num_train_steps

  with distribution_utils.get_strategy_scope(strategy):
    keras_model = _get_keras_model(params)
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=params["learning_rate"],
        beta_1=params["beta1"],
        beta_2=params["beta2"],
        epsilon=params["epsilon"])
    if FLAGS.fp16_implementation == "graph_rewrite":
      optimizer = \
        tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(
            optimizer,
            loss_scale=flags_core.get_loss_scale(FLAGS,
                                                 default_for_fp16="dynamic"))
    elif FLAGS.dtype == "fp16" and params["keras_use_ctl"]:
      # When keras_use_ctl is False, instead Model.fit() automatically applies
      # loss scaling so we don't need to create a LossScaleOptimizer.
      optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
          optimizer,
          tf.keras.mixed_precision.experimental.global_policy().loss_scale)

    if params["keras_use_ctl"]:
      train_loss, eval_results = run_ncf_custom_training(
          params,
          strategy,
          keras_model,
          optimizer,
          callbacks,
          train_input_dataset,
          eval_input_dataset,
          num_train_steps,
          num_eval_steps,
          generate_input_online=generate_input_online)
    else:
      keras_model.compile(optimizer=optimizer, run_eagerly=FLAGS.run_eagerly)

      if not FLAGS.ml_perf:
        # Create Tensorboard summary and checkpoint callbacks.
        summary_dir = os.path.join(FLAGS.model_dir, "summaries")
        summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
        checkpoint_path = os.path.join(FLAGS.model_dir, "checkpoint")
        checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
            checkpoint_path, save_weights_only=True)

        callbacks += [summary_callback, checkpoint_callback]

      history = keras_model.fit(
          train_input_dataset,
          epochs=FLAGS.train_epochs,
          steps_per_epoch=steps_per_epoch,
          callbacks=callbacks,
          validation_data=eval_input_dataset,
          validation_steps=num_eval_steps,
          verbose=2)

      logging.info("Training done. Start evaluating")

      eval_loss_and_metrics = keras_model.evaluate(
          eval_input_dataset, steps=num_eval_steps, verbose=2)

      logging.info("Keras evaluation is done.")

      # Keras evaluate() API returns scalar loss and metric values from
      # evaluation as a list. Here, the returned list would contain
      # [evaluation loss, hr sum, hr count].
      eval_hit_rate = eval_loss_and_metrics[1] / eval_loss_and_metrics[2]

      # Format evaluation result into [eval loss, eval hit accuracy].
      eval_results = [eval_loss_and_metrics[0], eval_hit_rate]

      if history and history.history:
        train_history = history.history
        train_loss = train_history["loss"][-1]

  stats = build_stats(train_loss, eval_results, time_callback)
  return stats
Exemple #4
0
    def __init__(self, flags_obj):
        """Init function

    Args:
      flags_obj: Object containing parsed flag values, i.e., FLAGS.
    """
        self.flags_obj = flags_obj

        # Add flag-defined parameters to params object
        num_gpus = flags_core.get_num_gpus(flags_obj)
        self.params = params = misc.get_model_params(flags_obj.param_set,
                                                     num_gpus)
        params["num_gpus"] = num_gpus
        params["data_dir"] = flags_obj.data_dir
        params["val_data_dir"] = flags_obj.val_data_dir
        params["model_dir"] = flags_obj.model_dir
        params["static_batch"] = flags_obj.static_batch
        params["max_input_length"] = flags_obj.max_input_length
        params["max_target_length"] = flags_obj.max_target_length
        params["decode_batch_size"] = flags_obj.decode_batch_size
        params["decode_max_length"] = flags_obj.decode_max_length
        params["padded_decode"] = flags_obj.padded_decode
        params["num_parallel_calls"] = (flags_obj.num_parallel_calls
                                        or tf.data.experimental.AUTOTUNE)

        params["use_synthetic_data"] = flags_obj.use_synthetic_data
        params["batch_size"] = flags_obj.batch_size * max(num_gpus, 1)
        logging.info('actual batch_size = {} * {}'.format(
            flags_obj.batch_size, max(num_gpus, 1)))
        params["repeat_dataset"] = None
        params["dtype"] = flags_core.get_tf_dtype(flags_obj)
        params[
            "enable_metrics_in_training"] = flags_obj.enable_metrics_in_training
        params["num_hashes"] = flags_obj.num_hashes
        params["test_num_hashes"] = flags_obj.test_num_hashes
        params[
            "use_full_attention_in_reformer"] = flags_obj.use_full_attention_in_reformer
        params["bucket_size"] = flags_obj.bucket_size

        if flags_obj.one_dropout is not None:
            params['layer_postprocess_dropout'] = flags_obj.one_dropout
            params['attention_dropout'] = flags_obj.one_dropout
            params['relu_dropout'] = flags_obj.one_dropout
        if flags_obj.attention_dropout is not None:
            params['attention_dropout'] = flags_obj.attention_dropout
        params['lsh_attention_dropout'] = params['attention_dropout'] if params[
            "use_full_attention_in_reformer"] else flags_obj.lsh_attention_dropout
        logging.info(
            f'dropouts (postprocess, attention, lsh_attention, relu) = {[params[k] for k in ["layer_postprocess_dropout", "attention_dropout", "lsh_attention_dropout", "relu_dropout"]]}'
        )
        logging.info(
            f'attention_padding_strategy = {flags_obj.attention_padding_strategy}'
        )

        assert self.flags_obj.vocab_file, 'vocab file is None'
        self.tokenizer = tfds.features.text.SubwordTextEncoder.load_from_file(
            self.flags_obj.vocab_file)
        self.EOS_id = self.tokenizer.encode('<EOS>')[0]
        params["vocab_size"] = self.tokenizer.vocab_size
        logging.info(
            'loaded vocab from {}, vocab_size={} and EOS_id={}'.format(
                self.flags_obj.vocab_file, self.tokenizer.vocab_size,
                self.EOS_id))
        logging.info(f'training_schema = [{self.flags_obj.training_schema}]')

        if params["dtype"] == tf.float16:
            # TODO(reedwm): It's pretty ugly to set the global policy in a constructor
            # like this. What if multiple instances of Seq2SeqTask are created?
            # We should have a better way in the tf.keras.mixed_precision API of doing
            # this.
            loss_scale = flags_core.get_loss_scale(flags_obj,
                                                   default_for_fp16="dynamic")
            policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
                "mixed_float16", loss_scale=loss_scale)
            tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)

        self.distribution_strategy = distribution_utils.get_distribution_strategy(
            distribution_strategy=flags_obj.distribution_strategy,
            num_gpus=num_gpus,
            tpu_address=flags_obj.tpu or "")
        logging.info("Running dtitle model with num_gpus = %d", num_gpus)

        if self.distribution_strategy:
            logging.info("For training, using distribution strategy: %s",
                         self.distribution_strategy)
        else:
            logging.info("Not using any distribution strategy.")
def run(flags_obj):
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
  # Eager is default in tf 2.0 and should not be toggled
  if keras_common.is_v2_0():
    keras_common.set_config_v2()
  else:
    config = keras_common.get_config_proto_v1()
    if flags_obj.enable_eager:
      tf.compat.v1.enable_eager_execution(config=config)
    else:
      sess = tf.Session(config=config)
      tf.keras.backend.set_session(sess)

  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
    keras_common.set_gpu_thread_mode_and_count(flags_obj)

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == 'float16':
    policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
    tf.keras.mixed_precision.experimental.set_policy(policy)

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
      num_workers=distribution_utils.configure_cluster())

  strategy_scope = keras_common.get_strategy_scope(strategy)

  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
    distribution_utils.set_up_synthetic_data()
    input_fn = keras_common.get_synth_input_fn(
        height=imagenet_main.DEFAULT_IMAGE_SIZE,
        width=imagenet_main.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_main.NUM_CHANNELS,
        num_classes=imagenet_main.NUM_CLASSES,
        dtype=dtype)
  else:
    distribution_utils.undo_set_up_synthetic_data()
    input_fn = imagenet_main.input_fn

  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras,
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
      dtype=dtype)

  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=parse_record_keras,
        dtype=dtype)

  with strategy_scope:
    optimizer = keras_common.get_optimizer()
    if dtype == 'float16':
      # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
      # can be enabled with a single line of code.
      optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
          optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))

    if flags_obj.use_trivial_model:
      model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES)
    else:
      model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES,
                                    dtype=dtype)

    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer=optimizer,
                  metrics=['sparse_categorical_accuracy'])

  time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
      learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])

  train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

  num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
                    flags_obj.batch_size)

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
    tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None
    validation_data = None

  callbacks = [time_callback, lr_callback]
  if flags_obj.enable_tensorboard:
    callbacks.append(tensorboard_callback)
  if flags_obj.enable_e2e_xprof:
    profiler.start()

  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
                      callbacks=callbacks,
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      validation_freq=flags_obj.epochs_between_evals,
                      verbose=2)

  if flags_obj.enable_e2e_xprof:
    results = profiler.stop()
    profiler.save(flags_obj.model_dir, results)

  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)
  stats = keras_common.build_stats(history, eval_output, time_callback)
  return stats
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
    # Eager is default in tf 2.0 and should not be toggled
    if keras_common.is_v2_0():
        keras_common.set_config_v2()
    else:
        config = keras_common.get_config_proto_v1()
        if flags_obj.enable_eager:
            tf.compat.v1.enable_eager_execution(config=config)
        else:
            sess = tf.Session(config=config)
            tf.keras.backend.set_session(sess)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_common.set_gpu_thread_mode_and_count(flags_obj)
    if flags_obj.data_delay_prefetch:
        keras_common.data_delay_prefetch()
    keras_common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'float16':
        policy = tf.keras.mixed_precision.experimental.Policy(
            'infer_float32_vars')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=distribution_utils.configure_cluster(),
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last partial
    # batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    # pylint: disable=protected-access
    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = keras_common.get_synth_input_fn(
            height=imagenet_main.DEFAULT_IMAGE_SIZE,
            width=imagenet_main.DEFAULT_IMAGE_SIZE,
            num_channels=imagenet_main.NUM_CHANNELS,
            num_classes=imagenet_main.NUM_CLASSES,
            dtype=dtype,
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = imagenet_main.input_fn

    # When `enable_xla` is True, we always drop the remainder of the batches
    # in the dataset, as XLA-GPU doesn't support dynamic shapes.
    drop_remainder = flags_obj.enable_xla

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=parse_record_keras,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        drop_remainder=drop_remainder,
        tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
    )

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(is_training=False,
                                      data_dir=flags_obj.data_dir,
                                      batch_size=flags_obj.batch_size,
                                      num_epochs=flags_obj.train_epochs,
                                      parse_record_fn=parse_record_keras,
                                      dtype=dtype,
                                      drop_remainder=drop_remainder)

    lr_schedule = 0.1
    if flags_obj.use_tensor_lr:
        lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup(
            batch_size=flags_obj.batch_size,
            epoch_size=imagenet_main.NUM_IMAGES['train'],
            warmup_epochs=LR_SCHEDULE[0][1],
            boundaries=list(p[1] for p in LR_SCHEDULE[1:]),
            multipliers=list(p[0] for p in LR_SCHEDULE),
            compute_lr_on_cpu=True)

    with strategy_scope:
        optimizer = keras_common.get_optimizer(lr_schedule)
        if dtype == 'float16':
            # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
            # can be enabled with a single line of code.
            optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
                optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))

        if flags_obj.enable_xla and not flags_obj.enable_eager:
            # TODO(b/129861005): Fix OOM issue in eager mode when setting
            # `batch_size` in keras.Input layer.
            if strategy and strategy.num_replicas_in_sync > 1:
                # TODO(b/129791381): Specify `input_layer_batch_size` value in
                # DistributionStrategy multi-replica case.
                input_layer_batch_size = None
            else:
                input_layer_batch_size = flags_obj.batch_size
        else:
            input_layer_batch_size = None

        if flags_obj.use_trivial_model:
            model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES,
                                                dtype)
        else:
            model = resnet_model.resnet50(
                num_classes=imagenet_main.NUM_CLASSES,
                dtype=dtype,
                batch_size=input_layer_batch_size)

        model.compile(loss='sparse_categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=(['sparse_categorical_accuracy']
                               if flags_obj.report_accuracy_metrics else None),
                      cloning=flags_obj.clone_model_in_keras_dist_strat)

    callbacks = keras_common.get_callbacks(learning_rate_schedule,
                                           imagenet_main.NUM_IMAGES['train'])

    train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
    train_epochs = flags_obj.train_epochs

    if flags_obj.train_steps:
        train_steps = min(flags_obj.train_steps, train_steps)
        train_epochs = 1

    num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        # Only build the training graph. This reduces memory usage introduced by
        # control flow ops in layers that have different implementations for
        # training and inference (e.g., batch norm).
        tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)

    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)
    stats = keras_common.build_stats(history, eval_output, callbacks)
    return stats
Exemple #7
0
def resnet_main(flags_obj,
                model_function,
                input_function,
                dataset_name,
                shape=None):
    """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.

  Returns:
    Dict of results of the run.
  """

    model_helpers.apply_clean(flags.FLAGS)

    # Ensures flag override logic is only executed if explicitly triggered.
    if flags_obj.tf_gpu_thread_mode:
        override_flags_and_set_envars_for_gpu_thread_pool(flags_obj)

    # Creates session config. allow_soft_placement = True, is required for
    # multi-GPU and is not harmful for other modes.
    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

    # Creates a `RunConfig` that checkpoints every 24 hours which essentially
    # results in checkpoints determined only by `epochs_between_evals`.
    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config,
                                        save_checkpoints_secs=60 * 60 * 24)

    # Initializes model with all but the dense layer from pretrained ResNet.
    if flags_obj.pretrained_model_checkpoint_path is not None:
        warm_start_settings = tf.estimator.WarmStartSettings(
            flags_obj.pretrained_model_checkpoint_path,
            vars_to_warm_start='^(?!.*dense)')
    else:
        warm_start_settings = None

    classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        warm_start_from=warm_start_settings,
        params={
            'resnet_size': int(flags_obj.resnet_size),
            'data_format': flags_obj.data_format,
            'batch_size': flags_obj.batch_size,
            'resnet_version': int(flags_obj.resnet_version),
            'loss_scale': flags_core.get_loss_scale(flags_obj),
            'dtype': flags_core.get_tf_dtype(flags_obj),
            'fine_tune': flags_obj.fine_tune
        })

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'resnet_size': flags_obj.resnet_size,
        'resnet_version': flags_obj.resnet_version,
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
    }
    if flags_obj.use_synthetic_data:
        dataset_name = dataset_name + '-synthetic'

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info('resnet',
                                  dataset_name,
                                  run_params,
                                  test_id=flags_obj.benchmark_test_id)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    train_hooks = list(train_hooks) + lottery.hooks_from_flags(
        flags_obj.flag_values_dict())

    def input_fn_train(num_epochs):
        return input_function(
            is_training=True,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=num_epochs,
            dtype=flags_core.get_tf_dtype(flags_obj),
            datasets_num_private_threads=flags_obj.
            datasets_num_private_threads,
            num_parallel_batches=flags_obj.datasets_num_parallel_batches)

    def input_fn_eval():
        return input_function(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=1,
            dtype=flags_core.get_tf_dtype(flags_obj))

    if flags_obj.lth_generate_predictions:
        ckpt = tf.train.latest_checkpoint(flags_obj.model_dir)

        if flags_obj.lth_no_pruning:
            m_hooks = []
        else:
            m_hooks = lottery.hooks_from_flags(flags_obj.flag_values_dict())

        eval_results = classifier.predict(
            input_fn=input_fn_eval,
            checkpoint_path=ckpt,
            hooks=m_hooks,
        )

        assert flags_obj.lth_prediction_result_dir
        with tf.gfile.Open(os.path.join(flags_obj.data_dir, 'test_batch.bin'),
                           'rb') as f:
            labels = list(f.read()[::32 * 32 * 3 + 1])

        eval_results = list(eval_results)
        if not tf.gfile.Exists(flags_obj.lth_prediction_result_dir):
            tf.gfile.MakeDirs(flags_obj.lth_prediction_result_dir)
        with tf.gfile.Open(
                os.path.join(flags_obj.lth_prediction_result_dir,
                             'predictions'), 'wb') as f:
            for label, res in zip(labels, eval_results):
                res['label'] = label
            pickle.dump(eval_results, f)
        return

    try:
        cpr = tf.train.NewCheckpointReader(
            tf.train.latest_checkpoint(flags_obj.model_dir))
        current_step = cpr.get_tensor('global_step')
    except:
        current_step = 0

    while current_step < flags_obj.max_train_steps:
        next_checkpoint = min(current_step + 10000, flags_obj.max_train_steps)
        classifier.train(input_fn=lambda: input_fn_train(1000),
                         hooks=train_hooks,
                         max_steps=next_checkpoint)
        current_step = next_checkpoint
        tf.logging.info('Starting to evaluate.')
        eval_results = classifier.evaluate(input_fn=input_fn_eval)
        benchmark_logger.log_evaluation_result(eval_results)

    if flags_obj.export_dir is not None:
        # Exports a saved model for the given classifier.
        export_dtype = flags_core.get_tf_dtype(flags_obj)
        if flags_obj.image_bytes_as_serving_input:
            input_receiver_fn = functools.partial(image_bytes_serving_input_fn,
                                                  shape,
                                                  dtype=export_dtype)
        else:
            input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
                shape, batch_size=flags_obj.batch_size, dtype=export_dtype)
        classifier.export_savedmodel(flags_obj.export_dir,
                                     input_receiver_fn,
                                     strip_default_attrs=True)
Exemple #8
0
def run_predict(flags_obj, datasets_override=None, strategy_override=None):
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_utils.set_gpu_thread_mode_and_count(
            per_gpu_thread_count=flags_obj.per_gpu_thread_count,
            gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
            num_gpus=flags_obj.num_gpus,
            datasets_num_private_threads=flags_obj.datasets_num_private_threads
        )
    common.set_cudnn_batchnorm_mode()

    performance.set_mixed_precision_policy(
        flags_core.get_tf_dtype(flags_obj),
        flags_core.get_loss_scale(flags_obj, default_for_fp16=128))

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    # Configures cluster spec for distribution strategy.
    _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                             flags_obj.task_index)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    distribution_utils.undo_set_up_synthetic_data()

    train_input_dataset, eval_input_dataset, tr_dataset, te_dataset = setup_datasets(
        flags_obj, shuffle=False)

    pred_input_dataset, pred_dataset = eval_input_dataset, te_dataset

    with strategy_scope:
        #model = build_model(tr_dataset.num_classes, mode='resnet50_features')
        model = build_model(100, mode='resnet50_features')

        load_path = GB_OPTIONS.pretrained_filepath
        if load_path is None:
            load_path = GB_OPTIONS.checkpoint_folder
        latest = tf.train.latest_checkpoint(load_path)
        print(latest)
        model.load_weights(latest)

        num_eval_examples = pred_dataset.num_examples_per_epoch()
        num_eval_steps = num_eval_examples // GB_OPTIONS.batch_size
        print(GB_OPTIONS.batch_size)

        pred = model.predict(pred_input_dataset,
                             batch_size=GB_OPTIONS.batch_size,
                             steps=num_eval_steps)

        lab = np.asarray(pred_dataset.data[1])
        if hasattr(pred_dataset, 'ori_labels'):
            ori_lab = pred_dataset.ori_labels
        else:
            ori_lab = lab

        print(pred.shape)

        np.save('out_X', pred)
        np.save('out_labels', lab)
        np.save('out_ori_labels', ori_lab)

        return 'good'
Exemple #9
0
def resnet_main(flags_obj, model_function, input_function, shape=None):
    """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags.export_dir is passed.
  """

    # Using the Winograd non-fused algorithms provides a small performance boost.
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

    if flags_obj.multi_gpu:
        validate_batch_size_for_multi_gpu(flags_obj.batch_size)

        # There are two steps required if using multi-GPU: (1) wrap the model_fn,
        # and (2) wrap the optimizer. The first happens here, and (2) happens
        # in the model_fn itself when the optimizer is defined.
        model_function = tf.contrib.estimator.replicate_model_fn(
            model_function, loss_reduction=tf.losses.Reduction.MEAN)

    # Create session config based on values of inter_op_parallelism_threads and
    # intra_op_parallelism_threads. Note that we default to having
    # allow_soft_placement = True, which is required for multi-GPU and not
    # harmful for other modes.
    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    # Set up a RunConfig to save checkpoint and set session config.
    run_config = tf.estimator.RunConfig().replace(
        save_checkpoints_secs=1e9, session_config=session_config)
    classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        params={
            'resnet_size': int(flags_obj.resnet_size),
            'data_format': flags_obj.data_format,
            'batch_size': flags_obj.batch_size,
            'multi_gpu': flags_obj.multi_gpu,
            'version': int(flags_obj.version),
            'loss_scale': flags_core.get_loss_scale(flags_obj),
            'dtype': flags_core.get_tf_dtype(flags_obj)
        })

    benchmark_logger = logger.config_benchmark_logger(
        flags_obj.benchmark_log_dir)
    benchmark_logger.log_run_info('resnet')

    train_hooks = hooks_helper.get_train_hooks(
        flags_obj.hooks,
        batch_size=flags_obj.batch_size,
        benchmark_log_dir=flags_obj.benchmark_log_dir)

    def input_fn_train():
        return input_function(True, flags_obj.data_dir, flags_obj.batch_size,
                              flags_obj.epochs_between_evals,
                              flags_obj.num_parallel_calls,
                              flags_obj.multi_gpu)

    def input_fn_eval():
        return input_function(False, flags_obj.data_dir, flags_obj.batch_size,
                              1, flags_obj.num_parallel_calls,
                              flags_obj.multi_gpu)

    total_training_cycle = (flags_obj.train_epochs //
                            flags_obj.epochs_between_evals)
    for cycle_index in range(total_training_cycle):
        tf.logging.info('Starting a training cycle: %d/%d', cycle_index,
                        total_training_cycle)

        classifier.train(input_fn=input_fn_train,
                         hooks=train_hooks,
                         max_steps=flags_obj.max_train_steps)

        tf.logging.info('Starting to evaluate.')
        # flags.max_train_steps is generally associated with testing and profiling.
        # As a result it is frequently called with synthetic data, which will
        # iterate forever. Passing steps=flags.max_train_steps allows the eval
        # (which is generally unimportant in those circumstances) to terminate.
        # Note that eval will run for max_train_steps each loop, regardless of the
        # global_step count.
        eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                           steps=flags_obj.max_train_steps)

        benchmark_logger.log_evaluation_result(eval_results)

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                             eval_results['accuracy']):
            break

    if flags_obj.export_dir is not None:
        warn_on_multi_gpu_export(flags_obj.multi_gpu)

        # Exports a saved model for the given classifier.
        input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
            shape, batch_size=flags_obj.batch_size)
        classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
Exemple #10
0
def resnet_main(flags_obj,
                model_function,
                input_function,
                dataset_name,
                shape=None):
    """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.
  """

    # Using the Winograd non-fused algorithms provides a small performance boost.
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

    # Create session config based on values of inter_op_parallelism_threads and
    # intra_op_parallelism_threads. Note that we default to having
    # allow_soft_placement = True, which is required for multi-GPU and not
    # harmful for other modes.
    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    if flags_core.get_num_gpus(flags_obj) == 0:
        distribution = tf.contrib.distribute.OneDeviceStrategy('device:CPU:0')
    elif flags_core.get_num_gpus(flags_obj) == 1:
        distribution = tf.contrib.distribute.OneDeviceStrategy('device:GPU:0')
    else:
        distribution = tf.contrib.distribute.MirroredStrategy(
            num_gpus=flags_core.get_num_gpus(flags_obj))

    run_config = tf.estimator.RunConfig(train_distribute=distribution,
                                        session_config=session_config)

    classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        params={
            'resnet_size': int(flags_obj.resnet_size),
            'data_format': flags_obj.data_format,
            'batch_size': flags_obj.batch_size,
            'resnet_version': int(flags_obj.resnet_version),
            'loss_scale': flags_core.get_loss_scale(flags_obj),
            'dtype': flags_core.get_tf_dtype(flags_obj)
        })

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'resnet_size': flags_obj.resnet_size,
        'resnet_version': flags_obj.resnet_version,
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
    }
    benchmark_logger = logger.config_benchmark_logger(flags_obj)
    benchmark_logger.log_run_info('resnet', dataset_name, run_params)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               batch_size=flags_obj.batch_size)

    def input_fn_train(num_epochs):
        return input_function(
            mode="train",
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=num_epochs,
            num_gpus=flags_core.get_num_gpus(flags_obj),
            dtype=flags_core.get_tf_dtype(flags_obj))

    def input_fn_eval():
        return input_function(mode="validate",
                              data_dir=flags_obj.data_dir,
                              batch_size=per_device_batch_size(
                                  flags_obj.batch_size,
                                  flags_core.get_num_gpus(flags_obj)),
                              num_epochs=1)

    def input_fn_pred():
        return input_function(mode="predict",
                              data_dir=flags_obj.data_dir,
                              batch_size=per_device_batch_size(
                                  flags_obj.batch_size,
                                  flags_core.get_num_gpus(flags_obj)),
                              num_epochs=1)

    #
    if flags_obj.predict_only:
        result = classifier.predict(input_fn=lambda: input_fn_pred())
        predicted_values = np.stack([r["predictions"] for r in result], axis=0)
        #print(predicted_values)
        df = pd.DataFrame(predicted_values)
        df.to_csv("validate_result.txt")
        return

    # train
    if flags_obj.eval_only or not flags_obj.train_epochs:
        # If --eval_only is set, perform a single loop with zero train epochs.
        schedule, n_loops = [0], 1
    else:
        # Compute the number of times to loop while training. All but the last
        # pass will train for `epochs_between_evals` epochs, while the last will
        # train for the number needed to reach `training_epochs`. For instance if
        #   train_epochs = 25 and epochs_between_evals = 10
        # schedule will be set to [10, 10, 5]. That is to say, the loop will:
        #   Train for 10 epochs and then evaluate.
        #   Train for another 10 epochs and then evaluate.
        #   Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
        n_loops = math.ceil(flags_obj.train_epochs /
                            flags_obj.epochs_between_evals)
        schedule = [
            flags_obj.epochs_between_evals for _ in range(int(n_loops))
        ]
        schedule[-1] = flags_obj.train_epochs - sum(
            schedule[:-1])  # over counting.

    for cycle_index, num_train_epochs in enumerate(schedule):
        tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops))

        if num_train_epochs:
            classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
                             hooks=train_hooks,
                             max_steps=flags_obj.max_train_steps)

        tf.logging.info('Starting to evaluate.')

        # flags_obj.max_train_steps is generally associated with testing and
        # profiling. As a result it is frequently called with synthetic data, which
        # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
        # eval (which is generally unimportant in those circumstances) to terminate.
        # Note that eval will run for max_train_steps each loop, regardless of the
        # global_step count.
        eval_results = classifier.evaluate(input_fn=input_fn_eval, steps=100)

        benchmark_logger.log_evaluation_result(eval_results)

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                             eval_results['mse']):
            break

    # save model for serving
    if flags_obj.export_dir is not None:
        # Exports a saved model for the given classifier.
        input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
            shape, batch_size=flags_obj.batch_size)
        classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
Exemple #11
0
def net_main(flags_obj,
             model_function,
             input_function,
             net_data_configs,
             shape=None):
    """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.
  """

    model_helpers.apply_clean(flags.FLAGS)
    is_metriclog = True
    if is_metriclog:
        metric_logfn = os.path.join(flags_obj.model_dir, 'log_metric.txt')
        metric_logf = open(metric_logfn, 'a')

    from tensorflow.contrib.memory_stats.ops import gen_memory_stats_ops
    max_memory_usage = gen_memory_stats_ops.max_bytes_in_use()

    # Using the Winograd non-fused algorithms provides a small performance boost.
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

    # Create session config based on values of inter_op_parallelism_threads and
    # intra_op_parallelism_threads. Note that we default to having
    # allow_soft_placement = True, which is required for multi-GPU and not
    # harmful for other modes.
    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config)

    # initialize our model with all but the dense layer from pretrained resnet
    if flags_obj.pretrained_model_checkpoint_path is not None:
        warm_start_settings = tf.estimator.WarmStartSettings(
            flags_obj.pretrained_model_checkpoint_path,
            vars_to_warm_start='^(?!.*dense)')
    else:
        warm_start_settings = None

    classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        warm_start_from=warm_start_settings,
        params={
            'data_format': flags_obj.data_format,
            'batch_size': flags_obj.batch_size,
            'loss_scale': flags_core.get_loss_scale(flags_obj),
            'weight_decay': flags_obj.weight_decay,
            'dtype': flags_core.get_tf_dtype(flags_obj),
            'fine_tune': flags_obj.fine_tune,
            'examples_per_epoch': flags_obj.examples_per_epoch,
            'net_data_configs': net_data_configs
        })

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
    }
    dataset_name = net_data_configs['dataset_name']
    if flags_obj.use_synthetic_data:
        dataset_name = dataset_name + '-synthetic'

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info('meshnet',
                                  dataset_name,
                                  run_params,
                                  test_id=flags_obj.benchmark_test_id)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    def input_fn_train(num_epochs):
        return input_function(
            is_training=True,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=num_epochs,
            num_gpus=flags_core.get_num_gpus(flags_obj),
            examples_per_epoch=flags_obj.examples_per_epoch,
            sg_settings=net_data_configs['sg_settings'])

    def input_fn_eval():
        return input_function(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=1,
            sg_settings=net_data_configs['sg_settings'])

    if flags_obj.eval_only or flags_obj.pred_ply or not flags_obj.train_epochs:
        # If --eval_only is set, perform a single loop with zero train epochs.
        schedule, n_loops = [0], 1
    else:
        # Compute the number of times to loop while training. All but the last
        # pass will train for `epochs_between_evals` epochs, while the last will
        # train for the number needed to reach `training_epochs`. For instance if
        #   train_epochs = 25 and epochs_between_evals = 10
        # schedule will be set to [10, 10, 5]. That is to say, the loop will:
        #   Train for 10 epochs and then evaluate.
        #   Train for another 10 epochs and then evaluate.
        #   Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
        n_loops = math.ceil(flags_obj.train_epochs /
                            flags_obj.epochs_between_evals)
        schedule = [
            flags_obj.epochs_between_evals for _ in range(int(n_loops))
        ]
        schedule[-1] = flags_obj.train_epochs - sum(
            schedule[:-1])  # over counting.

        classifier.train(input_fn=lambda: input_fn_train(1),
                         hooks=train_hooks,
                         max_steps=10)
        with tf.Session() as sess:
            max_memory_usage_v = sess.run(max_memory_usage)
            tf.logging.info('\n\nmemory usage: %0.3f G\n\n' %
                            (max_memory_usage_v * 1.0 / 1e9))

    best_acc, best_acc_checkpoint = load_saved_best(flags_obj.model_dir)
    for cycle_index, num_train_epochs in enumerate(schedule):
        tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops))

        t0 = time.time()
        train_t = 0
        if num_train_epochs:
            classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
                             hooks=train_hooks,
                             max_steps=flags_obj.max_train_steps)
            train_t = (time.time() - t0) / num_train_epochs

        tf.logging.info('Starting to evaluate.')

        # flags_obj.max_train_steps is generally associated with testing and
        # profiling. As a result it is frequently called with synthetic data, which
        # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
        # eval (which is generally unimportant in those circumstances) to terminate.
        # Note that eval will run for max_train_steps each loop, regardless of the
        # global_step count.
        only_train = False and (not flags_obj.eval_only) and (
            not flags_obj.pred_ply)
        if not only_train:
            t0 = time.time()
            eval_results = classifier.evaluate(
                input_fn=input_fn_eval,
                steps=flags_obj.max_train_steps,
            )
            #checkpoint_path=best_acc_checkpoint)
            eval_t = time.time() - t0

            if flags_obj.pred_ply:
                pred_generator = classifier.predict(input_fn=input_fn_eval)
                num_classes = net_data_configs['dset_metas'].num_classes
                gen_pred_ply(eval_results, pred_generator, flags_obj.model_dir,
                             num_classes)

            benchmark_logger.log_evaluation_result(eval_results)

            if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                                 eval_results['accuracy']):
                break

            cur_is_best = ''
            if num_train_epochs and eval_results['accuracy'] > best_acc:
                best_acc = eval_results['accuracy']
                save_cur_model_as_best_acc(flags_obj.model_dir, best_acc)
                cur_is_best = 'best'
            global_step = cur_global_step(flags_obj.model_dir)
            epoch = int(global_step / flags_obj.examples_per_epoch *
                        flags_obj.num_gpus)
            ious_str = get_ious_str(eval_results['cm'],
                                    net_data_configs['dset_metas'],
                                    eval_results['mean_iou'])
            metric_logf.write(
                '\n{} train t:{:.1f}  eval t:{:.1f} \teval acc:{:.3f} \tmean_iou:{:.3f} {} {}\n'
                .format(epoch, train_t, eval_t, eval_results['accuracy'],
                        eval_results['mean_iou'], cur_is_best, ious_str))
            metric_logf.flush()

    if flags_obj.export_dir is not None:
        # Exports a saved model for the given classifier.
        input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
            shape, batch_size=flags_obj.batch_size)
        classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
Exemple #12
0
def convinh_main(
    flags_obj, model_function, input_function, dataset_name, shape=None):
  """Shared main loop for convinh Models.

  Args:
    flags_obj: An object containing parsed flags. See define_convinh_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.
      _dir is passed.
  """

  model_helpers.apply_clean(flags.FLAGS)

  # Using the Winograd non-fused algorithms provides a small performance boost.
  os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

  # Create session config based on values of inter_op_parallelism_threads and
  # intra_op_parallelism_threads. Note that we default to having
  # allow_soft_placement = True, which is required for multi-GPU and not
  # harmful for other modes.
  session_config = tf.ConfigProto(
      inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
      intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
      allow_soft_placement=True)

  distribution_strategy = distribution_utils.get_distribution_strategy(
      flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)
  
  run_config = tf.estimator.RunConfig(
      tf_random_seed=flags_obj.seed,
      train_distribute=distribution_strategy, 
      session_config=session_config,
      keep_checkpoint_max = flags_obj.num_ckpt
      )

  classifier = tf.estimator.Estimator(
      model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config,
      params={
          'model_params':{
                'data_format':flags_obj.data_format, 
                'filters':list(map(int,flags_obj.filters)),
                'ratio_PV': flags_obj.ratio_PV,
                'ratio_SST': flags_obj.ratio_SST,
                'conv_kernel_size':list(map(int,flags_obj.conv_kernel_size)),
                'conv_kernel_size_inh':list(map(int,flags_obj.conv_kernel_size_inh)),
                'conv_strides':list(map(int,flags_obj.conv_strides)),
                'pool_size':list(map(int,flags_obj.pool_size)),
                'pool_strides':list(map(int,flags_obj.pool_strides)),
                'num_ff_layers':flags_obj.num_ff_layers,
                'num_rnn_layers':flags_obj.num_rnn_layers,
                'connection':flags_obj.connection,
                'n_time':flags_obj.n_time,
                'cell_fn':flags_obj.cell_fn,
                'act_fn':flags_obj.act_fn, 
                'pvsst_circuit':flags_obj.pvsst_circuit,
                'gating':flags_obj.gating,
                'normalize':flags_obj.normalize,
                'num_classes':flags_obj.num_classes
              },
          'batch_size' : flags_obj.batch_size,
          'weight_decay': flags_obj.weight_decay,
          'loss_scale': flags_core.get_loss_scale(flags_obj),
          'dtype': flags_core.get_tf_dtype(flags_obj)
      })

  run_params = {
      'batch_size': flags_obj.batch_size,
      'dtype': flags_core.get_tf_dtype(flags_obj),
      'convinh_size': flags_obj.convinh_size, # deprecated
      'convinh_version': flags_obj.convinh_version, # deprecated
      'synthetic_data': flags_obj.use_synthetic_data, # deprecated
      'train_epochs': flags_obj.train_epochs,
  }
  if flags_obj.use_synthetic_data:
    dataset_name = dataset_name + '-synthetic'

  benchmark_logger = logger.get_benchmark_logger()
  benchmark_logger.log_run_info('convinh', dataset_name, run_params,
                                test_id=flags_obj.benchmark_test_id)

  train_hooks = hooks_helper.get_train_hooks(
      flags_obj.hooks,
      model_dir=flags_obj.model_dir,
      batch_size=flags_obj.batch_size)
  
  class input_fn_train(object):
    def __init__(self,num_epochs):
      self._num_epochs = num_epochs
    def __call__(self):
      return input_function(
          is_training=True, data_dir=flags_obj.data_dir,
          batch_size=distribution_utils.per_device_batch_size(
              flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
          num_epochs=self._num_epochs,
          num_gpus=flags_core.get_num_gpus(flags_obj))

  def input_fn_eval():
    return input_function(
        is_training=False, data_dir=flags_obj.data_dir,
        batch_size=distribution_utils.per_device_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=1)
 
  tf.logging.info('Evaluate the intial model.')                          
  eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                       steps=flags_obj.max_train_steps)

  benchmark_logger.log_evaluation_result(eval_results)
  
  # training
  total_training_cycle = (flags_obj.train_epochs //
                          flags_obj.epochs_between_evals) + 1
                          
  for cycle_index in range(total_training_cycle):
    
    cur_train_epochs = flags_obj.epochs_between_evals if cycle_index else 1
    
    tf.logging.info('Starting a training cycle: %d/%d, with %d epochs',
                    cycle_index, total_training_cycle, cur_train_epochs)
    
    classifier.train(input_fn=input_fn_train(cur_train_epochs), 
                     hooks=train_hooks, max_steps=flags_obj.max_train_steps)

    tf.logging.info('Starting to evaluate.')

    # flags_obj.max_train_steps is generally associated with testing and
    # profiling. As a result it is frequently called with synthetic data, which
    # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
    # eval (which is generally unimportant in those circumstances) to terminate.
    # Note that eval will run for max_train_steps each loop, regardless of the
    # global_step count.
    eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                       steps=flags_obj.max_train_steps)

    benchmark_logger.log_evaluation_result(eval_results)

    if model_helpers.past_stop_threshold(
        flags_obj.stop_threshold, eval_results['accuracy']):
      break

    if flags_obj.export_dir is not None:
      # Exports a saved model for the given classifier.
      input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
          shape, batch_size=1)
      if cycle_index==0:
        classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn,
                checkpoint_path='{}/model.ckpt-0'.format(flags_obj.model_dir))
      classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)   
Exemple #13
0
    def gen_estimator(period=None):
        resnet_size = int(flags_obj.resnet_size)
        data_format = flags_obj.data_format
        batch_size = flags_obj.batch_size
        resnet_version = int(flags_obj.resnet_version)
        loss_scale = flags_core.get_loss_scale(flags_obj)
        dtype_tf = flags_core.get_tf_dtype(flags_obj)
        num_epochs_per_decay = flags_obj.num_epochs_per_decay
        learning_rate_decay_factor = flags_obj.learning_rate_decay_factor
        end_learning_rate = flags_obj.end_learning_rate
        learning_rate_decay_type = flags_obj.learning_rate_decay_type
        weight_decay = flags_obj.weight_decay
        zero_gamma = flags_obj.zero_gamma
        lr_warmup_epochs = flags_obj.lr_warmup_epochs
        base_learning_rate = flags_obj.base_learning_rate
        use_resnet_d = flags_obj.use_resnet_d
        use_dropblock = flags_obj.use_dropblock
        dropblock_kp = [float(be) for be in flags_obj.dropblock_kp]
        label_smoothing = flags_obj.label_smoothing
        momentum = flags_obj.momentum
        bn_momentum = flags_obj.bn_momentum
        train_epochs = flags_obj.train_epochs
        piecewise_lr_boundary_epochs = [
            int(be) for be in flags_obj.piecewise_lr_boundary_epochs
        ]
        piecewise_lr_decay_rates = [
            float(dr) for dr in flags_obj.piecewise_lr_decay_rates
        ]
        use_ranking_loss = flags_obj.use_ranking_loss
        use_se_block = flags_obj.use_se_block
        use_sk_block = flags_obj.use_sk_block
        mixup_type = flags_obj.mixup_type
        dataset_name = flags_obj.dataset_name
        kd_temp = flags_obj.kd_temp
        no_downsample = flags_obj.no_downsample
        anti_alias_filter_size = flags_obj.anti_alias_filter_size
        anti_alias_type = flags_obj.anti_alias_type
        cls_loss_type = flags_obj.cls_loss_type
        logit_type = flags_obj.logit_type
        embedding_size = flags_obj.embedding_size
        pool_type = flags_obj.pool_type
        arc_s = flags_obj.arc_s
        arc_m = flags_obj.arc_m
        bl_alpha = flags_obj.bl_alpha
        bl_beta = flags_obj.bl_beta
        exp = None

        if install_hyperdash and flags_obj.use_hyperdash:
            exp = Experiment(flags_obj.model_dir.split("/")[-1])
            resnet_size = exp.param("resnet_size", int(flags_obj.resnet_size))
            batch_size = exp.param("batch_size", flags_obj.batch_size)
            exp.param("dtype", flags_obj.dtype)
            learning_rate_decay_type = exp.param(
                "learning_rate_decay_type", flags_obj.learning_rate_decay_type)
            weight_decay = exp.param("weight_decay", flags_obj.weight_decay)
            zero_gamma = exp.param("zero_gamma", flags_obj.zero_gamma)
            lr_warmup_epochs = exp.param("lr_warmup_epochs",
                                         flags_obj.lr_warmup_epochs)
            base_learning_rate = exp.param("base_learning_rate",
                                           flags_obj.base_learning_rate)
            use_dropblock = exp.param("use_dropblock", flags_obj.use_dropblock)
            dropblock_kp = exp.param(
                "dropblock_kp", [float(be) for be in flags_obj.dropblock_kp])
            piecewise_lr_boundary_epochs = exp.param(
                "piecewise_lr_boundary_epochs",
                [int(be) for be in flags_obj.piecewise_lr_boundary_epochs])
            piecewise_lr_decay_rates = exp.param(
                "piecewise_lr_decay_rates",
                [float(dr) for dr in flags_obj.piecewise_lr_decay_rates])
            mixup_type = exp.param("mixup_type", flags_obj.mixup_type)
            dataset_name = exp.param("dataset_name", flags_obj.dataset_name)
            exp.param("autoaugment_type", flags_obj.autoaugment_type)

        classifier = tf.estimator.Estimator(
            model_fn=model_function,
            model_dir=flags_obj.model_dir,
            config=run_config,
            params={
                'resnet_size': resnet_size,
                'data_format': data_format,
                'batch_size': batch_size,
                'resnet_version': resnet_version,
                'loss_scale': loss_scale,
                'dtype': dtype_tf,
                'num_epochs_per_decay': num_epochs_per_decay,
                'learning_rate_decay_factor': learning_rate_decay_factor,
                'end_learning_rate': end_learning_rate,
                'learning_rate_decay_type': learning_rate_decay_type,
                'weight_decay': weight_decay,
                'zero_gamma': zero_gamma,
                'lr_warmup_epochs': lr_warmup_epochs,
                'base_learning_rate': base_learning_rate,
                'use_resnet_d': use_resnet_d,
                'use_dropblock': use_dropblock,
                'dropblock_kp': dropblock_kp,
                'label_smoothing': label_smoothing,
                'momentum': momentum,
                'bn_momentum': bn_momentum,
                'embedding_size': embedding_size,
                'train_epochs': train_epochs,
                'piecewise_lr_boundary_epochs': piecewise_lr_boundary_epochs,
                'piecewise_lr_decay_rates': piecewise_lr_decay_rates,
                'with_drawing_bbox': flags_obj.with_drawing_bbox,
                'use_ranking_loss': use_ranking_loss,
                'use_se_block': use_se_block,
                'use_sk_block': use_sk_block,
                'mixup_type': mixup_type,
                'kd_temp': kd_temp,
                'no_downsample': no_downsample,
                'dataset_name': dataset_name,
                'anti_alias_filter_size': anti_alias_filter_size,
                'anti_alias_type': anti_alias_type,
                'cls_loss_type': cls_loss_type,
                'logit_type': logit_type,
                'arc_s': arc_s,
                'arc_m': arc_m,
                'pool_type': pool_type,
                'bl_alpha': bl_alpha,
                'bl_beta': bl_beta,
                'train_steps': total_train_steps,
            })
        return classifier, exp
Exemple #14
0
    def __init__(self, flags_obj, time_callback, epoch_steps):
        standard_runnable.StandardTrainable.__init__(
            self, flags_obj.use_tf_while_loop, flags_obj.use_tf_function)
        standard_runnable.StandardEvaluable.__init__(self,
                                                     flags_obj.use_tf_function)

        self.strategy = tf.distribute.get_strategy()
        self.flags_obj = flags_obj
        self.dtype = flags_core.get_tf_dtype(flags_obj)
        self.time_callback = time_callback

        # Input pipeline related
        batch_size = flags_obj.batch_size
        if batch_size % self.strategy.num_replicas_in_sync != 0:
            raise ValueError(
                'Batch size must be divisible by number of replicas : {}'.
                format(self.strategy.num_replicas_in_sync))

        # As auto rebatching is not supported in
        # `experimental_distribute_datasets_from_function()` API, which is
        # required when cloning dataset to multiple workers in eager mode,
        # we use per-replica batch size.
        self.batch_size = int(batch_size / self.strategy.num_replicas_in_sync)

        if self.flags_obj.use_synthetic_data:
            self.input_fn = common.get_synth_input_fn(
                height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
                width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
                num_channels=imagenet_preprocessing.NUM_CHANNELS,
                num_classes=imagenet_preprocessing.NUM_CLASSES,
                dtype=self.dtype,
                drop_remainder=True)
        else:
            self.input_fn = imagenet_preprocessing.input_fn

        resnet_model.change_keras_layer(flags_obj.use_tf_keras_layers)
        self.model = resnet_model.resnet50(
            num_classes=imagenet_preprocessing.NUM_CLASSES,
            batch_size=flags_obj.batch_size,
            use_l2_regularizer=not flags_obj.single_l2_loss_op)

        lr_schedule = common.PiecewiseConstantDecayWithWarmup(
            batch_size=flags_obj.batch_size,
            epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
            warmup_epochs=common.LR_SCHEDULE[0][1],
            boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
            multipliers=list(p[0] for p in common.LR_SCHEDULE),
            compute_lr_on_cpu=True)
        self.optimizer = common.get_optimizer(lr_schedule)
        # Make sure iterations variable is created inside scope.
        self.global_step = self.optimizer.iterations

        if self.dtype == tf.float16:
            loss_scale = flags_core.get_loss_scale(flags_obj,
                                                   default_for_fp16=128)
            self.optimizer = (
                tf.keras.mixed_precision.experimental.LossScaleOptimizer(
                    self.optimizer, loss_scale))
        elif flags_obj.fp16_implementation == 'graph_rewrite':
            # `dtype` is still float32 in this case. We built the graph in float32
            # and let the graph rewrite change parts of it float16.
            if not flags_obj.use_tf_function:
                raise ValueError(
                    '--fp16_implementation=graph_rewrite requires '
                    '--use_tf_function to be true')
            loss_scale = flags_core.get_loss_scale(flags_obj,
                                                   default_for_fp16=128)
            self.optimizer = (
                tf.train.experimental.enable_mixed_precision_graph_rewrite(
                    self.optimizer, loss_scale))

        self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
        self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'train_accuracy', dtype=tf.float32)
        self.test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
        self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'test_accuracy', dtype=tf.float32)

        self.checkpoint = tf.train.Checkpoint(model=self.model,
                                              optimizer=self.optimizer)

        # Handling epochs.
        self.epoch_steps = epoch_steps
        self.epoch_helper = utils.EpochHelper(epoch_steps, self.global_step)
        self.examples_per_second_history = []
Exemple #15
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using native Keras APIs.

    Args:
      flags_obj: An object containing parsed flag values.

    Raises:
      ValueError: If fp16 is passed as it is not currently supported.

    Returns:
      Dictionary of training and eval stats.
    """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        common.set_gpu_thread_mode_and_count(flags_obj)
    if flags_obj.data_delay_prefetch:
        common.data_delay_prefetch()
    common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'float16':
        policy = tf.keras.mixed_precision.experimental.Policy(
            'infer_float32_vars')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    # Configures cluster spec for distribution strategy.
    num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                                       flags_obj.task_index)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=num_workers,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    # pylint: disable=protected-access
    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = common.get_synth_input_fn(
            height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            num_channels=imagenet_preprocessing.NUM_CHANNELS,
            num_classes=imagenet_preprocessing.NUM_CLASSES,
            dtype=dtype,
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = imagenet_preprocessing.input_fn

    # When `enable_xla` is True, we always drop the remainder of the batches
    # in the dataset, as XLA-GPU doesn't support dynamic shapes.
    drop_remainder = flags_obj.enable_xla

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=imagenet_preprocessing.parse_record,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        drop_remainder=drop_remainder,
        tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
    )

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=flags_obj.batch_size,
            num_epochs=flags_obj.train_epochs,
            parse_record_fn=imagenet_preprocessing.parse_record,
            dtype=dtype,
            drop_remainder=drop_remainder)

    lr_schedule = 0.1
    if flags_obj.use_tensor_lr:
        lr_schedule = common.PiecewiseConstantDecayWithWarmup(
            batch_size=flags_obj.batch_size,
            epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
            warmup_epochs=LR_SCHEDULE[0][1],
            boundaries=list(p[1] for p in LR_SCHEDULE[1:]),
            multipliers=list(p[0] for p in LR_SCHEDULE),
            compute_lr_on_cpu=True)

    with strategy_scope:
        optimizer = common.get_optimizer(lr_schedule)
        if dtype == 'float16':
            # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
            # can be enabled with a single line of code.
            optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
                optimizer,
                loss_scale=flags_core.get_loss_scale(flags_obj,
                                                     default_for_fp16=128))

        if flags_obj.use_trivial_model:
            model = trivial_model.trivial_model(
                imagenet_preprocessing.NUM_CLASSES, dtype)
        else:
            model = resnet_model.resnet50(
                num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype)

        # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
        # a valid arg for this model. Also remove as a valid flag.
        if flags_obj.force_v2_in_keras_compile is not None:
            model.compile(
                loss='sparse_categorical_crossentropy',
                optimizer=optimizer,
                metrics=(['sparse_categorical_accuracy']
                         if flags_obj.report_accuracy_metrics else None),
                run_eagerly=flags_obj.run_eagerly,
                experimental_run_tf_function=flags_obj.
                force_v2_in_keras_compile)
        else:
            model.compile(
                loss='sparse_categorical_crossentropy',
                optimizer=optimizer,
                metrics=(['sparse_categorical_accuracy']
                         if flags_obj.report_accuracy_metrics else None),
                run_eagerly=flags_obj.run_eagerly)

    callbacks = common.get_callbacks(
        learning_rate_schedule, imagenet_preprocessing.NUM_IMAGES['train'])

    train_steps = (imagenet_preprocessing.NUM_IMAGES['train'] //
                   flags_obj.batch_size)
    train_epochs = flags_obj.train_epochs

    if flags_obj.train_steps:
        train_steps = min(flags_obj.train_steps, train_steps)
        train_epochs = 1

    num_eval_steps = (imagenet_preprocessing.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        # Only build the training graph. This reduces memory usage introduced by
        # control flow ops in layers that have different implementations for
        # training and inference (e.g., batch norm).
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    if not strategy and flags_obj.explicit_gpu_placement:
        # TODO(b/135607227): Add device scope automatically in Keras training loop
        # when not using distribition strategy.
        no_dist_strat_device = tf.device('/device:GPU:0')
        no_dist_strat_device.__enter__()

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps // 15,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=1)

    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=1)

    if not strategy and flags_obj.explicit_gpu_placement:
        no_dist_strat_device.__exit__()

    stats = common.build_stats(history, eval_output, callbacks)
    return stats
def run(flags_obj):
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  config = keras_common.get_config_proto()
  # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
  # Eager is default in tf 2.0 and should not be toggled
  if not keras_common.is_v2_0():
    if flags_obj.enable_eager:
      tf.compat.v1.enable_eager_execution(config=config)
    else:
      sess = tf.Session(config=config)
      tf.keras.backend.set_session(sess)
  # TODO(haoyuzhang): Set config properly in TF2.0 when the config API is ready.

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == 'float16':
    policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
    tf.keras.mixed_precision.experimental.set_policy(policy)

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
    distribution_utils.set_up_synthetic_data()
    input_fn = keras_common.get_synth_input_fn(
        height=imagenet_main.DEFAULT_IMAGE_SIZE,
        width=imagenet_main.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_main.NUM_CHANNELS,
        num_classes=imagenet_main.NUM_CLASSES,
        dtype=dtype)
  else:
    distribution_utils.undo_set_up_synthetic_data()
    input_fn = imagenet_main.input_fn

  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras,
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
      dtype=dtype)

  eval_input_dataset = input_fn(
      is_training=False,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras,
      dtype=dtype)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus)

  strategy_scope = keras_common.get_strategy_scope(strategy)

  with strategy_scope:
    optimizer = keras_common.get_optimizer()
    if dtype == 'float16':
      # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
      # can be enabled with a single line of code.
      optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
          optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))
    model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES,
                                  dtype=dtype)

    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer=optimizer,
                  metrics=['sparse_categorical_accuracy'])

  time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
      learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])

  train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

  num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
                    flags_obj.batch_size)

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
    tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None
    validation_data = None

  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
                      callbacks=[
                          time_callback,
                          lr_callback,
                          tensorboard_callback
                      ],
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      validation_freq=flags_obj.epochs_between_evals,
                      verbose=2)

  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)
  stats = keras_common.build_stats(history, eval_output, time_callback)
  return stats
Exemple #17
0
def resnet_main(
    flags_obj, model_function, input_function, dataset_name, shape=None):
  """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.
  """

  print("RESNET MAIN")
  model_helpers.apply_clean(flags.FLAGS)

  # Ensures flag override logic is only executed if explicitly triggered.
  if flags_obj.tf_gpu_thread_mode:
    override_flags_and_set_envars_for_gpu_thread_pool(flags_obj)

  # Creates session config. allow_soft_placement = True, is required for
  # multi-GPU and is not harmful for other modes.
  session_config = tf.ConfigProto(allow_soft_placement=True)

  run_config = tf.estimator.RunConfig(
      session_config=session_config,
      save_checkpoints_secs=60*60*24)

  # Initializes model with all but the dense layer from pretrained ResNet.
  if flags_obj.pretrained_model_checkpoint_path is not None:
    warm_start_settings = tf.estimator.WarmStartSettings(
        flags_obj.pretrained_model_checkpoint_path,
        vars_to_warm_start='^(?!.*dense)')
  else:
    warm_start_settings = None

  classifier = tf.estimator.Estimator(
      model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config,
      warm_start_from=warm_start_settings, params={
          'resnet_size': int(flags_obj.resnet_size),
          'data_format': flags_obj.data_format,
          'batch_size': flags_obj.batch_size,
          'resnet_version': int(flags_obj.resnet_version),
          'loss_scale': flags_core.get_loss_scale(flags_obj),
          'dtype': flags_core.get_tf_dtype(flags_obj),
          'fine_tune': flags_obj.fine_tune
      })

  run_params = {
      'batch_size': flags_obj.batch_size,
      'dtype': flags_core.get_tf_dtype(flags_obj),
      'resnet_size': flags_obj.resnet_size,
      'resnet_version': flags_obj.resnet_version,
      'synthetic_data': flags_obj.use_synthetic_data,
      'train_epochs': flags_obj.train_epochs,
  }

  def input_fn_eval():
    return input_function(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=distribution_utils.per_device_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=1,
        dtype=flags_core.get_tf_dtype(flags_obj))

  schedule, n_loops = [0], 1
  if flags_obj.export_dir is not None:
    # Exports a saved model for the given classifier.
    export_dtype = flags_core.get_tf_dtype(flags_obj)
    if flags_obj.image_bytes_as_serving_input:
      input_receiver_fn = functools.partial(
          image_bytes_serving_input_fn, shape, dtype=export_dtype)
    else:
      input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
          shape, batch_size=flags_obj.batch_size, dtype=export_dtype)
    classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn,
                                 strip_default_attrs=True)
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
    NotImplementedError: If some features are not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_utils.set_gpu_thread_mode_and_count(
            per_gpu_thread_count=flags_obj.per_gpu_thread_count,
            gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
            num_gpus=flags_obj.num_gpus,
            datasets_num_private_threads=flags_obj.datasets_num_private_threads
        )
    common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == tf.float16:
        loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
        policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
            'mixed_float16', loss_scale=loss_scale)
        tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
        if not keras_utils.is_v2_0():
            raise ValueError('--dtype=fp16 is not supported in TensorFlow 1.')
    elif dtype == tf.bfloat16:
        policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
            'mixed_bfloat16')
        tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    # Configures cluster spec for distribution strategy.
    num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                                       flags_obj.task_index)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=num_workers,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    # pylint: disable=protected-access
    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = common.get_synth_input_fn(
            height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            num_channels=imagenet_preprocessing.NUM_CHANNELS,
            num_classes=imagenet_preprocessing.NUM_CLASSES,
            dtype=dtype,
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = imagenet_preprocessing.input_fn

    # When `enable_xla` is True, we always drop the remainder of the batches
    # in the dataset, as XLA-GPU doesn't support dynamic shapes.
    drop_remainder = flags_obj.enable_xla

    # Current resnet_model.resnet50 input format is always channel-last.
    # We use keras_application mobilenet model which input format is depends on
    # the keras beckend image data format.
    # This use_keras_image_data_format flags indicates whether image preprocessor
    # output format should be same as the keras backend image data format or just
    # channel-last format.
    use_keras_image_data_format = (flags_obj.model == 'mobilenet')
    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        drop_remainder=drop_remainder,
        tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
        training_dataset_cache=flags_obj.training_dataset_cache,
    )

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=flags_obj.batch_size,
            num_epochs=flags_obj.train_epochs,
            parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
                use_keras_image_data_format=use_keras_image_data_format),
            dtype=dtype,
            drop_remainder=drop_remainder)

    lr_schedule = 0.1
    if flags_obj.use_tensor_lr:
        lr_schedule = common.PiecewiseConstantDecayWithWarmup(
            batch_size=flags_obj.batch_size,
            epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
            warmup_epochs=common.LR_SCHEDULE[0][1],
            boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
            multipliers=list(p[0] for p in common.LR_SCHEDULE),
            compute_lr_on_cpu=True)
    steps_per_epoch = (imagenet_preprocessing.NUM_IMAGES['train'] //
                       flags_obj.batch_size)

    learning_rate_schedule_fn = None
    with strategy_scope:
        if flags_obj.optimizer == 'resnet50_default':
            optimizer = common.get_optimizer(lr_schedule)
            learning_rate_schedule_fn = common.learning_rate_schedule
        elif flags_obj.optimizer == 'mobilenet_default':
            lr_decay_factor = 0.94
            num_epochs_per_decay = 2.5
            initial_learning_rate_per_sample = 0.000007
            initial_learning_rate = \
                initial_learning_rate_per_sample * flags_obj.batch_size
            optimizer = tf.keras.optimizers.SGD(
                learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
                    initial_learning_rate,
                    decay_steps=steps_per_epoch * num_epochs_per_decay,
                    decay_rate=lr_decay_factor,
                    staircase=True),
                momentum=0.9)
        if flags_obj.fp16_implementation == 'graph_rewrite':
            # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
            # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
            # which will ensure tf.compat.v2.keras.mixed_precision and
            # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
            # up.
            optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                optimizer)

        # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
        if flags_obj.use_trivial_model:
            model = trivial_model.trivial_model(
                imagenet_preprocessing.NUM_CLASSES)
        elif flags_obj.model == 'resnet50_v1.5':
            resnet_model.change_keras_layer(flags_obj.use_tf_keras_layers)
            model = resnet_model.resnet50(
                num_classes=imagenet_preprocessing.NUM_CLASSES)
        elif flags_obj.model == 'mobilenet':
            # TODO(kimjaehong): Remove layers attribute when minimum TF version
            # support 2.0 layers by default.
            model = tf.keras.applications.mobilenet.MobileNet(
                weights=None,
                classes=imagenet_preprocessing.NUM_CLASSES,
                layers=tf.keras.layers)
        if flags_obj.pretrained_filepath:
            model.load_weights(flags_obj.pretrained_filepath)

        if flags_obj.pruning_method == 'polynomial_decay':
            if dtype != tf.float32:
                raise NotImplementedError(
                    'Pruning is currently only supported on dtype=tf.float32.')
            pruning_params = {
                'pruning_schedule':
                tfmot.sparsity.keras.PolynomialDecay(
                    initial_sparsity=flags_obj.pruning_initial_sparsity,
                    final_sparsity=flags_obj.pruning_final_sparsity,
                    begin_step=flags_obj.pruning_begin_step,
                    end_step=flags_obj.pruning_end_step,
                    frequency=flags_obj.pruning_frequency),
            }
            model = tfmot.sparsity.keras.prune_low_magnitude(
                model, **pruning_params)
        elif flags_obj.pruning_method:
            raise NotImplementedError(
                'Only polynomial_decay is currently supported.')
        # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
        # a valid arg for this model. Also remove as a valid flag.
        if flags_obj.force_v2_in_keras_compile is not None:
            model.compile(
                loss='sparse_categorical_crossentropy',
                optimizer=optimizer,
                metrics=(['sparse_categorical_accuracy']
                         if flags_obj.report_accuracy_metrics else None),
                run_eagerly=flags_obj.run_eagerly,
                experimental_run_tf_function=flags_obj.
                force_v2_in_keras_compile)
        else:
            model.compile(
                loss='sparse_categorical_crossentropy',
                optimizer=optimizer,
                metrics=(['sparse_categorical_accuracy']
                         if flags_obj.report_accuracy_metrics else None),
                run_eagerly=flags_obj.run_eagerly)

    train_epochs = flags_obj.train_epochs

    callbacks = common.get_callbacks(
        steps_per_epoch=steps_per_epoch,
        learning_rate_schedule_fn=learning_rate_schedule_fn,
        pruning_method=flags_obj.pruning_method,
        enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
        model_dir=flags_obj.model_dir)

    # if mutliple epochs, ignore the train_steps flag.
    if train_epochs <= 1 and flags_obj.train_steps:
        steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
        train_epochs = 1

    num_eval_steps = (imagenet_preprocessing.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        # Only build the training graph. This reduces memory usage introduced by
        # control flow ops in layers that have different implementations for
        # training and inference (e.g., batch norm).
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    if not strategy and flags_obj.explicit_gpu_placement:
        # TODO(b/135607227): Add device scope automatically in Keras training loop
        # when not using distribition strategy.
        no_dist_strat_device = tf.device('/device:GPU:0')
        no_dist_strat_device.__enter__()

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)

    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

    if flags_obj.pruning_method:
        model = tfmot.sparsity.keras.strip_pruning(model)
    if flags_obj.enable_checkpoint_and_export:
        if dtype == tf.bfloat16:
            logging.warning(
                'Keras model.save does not support bfloat16 dtype.')
        else:
            # Keras model.save assumes a float32 input designature.
            export_path = os.path.join(flags_obj.model_dir, 'saved_model')
            model.save(export_path, include_optimizer=False)

    if not strategy and flags_obj.explicit_gpu_placement:
        no_dist_strat_device.__exit__()

    stats = common.build_stats(history, eval_output, callbacks)
    return stats
Exemple #19
0
def resnet_main(flags_obj,
                model_function,
                input_function,
                dataset_name,
                shape=None):
    """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.
  """

    model_helpers.apply_clean(flags.FLAGS)

    # Ensures flag override logic is only executed if explicitly triggered.
    if flags_obj.tf_gpu_thread_mode:
        override_flags_and_set_envars_for_gpu_thread_pool(flags_obj)

    # Creates session config. allow_soft_placement = True, is required for
    # multi-GPU and is not harmful for other modes.
    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

    # Creates a `RunConfig` that checkpoints every 24 hours which essentially
    # results in checkpoints determined only by `epochs_between_evals`.
    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config,
                                        save_checkpoints_secs=60 * 60 * 24)

    # Initializes model with all but the dense layer from pretrained ResNet.
    if flags_obj.pretrained_model_checkpoint_path is not None:
        warm_start_settings = tf.estimator.WarmStartSettings(
            flags_obj.pretrained_model_checkpoint_path,
            vars_to_warm_start='^(?!.*dense)')
    else:
        warm_start_settings = None

    PKL_PATH = 'data/vocab_short_list.pkl'
    vocab = pickle.load(open(PKL_PATH, 'rb'))
    if len(vocab) == 2:
        vocab = vocab[0]
    vocab = list(vocab)
    classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        warm_start_from=warm_start_settings,
        params={
            'resnet_size': int(flags_obj.resnet_size),
            'data_format': flags_obj.data_format,
            'batch_size': flags_obj.batch_size,
            'resnet_version': int(flags_obj.resnet_version),
            'loss_scale': flags_core.get_loss_scale(flags_obj),
            'dtype': flags_core.get_tf_dtype(flags_obj),
            'vocab': vocab,
            'fine_tune': flags_obj.fine_tune,
            'learning_rate': flags_obj.learning_rate
        })

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'resnet_size': flags_obj.resnet_size,
        'resnet_version': flags_obj.resnet_version,
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
    }
    if flags_obj.use_synthetic_data:
        dataset_name = dataset_name + '-synthetic'

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info('resnet',
                                  dataset_name,
                                  run_params,
                                  test_id=flags_obj.benchmark_test_id)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)
    eval_hooks = hooks_helper.get_train_hooks(
        ['featuresloggerhook'],
        model_dir=flags_obj.model_dir,
        batch_size=flags_obj.batch_size,
        use_train_data=flags_obj.use_train_data)
    if flags_obj.eval_only:
        train_hooks += eval_hooks

    def input_fn_train(num_epochs, shuffle=True):
        return input_function(
            is_training=True,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=num_epochs,
            dtype=flags_core.get_tf_dtype(flags_obj),
            datasets_num_private_threads=flags_obj.
            datasets_num_private_threads,
            num_parallel_batches=flags_obj.datasets_num_parallel_batches,
            shuffle=shuffle)

    def input_fn_eval():
        return input_function(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=1,
            dtype=flags_core.get_tf_dtype(flags_obj))

    if flags_obj.eval_only or not flags_obj.train_epochs:
        # If --eval_only is set, perform a single loop with zero train epochs.
        schedule, n_loops = [0], 1
    else:
        # Compute the number of times to loop while training. All but the last
        # pass will train for `epochs_between_evals` epochs, while the last will
        # train for the number needed to reach `training_epochs`. For instance if
        #   train_epochs = 25 and epochs_between_evals = 10
        # schedule will be set to [10, 10, 5]. That is to say, the loop will:
        #   Train for 10 epochs and then evaluate.
        #   Train for another 10 epochs and then evaluate.
        #   Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
        n_loops = math.ceil(flags_obj.train_epochs /
                            flags_obj.epochs_between_evals)
        schedule = [
            flags_obj.epochs_between_evals for _ in range(int(n_loops))
        ]
        schedule[-1] = flags_obj.train_epochs - sum(
            schedule[:-1])  # over counting.

    for cycle_index, num_train_epochs in enumerate(schedule):
        tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops))

        if num_train_epochs:
            classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
                             hooks=train_hooks,
                             max_steps=flags_obj.max_train_steps)

        tf.logging.info('Starting to evaluate.')

        # flags_obj.max_train_steps is generally associated with testing and
        # profiling. As a result it is frequently called with synthetic data, which
        # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
        # eval (which is generally unimportant in those circumstances) to terminate.
        # Note that eval will run for max_train_steps each loop, regardless of the
        # global_step count.
        eval_input_fn = (lambda: input_fn_train(1, False)) if (flags_obj.eval_only and \
            flags_obj.use_train_data) else input_fn_eval
        eval_results = classifier.evaluate(input_fn=eval_input_fn,
                                           hooks=eval_hooks,
                                           steps=flags_obj.max_train_steps)

        benchmark_logger.log_evaluation_result(eval_results)

        if flags_obj.eval_only:
            pred_results = classifier.predict(eval_input_fn)
            dir_name = 'train' if flags_obj.use_train_data else 'val'
            if not os.path.exists(os.path.join(flags_obj.model_dir, dir_name)):
                os.makedirs(os.path.join(flags_obj.model_dir, dir_name))
            pickle.dump(
                list(pred_results),
                open(os.path.join(flags_obj.model_dir, dir_name, 'preds.pkl'),
                     'wb'))
            return

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                             eval_results['accuracy']):
            break

    if flags_obj.export_dir is not None:
        # Exports a saved model for the given classifier.
        export_dtype = flags_core.get_tf_dtype(flags_obj)
        if flags_obj.image_bytes_as_serving_input:
            input_receiver_fn = functools.partial(image_bytes_serving_input_fn,
                                                  shape,
                                                  dtype=export_dtype)
        else:
            input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
                shape, batch_size=flags_obj.batch_size, dtype=export_dtype)
        classifier.export_savedmodel(flags_obj.export_dir,
                                     input_receiver_fn,
                                     strip_default_attrs=True)
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using custom training loops.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == tf.float16:
        policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
            'mixed_float16')
        tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
    elif dtype == tf.bfloat16:
        policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
            'mixed_bfloat16')
        tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)

    # This only affects GPU.
    common.set_cudnn_batchnorm_mode()

    # TODO(anj-s): Set data_format without using Keras.
    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    train_ds, test_ds = get_input_dataset(flags_obj, strategy)
    per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
        flags_obj)
    steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps)
    logging.info(
        "Training %d epochs, each epoch has %d steps, "
        "total steps: %d; Eval %d steps", train_epochs, per_epoch_steps,
        train_epochs * per_epoch_steps, eval_steps)

    time_callback = keras_utils.TimeHistory(flags_obj.batch_size,
                                            flags_obj.log_steps)

    with distribution_utils.get_strategy_scope(strategy):
        resnet_model.change_keras_layer(flags_obj.use_tf_keras_layers)
        model = resnet_model.resnet50(
            num_classes=imagenet_preprocessing.NUM_CLASSES,
            batch_size=flags_obj.batch_size,
            use_l2_regularizer=not flags_obj.single_l2_loss_op)

        lr_schedule = common.PiecewiseConstantDecayWithWarmup(
            batch_size=flags_obj.batch_size,
            epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
            warmup_epochs=common.LR_SCHEDULE[0][1],
            boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
            multipliers=list(p[0] for p in common.LR_SCHEDULE),
            compute_lr_on_cpu=True)
        optimizer = common.get_optimizer(lr_schedule)

        if dtype == tf.float16:
            loss_scale = flags_core.get_loss_scale(flags_obj,
                                                   default_for_fp16=128)
            optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
                optimizer, loss_scale)
        elif flags_obj.fp16_implementation == 'graph_rewrite':
            # `dtype` is still float32 in this case. We built the graph in float32 and
            # let the graph rewrite change parts of it float16.
            if not flags_obj.use_tf_function:
                raise ValueError(
                    '--fp16_implementation=graph_rewrite requires '
                    '--use_tf_function to be true')
            loss_scale = flags_core.get_loss_scale(flags_obj,
                                                   default_for_fp16=128)
            optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                optimizer, loss_scale)

        current_step = 0
        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir)
        if latest_checkpoint:
            checkpoint.restore(latest_checkpoint)
            logging.info("Load checkpoint %s", latest_checkpoint)
            current_step = optimizer.iterations.numpy()

        train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
        training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'training_accuracy', dtype=tf.float32)
        test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
        test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'test_accuracy', dtype=tf.float32)

        trainable_variables = model.trainable_variables

        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            with tf.GradientTape() as tape:
                logits = model(images, training=True)

                prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, logits)
                loss = tf.reduce_sum(prediction_loss) * (1.0 /
                                                         flags_obj.batch_size)
                num_replicas = tf.distribute.get_strategy(
                ).num_replicas_in_sync

                if flags_obj.single_l2_loss_op:
                    l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.add_n([
                        tf.nn.l2_loss(v)
                        for v in trainable_variables if 'bn' not in v.name
                    ])

                    loss += (l2_loss / num_replicas)
                else:
                    loss += (tf.reduce_sum(model.losses) / num_replicas)

                # Scale the loss
                if flags_obj.dtype == "fp16":
                    loss = optimizer.get_scaled_loss(loss)

            grads = tape.gradient(loss, trainable_variables)

            # Unscale the grads
            if flags_obj.dtype == "fp16":
                grads = optimizer.get_unscaled_gradients(grads)

            optimizer.apply_gradients(zip(grads, trainable_variables))
            train_loss.update_state(loss)
            training_accuracy.update_state(labels, logits)

        @tf.function
        def train_steps(iterator, steps):
            """Performs distributed training steps in a loop."""
            for _ in tf.range(steps):
                strategy.experimental_run_v2(step_fn, args=(next(iterator), ))

        def train_single_step(iterator):
            if strategy:
                strategy.experimental_run_v2(step_fn, args=(next(iterator), ))
            else:
                return step_fn(next(iterator))

        def test_step(iterator):
            """Evaluation StepFn."""
            def step_fn(inputs):
                images, labels = inputs
                logits = model(images, training=False)
                loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, logits)
                loss = tf.reduce_sum(loss) * (1.0 / flags_obj.batch_size)
                test_loss.update_state(loss)
                test_accuracy.update_state(labels, logits)

            if strategy:
                strategy.experimental_run_v2(step_fn, args=(next(iterator), ))
            else:
                step_fn(next(iterator))

        if flags_obj.use_tf_function:
            train_single_step = tf.function(train_single_step)
            test_step = tf.function(test_step)

        if flags_obj.enable_tensorboard:
            summary_writer = tf.summary.create_file_writer(flags_obj.model_dir)
        else:
            summary_writer = None

        train_iter = iter(train_ds)
        time_callback.on_train_begin()
        for epoch in range(current_step // per_epoch_steps, train_epochs):
            train_loss.reset_states()
            training_accuracy.reset_states()

            steps_in_current_epoch = 0
            while steps_in_current_epoch < per_epoch_steps:
                time_callback.on_batch_begin(steps_in_current_epoch +
                                             epoch * per_epoch_steps)
                steps = _steps_to_run(steps_in_current_epoch, per_epoch_steps,
                                      steps_per_loop)
                if steps == 1:
                    train_single_step(train_iter)
                else:
                    # Converts steps to a Tensor to avoid tf.function retracing.
                    train_steps(train_iter,
                                tf.convert_to_tensor(steps, dtype=tf.int32))
                time_callback.on_batch_end(steps_in_current_epoch +
                                           epoch * per_epoch_steps)
                steps_in_current_epoch += steps

            logging.info('Training loss: %s, accuracy: %s at epoch %d',
                         train_loss.result().numpy(),
                         training_accuracy.result().numpy(), epoch + 1)

            if (not flags_obj.skip_eval
                    and (epoch + 1) % flags_obj.epochs_between_evals == 0):
                test_loss.reset_states()
                test_accuracy.reset_states()

                test_iter = iter(test_ds)
                for _ in range(eval_steps):
                    test_step(test_iter)

                logging.info('Test loss: %s, accuracy: %s%% at epoch: %d',
                             test_loss.result().numpy(),
                             test_accuracy.result().numpy(), epoch + 1)

            if flags_obj.enable_checkpoint_and_export:
                checkpoint_name = checkpoint.save(
                    os.path.join(flags_obj.model_dir,
                                 'model.ckpt-{}'.format(epoch + 1)))
                logging.info('Saved checkpoint to %s', checkpoint_name)

            if summary_writer:
                current_steps = steps_in_current_epoch + (epoch *
                                                          per_epoch_steps)
                with summary_writer.as_default():
                    tf.summary.scalar('train_loss', train_loss.result(),
                                      current_steps)
                    tf.summary.scalar('train_accuracy',
                                      training_accuracy.result(),
                                      current_steps)
                    tf.summary.scalar('eval_loss', test_loss.result(),
                                      current_steps)
                    tf.summary.scalar('eval_accuracy', test_accuracy.result(),
                                      current_steps)

        time_callback.on_train_end()
        if summary_writer:
            summary_writer.close()

        eval_result = None
        train_result = None
        if not flags_obj.skip_eval:
            eval_result = [
                test_loss.result().numpy(),
                test_accuracy.result().numpy()
            ]
            train_result = [
                train_loss.result().numpy(),
                training_accuracy.result().numpy()
            ]

        stats = build_stats(train_result, eval_result, time_callback)
        return stats
def resnet_main(flags_obj,
                model_function,
                input_function,
                dataset_name,
                shape=None):
    """Shared main loop for ResNet Models.
  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.
  """

    model_helpers.apply_clean(flags.FLAGS)

    # Using the Winograd non-fused algorithms provides a small performance boost.
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

    # Create session config based on values of inter_op_parallelism_threads and
    # intra_op_parallelism_threads. Note that we default to having
    # allow_soft_placement = True, which is required for multi-GPU and not
    # harmful for other modes.
    gpu_memory_fraction = tf.GPUOptions(
        per_process_gpu_memory_fraction=0.4)  # Xinyi add
    session_config = tf.ConfigProto(
        gpu_options=gpu_memory_fraction,  # Xinyi add
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        # flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)
        1,
        flags_obj.all_reduce_alg
    )  # Xinyi modified, get_num_gpus() will occupy all GPUs

    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config)

    classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        params={
            'resnet_size': int(flags_obj.resnet_size),
            'data_format': flags_obj.data_format,
            'batch_size': flags_obj.batch_size,
            'resnet_version': int(flags_obj.resnet_version),
            'loss_scale': flags_core.get_loss_scale(flags_obj),
            'dtype': flags_core.get_tf_dtype(flags_obj)
        })

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'resnet_size': flags_obj.resnet_size,
        'resnet_version': flags_obj.resnet_version,
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
    }
    if flags_obj.use_synthetic_data:
        dataset_name = dataset_name + '-synthetic'

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info('resnet',
                                  dataset_name,
                                  run_params,
                                  test_id=flags_obj.benchmark_test_id)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    def input_fn_train():
        return input_function(
            is_training=True,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                #flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
                flags_obj.batch_size,
                1),  # Xinyi modified, get_num_gpus() will occupy all GPUs
            num_epochs=flags_obj.epochs_between_evals,
            #num_gpus=flags_core.get_num_gpus(flags_obj))
            num_gpus=1)  # Xinyi modified, get_num_gpus() will occupy all GPUs

    def input_fn_eval():
        return input_function(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                #flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
                flags_obj.batch_size,
                1),  # Xinyi modified, get_num_gpus() will occupy all GPUs
            num_epochs=1)

    total_training_cycle = (flags_obj.train_epochs //
                            flags_obj.epochs_between_evals)
    for cycle_index in range(total_training_cycle):
        tf.logging.info('Starting a training cycle: %d/%d', cycle_index,
                        total_training_cycle)

        classifier.train(input_fn=input_fn_train,
                         hooks=train_hooks,
                         max_steps=flags_obj.max_train_steps)

        tf.logging.info('Starting to evaluate.')

        # flags_obj.max_train_steps is generally associated with testing and
        # profiling. As a result it is frequently called with synthetic data, which
        # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
        # eval (which is generally unimportant in those circumstances) to terminate.
        # Note that eval will run for max_train_steps each loop, regardless of the
        # global_step count.
        eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                           steps=flags_obj.max_train_steps)

        benchmark_logger.log_evaluation_result(eval_results)

        # Xinyi add, writing accuracy after every training epoch
        filename = os.path.join(flags.FLAGS.model_dir, 'learning_curve.csv')
        file_exists = os.path.isfile(filename)
        steps_per_epoch = int(50000 / flags.FLAGS.batch_size)
        fields = [
            'epochs', 'eval_accuracy', 'optimizer', 'learning_rate',
            'decay_rate', 'decay_steps', 'initializer', 'regularizer',
            'weight_decay', 'batch_size', 'model_id'
        ]
        csv_data = {
            'epochs': flags.FLAGS.epoch_index,
            'eval_accuracy': eval_results['accuracy'],
            'optimizer': flags.FLAGS.optimizer,
            'learning_rate': flags.FLAGS.learning_rate,
            'decay_rate': flags.FLAGS.decay_rate,
            'decay_steps': flags.FLAGS.decay_steps,
            'initializer': flags.FLAGS.initializer,
            'regularizer': flags.FLAGS.regularizer,
            'weight_decay': flags.FLAGS.weight_decay,
            'batch_size': flags.FLAGS.batch_size,
            'model_id': flags.FLAGS.model_id
        }

        if flags.FLAGS.optimizer=='Momentum' \
            or flags.FLAGS.optimizer=='RMSProp':
            fields.append('momentum')
            csv_data['momentum'] = flags.FLAGS.momentum
        if flags.FLAGS.optimizer == 'RMSProp':
            fields.append('grad_decay')
            csv_data['grad_decay'] = flags.FLAGS.grad_decay

        with open(filename, 'a') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fields)
            if not file_exists:
                writer.writeheader()
            writer.writerow(csv_data)

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                             eval_results['accuracy']):
            # break
            return eval_results['accuracy']  # Xinyi modified

    if flags_obj.export_dir is not None:
        # Exports a saved model for the given classifier.
        input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
            shape, batch_size=flags_obj.batch_size)
        classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)

    return eval_results['accuracy']  # Xinyi modified
  def __init__(self, flags_obj, time_callback, epoch_steps):
    self.strategy = tf.distribute.get_strategy()
    self.flags_obj = flags_obj
    self.dtype = flags_core.get_tf_dtype(flags_obj)
    self.time_callback = time_callback

    # Input pipeline related
    batch_size = flags_obj.batch_size
    if batch_size % self.strategy.num_replicas_in_sync != 0:
      raise ValueError(
          'Batch size must be divisible by number of replicas : {}'.format(
              self.strategy.num_replicas_in_sync))

    # As auto rebatching is not supported in
    # `experimental_distribute_datasets_from_function()` API, which is
    # required when cloning dataset to multiple workers in eager mode,
    # we use per-replica batch size.
    self.batch_size = int(batch_size / self.strategy.num_replicas_in_sync)

    if self.flags_obj.use_synthetic_data:
      self.input_fn = common.get_synth_input_fn(
          height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
          width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
          num_channels=imagenet_preprocessing.NUM_CHANNELS,
          num_classes=imagenet_preprocessing.NUM_CLASSES,
          dtype=self.dtype,
          drop_remainder=True)
    else:
      self.input_fn = imagenet_preprocessing.input_fn

    self.model = resnet_model.resnet50(
        num_classes=imagenet_preprocessing.NUM_CLASSES,
        use_l2_regularizer=not flags_obj.single_l2_loss_op)

    lr_schedule = common.PiecewiseConstantDecayWithWarmup(
        batch_size=flags_obj.batch_size,
        epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
        warmup_epochs=common.LR_SCHEDULE[0][1],
        boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
        multipliers=list(p[0] for p in common.LR_SCHEDULE),
        compute_lr_on_cpu=True)
    self.optimizer = common.get_optimizer(lr_schedule)
    # Make sure iterations variable is created inside scope.
    self.global_step = self.optimizer.iterations

    use_graph_rewrite = flags_obj.fp16_implementation == 'graph_rewrite'
    if use_graph_rewrite and not flags_obj.use_tf_function:
      raise ValueError('--fp16_implementation=graph_rewrite requires '
                       '--use_tf_function to be true')
    self.optimizer = performance.configure_optimizer(
        self.optimizer,
        use_float16=self.dtype == tf.float16,
        use_graph_rewrite=use_graph_rewrite,
        loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128))

    self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
    self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        'train_accuracy', dtype=tf.float32)
    self.test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
    self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        'test_accuracy', dtype=tf.float32)

    self.checkpoint = tf.train.Checkpoint(
        model=self.model, optimizer=self.optimizer)

    # Handling epochs.
    self.epoch_steps = epoch_steps
    self.epoch_helper = orbit.utils.EpochHelper(epoch_steps, self.global_step)
    train_dataset = orbit.utils.make_distributed_dataset(
        self.strategy,
        self.input_fn,
        is_training=True,
        data_dir=self.flags_obj.data_dir,
        batch_size=self.batch_size,
        parse_record_fn=imagenet_preprocessing.parse_record,
        datasets_num_private_threads=self.flags_obj
        .datasets_num_private_threads,
        dtype=self.dtype,
        drop_remainder=True)
    orbit.StandardTrainer.__init__(
        self,
        train_dataset,
        options=orbit.StandardTrainerOptions(
            use_tf_while_loop=flags_obj.use_tf_while_loop,
            use_tf_function=flags_obj.use_tf_function))
    if not flags_obj.skip_eval:
      eval_dataset = orbit.utils.make_distributed_dataset(
          self.strategy,
          self.input_fn,
          is_training=False,
          data_dir=self.flags_obj.data_dir,
          batch_size=self.batch_size,
          parse_record_fn=imagenet_preprocessing.parse_record,
          dtype=self.dtype)
      orbit.StandardEvaluator.__init__(
          self,
          eval_dataset,
          options=orbit.StandardEvaluatorOptions(
              use_tf_function=flags_obj.use_tf_function))
Exemple #23
0
def resnet_main(flags_obj,
                model_function,
                input_function,
                dataset_name,
                shape=None):
    """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.
  """

    model_helpers.apply_clean(flags.FLAGS)

    # Using the Winograd non-fused algorithms provides a small performance boost.
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

    # Create session config based on values of inter_op_parallelism_threads and
    # intra_op_parallelism_threads. Note that we default to having
    # allow_soft_placement = True, which is required for multi-GPU and not
    # harmful for other modes.
    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config)

    classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        params={
            'resnet_size': int(flags_obj.resnet_size),
            'data_format': flags_obj.data_format,
            'batch_size': flags_obj.batch_size,
            'resnet_version': int(flags_obj.resnet_version),
            'loss_scale': flags_core.get_loss_scale(flags_obj),
            'dtype': flags_core.get_tf_dtype(flags_obj)
        })

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'resnet_size': flags_obj.resnet_size,
        'resnet_version': flags_obj.resnet_version,
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
    }
    if flags_obj.use_synthetic_data:
        dataset_name = dataset_name + '-synthetic'

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info('resnet',
                                  dataset_name,
                                  run_params,
                                  test_id=flags_obj.benchmark_test_id)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    def input_fn_train():
        return input_function(
            is_training=True,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=flags_obj.epochs_between_evals,
            num_gpus=flags_core.get_num_gpus(flags_obj))

    def input_fn_eval():
        return input_function(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=1)

    total_training_cycle = (flags_obj.train_epochs //
                            flags_obj.epochs_between_evals)
    profiler_hook = tf.train.ProfilerHook(save_steps=100,
                                          save_secs=None,
                                          output_dir="profs",
                                          show_memory=True,
                                          show_dataflow=True)

    #DOGA DEBUG GRAPH
    gdef = gpb.GraphDef()

    with open('/tmp/cifar10_model/graph.pbtxt', 'r') as fh:
        graph_str = fh.read()

    pbtf.Parse(graph_str, gdef)
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(gdef)

        operations_tensors = {}
        operations_names = graph.get_operations()
        count1 = 0
        count2 = 0
        #print(operations_names)
        for operation in operations_names:
            operation_name = operation.name
            operations_info = graph.get_operation_by_name(
                operation_name).values()
            if len(operations_info) > 0:
                if not (operations_info[0].shape.ndims is None):
                    operation_shape = operations_info[0].shape.as_list()
                    operation_dtype_size = operations_info[0].dtype.size
                    if not (operation_dtype_size is None):
                        operation_no_of_elements = 1
                        for dim in operation_shape:
                            if not (dim is None):
                                operation_no_of_elements = operation_no_of_elements * dim
                        total_size = operation_no_of_elements * operation_dtype_size
                        operations_tensors[operation_name] = total_size
                    else:
                        count1 = count1 + 1
                else:
                    count1 = count1 + 1
                    operations_tensors[operation_name] = -1
            else:
                count2 = count2 + 1
                operations_tensors[operation_name] = -1

        print(count1)
        print(count2)

        with open('tensors_sz.json', 'w') as f:
            json.dump(operations_tensors, f)

    for cycle_index in range(total_training_cycle):
        tf.logging.info('Starting a training cycle: %d/%d', cycle_index,
                        total_training_cycle)

        classifier.train(input_fn=input_fn_train,
                         hooks=[profiler_hook],
                         max_steps=flags_obj.max_train_steps)

        tf.logging.info('Starting to evaluate.')

        # flags_obj.max_train_steps is generally associated with testing and
        # profiling. As a result it is frequently called with synthetic data, which
        # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
        # eval (which is generally unimportant in those circumstances) to terminate.
        # Note that eval will run for max_train_steps each loop, regardless of the
        # global_step count.
        eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                           steps=flags_obj.max_train_steps)

        benchmark_logger.log_evaluation_result(eval_results)

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                             eval_results['accuracy']):
            break

    if flags_obj.export_dir is not None:
        # Exports a saved model for the given classifier.
        input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
            shape, batch_size=flags_obj.batch_size)
        classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
Exemple #24
0
    def __init__(self, flags_obj):
        """Init function of TransformerMain.

    Args:
      flags_obj: Object containing parsed flag values, i.e., FLAGS.

    Raises:
      ValueError: if not using static batch for input data on TPU.
    """
        self.flags_obj = flags_obj
        self.predict_model = None

        # Add flag-defined parameters to params object
        num_gpus = flags_core.get_num_gpus(flags_obj)
        self.params = params = misc.get_model_params(flags_obj.param_set,
                                                     num_gpus)

        params["num_gpus"] = num_gpus
        params["use_ctl"] = flags_obj.use_ctl
        params["data_dir"] = flags_obj.data_dir
        params["model_dir"] = flags_obj.model_dir
        params["static_batch"] = flags_obj.static_batch
        params["max_length"] = flags_obj.max_length
        params["decode_batch_size"] = flags_obj.decode_batch_size
        params["decode_max_length"] = flags_obj.decode_max_length
        params["padded_decode"] = flags_obj.padded_decode
        params["num_parallel_calls"] = (flags_obj.num_parallel_calls
                                        or tf.data.experimental.AUTOTUNE)

        params["use_synthetic_data"] = flags_obj.use_synthetic_data
        params["batch_size"] = flags_obj.batch_size or params[
            "default_batch_size"]
        params["repeat_dataset"] = None
        params["dtype"] = flags_core.get_tf_dtype(flags_obj)
        params["enable_tensorboard"] = flags_obj.enable_tensorboard
        params[
            "enable_metrics_in_training"] = flags_obj.enable_metrics_in_training
        params["steps_between_evals"] = flags_obj.steps_between_evals
        params["enable_checkpointing"] = flags_obj.enable_checkpointing

        self.distribution_strategy = distribution_utils.get_distribution_strategy(
            distribution_strategy=flags_obj.distribution_strategy,
            num_gpus=num_gpus,
            all_reduce_alg=flags_obj.all_reduce_alg,
            num_packs=flags_obj.num_packs,
            tpu_address=flags_obj.tpu or "")
        if self.use_tpu:
            params[
                "num_replicas"] = self.distribution_strategy.num_replicas_in_sync
            if not params["static_batch"]:
                raise ValueError("TPU requires static batch for input data.")
        else:
            logging.info("Running transformer with num_gpus = %d", num_gpus)

        if self.distribution_strategy:
            logging.info("For training, using distribution strategy: %s",
                         self.distribution_strategy)
        else:
            logging.info("Not using any distribution strategy.")

        performance.set_mixed_precision_policy(
            params["dtype"],
            flags_core.get_loss_scale(flags_obj, default_for_fp16="dynamic"))
Exemple #25
0
    (ncf_input_pipeline.create_ncf_input_data(
        params, producer, input_meta_data, strategy))
  steps_per_epoch = None if generate_input_online else num_train_steps

  with distribution_utils.get_strategy_scope(strategy):
    keras_model = _get_keras_model(params)
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=params["learning_rate"],
        beta_1=params["beta1"],
        beta_2=params["beta2"],
        epsilon=params["epsilon"])
    if FLAGS.fp16_implementation == "graph_rewrite":
      optimizer = \
        tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(
            optimizer,
            loss_scale=flags_core.get_loss_scale(FLAGS,
                                                 default_for_fp16="dynamic"))
    elif FLAGS.dtype == "fp16" and params["keras_use_ctl"]:
      # When keras_use_ctl is False, instead Model.fit() automatically applies
      # loss scaling so we don't need to create a LossScaleOptimizer.
      optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
          optimizer,
          tf.keras.mixed_precision.experimental.global_policy().loss_scale)

    if params["keras_use_ctl"]:
      train_loss, eval_results = run_ncf_custom_training(
          params,
          strategy,
          keras_model,
          optimizer,
          callbacks,
          train_input_dataset,
Exemple #26
0
def get_loss_scale():
    return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic')
Exemple #27
0
def run_ncf(_):
    """Run NCF training and eval with Keras."""

    keras_utils.set_session_config(enable_xla=FLAGS.enable_xla)

    if FLAGS.seed is not None:
        print("Setting tf seed")
        tf.random.set_seed(FLAGS.seed)

    params = ncf_common.parse_flags(FLAGS)
    model_helpers.apply_clean(flags.FLAGS)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.distribution_strategy,
        num_gpus=FLAGS.num_gpus,
        tpu_address=FLAGS.tpu)
    params["distribute_strategy"] = strategy

    if not keras_utils.is_v2_0() and strategy is not None:
        logging.error(
            "NCF Keras only works with distribution strategy in TF 2.0")
        return
    if (params["keras_use_ctl"]
            and (not keras_utils.is_v2_0() or strategy is None)):
        logging.error(
            "Custom training loop only works with tensorflow 2.0 and dist strat."
        )
        return
    if params["use_tpu"] and not params["keras_use_ctl"]:
        logging.error(
            "Custom training loop must be used when using TPUStrategy.")
        return

    batch_size = params["batch_size"]
    time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps)
    callbacks = [time_callback]

    producer, input_meta_data = None, None
    generate_input_online = params["train_dataset_path"] is None

    if generate_input_online:
        # Start data producing thread.
        num_users, num_items, _, _, producer = ncf_common.get_inputs(params)
        producer.start()
        per_epoch_callback = IncrementEpochCallback(producer)
        callbacks.append(per_epoch_callback)
    else:
        assert params["eval_dataset_path"] and params["input_meta_data_path"]
        with tf.io.gfile.GFile(params["input_meta_data_path"], "rb") as reader:
            input_meta_data = json.loads(reader.read().decode("utf-8"))
            num_users = input_meta_data["num_users"]
            num_items = input_meta_data["num_items"]

    params["num_users"], params["num_items"] = num_users, num_items

    if FLAGS.early_stopping:
        early_stopping_callback = CustomEarlyStopping(
            "val_HR_METRIC", desired_value=FLAGS.hr_threshold)
        callbacks.append(early_stopping_callback)

    use_remote_tpu = params["use_tpu"] and FLAGS.tpu
    primary_cpu_task = tpu_lib.get_primary_cpu_task(use_remote_tpu)

    with tf.device(primary_cpu_task):
        (train_input_dataset, eval_input_dataset,
         num_train_steps, num_eval_steps) = \
          (ncf_input_pipeline.create_ncf_input_data(
              params, producer, input_meta_data, strategy))
        steps_per_epoch = None if generate_input_online else num_train_steps

        with distribution_utils.get_strategy_scope(strategy):
            keras_model = _get_keras_model(params)
            optimizer = tf.keras.optimizers.Adam(
                learning_rate=params["learning_rate"],
                beta_1=params["beta1"],
                beta_2=params["beta2"],
                epsilon=params["epsilon"])
            if FLAGS.dtype == "fp16":
                optimizer = \
                  tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(
                      optimizer,
                      loss_scale=flags_core.get_loss_scale(FLAGS,
                                                           default_for_fp16="dynamic"))

            if params["keras_use_ctl"]:
                train_loss, eval_results = run_ncf_custom_training(
                    params,
                    strategy,
                    keras_model,
                    optimizer,
                    callbacks,
                    train_input_dataset,
                    eval_input_dataset,
                    num_train_steps,
                    num_eval_steps,
                    generate_input_online=generate_input_online)
            else:
                # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
                # a valid arg for this model. Also remove as a valid flag.
                if FLAGS.force_v2_in_keras_compile is not None:
                    keras_model.compile(optimizer=optimizer,
                                        run_eagerly=FLAGS.run_eagerly,
                                        experimental_run_tf_function=FLAGS.
                                        force_v2_in_keras_compile)
                else:
                    keras_model.compile(optimizer=optimizer,
                                        run_eagerly=FLAGS.run_eagerly)

                history = keras_model.fit(train_input_dataset,
                                          epochs=FLAGS.train_epochs,
                                          steps_per_epoch=steps_per_epoch,
                                          callbacks=callbacks,
                                          validation_data=eval_input_dataset,
                                          validation_steps=num_eval_steps,
                                          verbose=2)

                logging.info("Training done. Start evaluating")

                eval_loss_and_metrics = keras_model.evaluate(
                    eval_input_dataset, steps=num_eval_steps, verbose=2)

                logging.info("Keras evaluation is done.")

                # Keras evaluate() API returns scalar loss and metric values from
                # evaluation as a list. Here, the returned list would contain
                # [evaluation loss, hr sum, hr count].
                eval_hit_rate = eval_loss_and_metrics[
                    1] / eval_loss_and_metrics[2]

                # Format evaluation result into [eval loss, eval hit accuracy].
                eval_results = [eval_loss_and_metrics[0], eval_hit_rate]

                if history and history.history:
                    train_history = history.history
                    train_loss = train_history["loss"][-1]

        stats = build_stats(train_loss, eval_results, time_callback)
        return stats
def resnet_main(
    flags_obj, model_function, input_function, dataset_name, shape=None):
  """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.
  """

  # Using the Winograd non-fused algorithms provides a small performance boost.
  os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

  # Create session config based on values of inter_op_parallelism_threads and
  # intra_op_parallelism_threads. Note that we default to having
  # allow_soft_placement = True, which is required for multi-GPU and not
  # harmful for other modes.
  session_config = tf.ConfigProto(
      inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
      intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
      allow_soft_placement=True)

  distribution_strategy = distribution_utils.get_distribution_strategy(
      flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

  run_config = tf.estimator.RunConfig(
      train_distribute=distribution_strategy, session_config=session_config)

  classifier = tf.estimator.Estimator(
      model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config,
      params={
          'resnet_size': int(flags_obj.resnet_size),
          'data_format': flags_obj.data_format,
          'batch_size': flags_obj.batch_size,
          'resnet_version': int(flags_obj.resnet_version),
          'loss_scale': flags_core.get_loss_scale(flags_obj),
          'dtype': flags_core.get_tf_dtype(flags_obj)
      })

  run_params = {
      'batch_size': flags_obj.batch_size,
      'dtype': flags_core.get_tf_dtype(flags_obj),
      'resnet_size': flags_obj.resnet_size,
      'resnet_version': flags_obj.resnet_version,
      'synthetic_data': flags_obj.use_synthetic_data,
      'train_epochs': flags_obj.train_epochs,
  }
  if flags_obj.use_synthetic_data:
    dataset_name = dataset_name + '-synthetic'

  benchmark_logger = logger.get_benchmark_logger()
  benchmark_logger.log_run_info('resnet', dataset_name, run_params,
                                test_id=flags_obj.benchmark_test_id)

  train_hooks = hooks_helper.get_train_hooks(
      flags_obj.hooks,
      batch_size=flags_obj.batch_size)

  def input_fn_train():
    return input_function(
        is_training=True, data_dir=flags_obj.data_dir,
        batch_size=distribution_utils.per_device_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=flags_obj.epochs_between_evals,
        num_gpus=flags_core.get_num_gpus(flags_obj))

  def input_fn_eval():
    return input_function(
        is_training=False, data_dir=flags_obj.data_dir,
        batch_size=distribution_utils.per_device_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=1)

  total_training_cycle = (flags_obj.train_epochs //
                          flags_obj.epochs_between_evals)
  for cycle_index in range(total_training_cycle):
    tf.logging.info('Starting a training cycle: %d/%d',
                    cycle_index, total_training_cycle)

    classifier.train(input_fn=input_fn_train, hooks=train_hooks,
                     max_steps=flags_obj.max_train_steps)

    tf.logging.info('Starting to evaluate.')

    # flags_obj.max_train_steps is generally associated with testing and
    # profiling. As a result it is frequently called with synthetic data, which
    # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
    # eval (which is generally unimportant in those circumstances) to terminate.
    # Note that eval will run for max_train_steps each loop, regardless of the
    # global_step count.
    eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                       steps=flags_obj.max_train_steps)

    benchmark_logger.log_evaluation_result(eval_results)

    if model_helpers.past_stop_threshold(
        flags_obj.stop_threshold, eval_results['accuracy']):
      break

  if flags_obj.export_dir is not None:
    # Exports a saved model for the given classifier.
    input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
        shape, batch_size=flags_obj.batch_size)
    classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
Exemple #29
0
def run(flags_obj):
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
  # Eager is default in tf 2.0 and should not be toggled
  if keras_common.is_v2_0():
    keras_common.set_config_v2()
  else:
    config = keras_common.get_config_proto_v1()
    if flags_obj.enable_eager:
      tf.compat.v1.enable_eager_execution(config=config)
    else:
      sess = tf.Session(config=config)
      tf.keras.backend.set_session(sess)

  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
    keras_common.set_gpu_thread_mode_and_count(flags_obj)
  if flags_obj.data_prefetch_with_slack:
    keras_common.data_prefetch_with_slack()
  keras_common.set_cudnn_batchnorm_mode()

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == 'float16':
    policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
    tf.keras.mixed_precision.experimental.set_policy(policy)

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
      num_workers=distribution_utils.configure_cluster(),
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs)

  strategy_scope = distribution_utils.get_strategy_scope(strategy)

  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
    distribution_utils.set_up_synthetic_data()
    input_fn = keras_common.get_synth_input_fn(
        height=imagenet_main.DEFAULT_IMAGE_SIZE,
        width=imagenet_main.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_main.NUM_CHANNELS,
        num_classes=imagenet_main.NUM_CLASSES,
        dtype=dtype,
        drop_remainder=True)
  else:
    distribution_utils.undo_set_up_synthetic_data()
    input_fn = imagenet_main.input_fn

  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras,
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
      dtype=dtype,
      drop_remainder=drop_remainder)

  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=parse_record_keras,
        dtype=dtype,
        drop_remainder=drop_remainder)

  lr_schedule = 0.1
  if flags_obj.use_tensor_lr:
    lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup(
        batch_size=flags_obj.batch_size,
        epoch_size=imagenet_main.NUM_IMAGES['train'],
        warmup_epochs=LR_SCHEDULE[0][1],
        boundaries=list(p[1] for p in LR_SCHEDULE[1:]),
        multipliers=list(p[0] for p in LR_SCHEDULE),
        compute_lr_on_cpu=True)

  with strategy_scope:
    optimizer = keras_common.get_optimizer(lr_schedule)
    if dtype == 'float16':
      # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
      # can be enabled with a single line of code.
      optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
          optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))

    if flags_obj.enable_xla and not flags_obj.enable_eager:
      # TODO(b/129861005): Fix OOM issue in eager mode when setting
      # `batch_size` in keras.Input layer.
      if strategy and strategy.num_replicas_in_sync > 1:
        # TODO(b/129791381): Specify `input_layer_batch_size` value in
        # DistributionStrategy multi-replica case.
        input_layer_batch_size = None
      else:
        input_layer_batch_size = flags_obj.batch_size
    else:
      input_layer_batch_size = None

    if flags_obj.use_trivial_model:
      model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES, dtype)
    else:
      model = resnet_model.resnet50(
          num_classes=imagenet_main.NUM_CLASSES,
          dtype=dtype,
          batch_size=input_layer_batch_size)

    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer=optimizer,
                  metrics=(['sparse_categorical_accuracy']
                           if flags_obj.report_accuracy_metrics else None),
                  cloning=flags_obj.clone_model_in_keras_dist_strat)

  callbacks = keras_common.get_callbacks(
      learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])

  train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

  num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
                    flags_obj.batch_size)

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
    tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None
    validation_data = None

  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
                      callbacks=callbacks,
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      validation_freq=flags_obj.epochs_between_evals,
                      verbose=2)

  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)
  stats = keras_common.build_stats(history, eval_output, callbacks)
  return stats
Exemple #30
0
    def __init__(self, flags_obj):
        """Init function of TransformerMain.

        Args:
          flags_obj: Object containing parsed flag values, i.e., FLAGS.

        Raises:
          ValueError: if not using static batch for input data on TPU.
        """
        self.flags_obj = flags_obj
        self.predict_model = None

        # Add flag-defined parameters to params object
        num_gpus = flags_core.get_num_gpus(flags_obj)
        self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus)

        params["embedding_size"] = params["hidden_size"] / 8
        params["num_gpus"] = num_gpus
        params["use_ctl"] = flags_obj.use_ctl
        params["data_dir"] = flags_obj.data_dir
        params["model_dir"] = flags_obj.model_dir
        params["static_batch"] = flags_obj.static_batch
        params["max_length"] = flags_obj.max_length
        params["decode_batch_size"] = flags_obj.decode_batch_size
        params["decode_max_length"] = flags_obj.decode_max_length
        params["padded_decode"] = flags_obj.padded_decode
        params["num_parallel_calls"] = (
                flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE)

        params["use_synthetic_data"] = flags_obj.use_synthetic_data
        params["batch_size"] = flags_obj.batch_size or params["default_batch_size"]
        params["repeat_dataset"] = None
        params["dtype"] = flags_core.get_tf_dtype(flags_obj)
        params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training

        if params["dtype"] == tf.float16:
            # TODO(reedwm): It's pretty ugly to set the global policy in a constructor
            # like this. What if multiple instances of TransformerTask are created?
            # We should have a better way in the tf.keras.mixed_precision API of doing
            # this.
            loss_scale = flags_core.get_loss_scale(flags_obj,
                                                   default_for_fp16="dynamic")
            policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
                "mixed_float16", loss_scale=loss_scale)
            tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)

        if params["dtype"] == tf.bfloat16:
            policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
                "mixed_bfloat16")
            tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)

        self.distribution_strategy = distribution_utils.get_distribution_strategy(
            distribution_strategy=flags_obj.distribution_strategy,
            num_gpus=num_gpus,
            tpu_address=flags_obj.tpu or "")
        if self.use_tpu:
            params["num_replicas"] = self.distribution_strategy.num_replicas_in_sync
            if not params["static_batch"]:
                raise ValueError("TPU requires static batch for input data.")
        else:
            logging.info("Running transformer with num_gpus = %d", num_gpus)

        if self.distribution_strategy:
            logging.info("For training, using distribution strategy: %s",
                         self.distribution_strategy)
        else:
            logging.info("Not using any distribution strategy.")
Exemple #31
0
def resnet_main(flags_obj,
                model_function,
                input_function,
                dataset_name,
                train_size,
                shape=None):
    #########################################################
    """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.
  """

    model_helpers.apply_clean(flags.FLAGS)

    # Ensures flag override logic is only executed if explicitly triggered.
    if flags_obj.tf_gpu_thread_mode:
        override_flags_and_set_envars_for_gpu_thread_pool(flags_obj)

    # Creates session config. allow_soft_placement = True, is required for
    # multi-GPU and is not harmful for other modes.
    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

    # Creates a `RunConfig` that checkpoints every 24 hours which essentially
    # results in checkpoints determined only by `epochs_between_evals`.
    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config,
                                        save_checkpoints_secs=60 * 60 * 24)

    # Initializes model with all but the dense layer from pretrained ResNet.
    if flags_obj.pretrained_model_checkpoint_path is not None:
        warm_start_settings = tf.estimator.WarmStartSettings(
            flags_obj.pretrained_model_checkpoint_path,
            vars_to_warm_start='^(?!.*dense)')
    else:
        warm_start_settings = None

    #######################################
    # my change: add learning rate and train data size
    ######################################
    # classifier = tf.estimator.Estimator(
    #     model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config,
    #     warm_start_from=warm_start_settings, params={
    #         'resnet_size': int(flags_obj.resnet_size),
    #         'data_format': flags_obj.data_format,
    #         'batch_size': flags_obj.batch_size,
    #         'resnet_version': int(flags_obj.resnet_version),
    #         'loss_scale': flags_core.get_loss_scale(flags_obj),
    #         'dtype': flags_core.get_tf_dtype(flags_obj),
    #         'fine_tune': flags_obj.fine_tune
    #     })
    classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        warm_start_from=warm_start_settings,
        params={
            'resnet_size': int(flags_obj.resnet_size),
            'data_format': flags_obj.data_format,
            'batch_size': flags_obj.batch_size,
            'resnet_version': int(flags_obj.resnet_version),
            'loss_scale': flags_core.get_loss_scale(flags_obj),
            'dtype': flags_core.get_tf_dtype(flags_obj),
            'fine_tune': flags_obj.fine_tune,
            'learning_rate': flags_obj.learning_rate,
            'train_data_size': flags_obj.train_data_size
        })
    ######################################

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'resnet_size': flags_obj.resnet_size,
        'resnet_version': flags_obj.resnet_version,
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
    }
    if flags_obj.use_synthetic_data:
        dataset_name = dataset_name + '-synthetic'

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info('resnet',
                                  dataset_name,
                                  run_params,
                                  test_id=flags_obj.benchmark_test_id)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    #################### My Changes #########################
    """
  purpose -- val by steps not by epochs
  change -- the input argument is not num_epochs, but starts and num_steps
  """

    # def input_fn_train(num_epochs):
    def input_fn_train(start_index, num_steps):
        #########################################################

        return input_function(
            is_training=True,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),

            #################### My Changes #########################
            # """
            # purpose -- val by steps not by epochs
            # change -- add two args, start_index and num steps, remove num_epochs
            # let num_epochs alwasys = 1
            # """
            # purpose 2: add one more argument: image_size
            start_index=start_index,
            num_steps=num_steps,
            image_size=flags_obj.image_size,
            # num_epochs=num_epochs,
            #########################################################
            dtype=flags_core.get_tf_dtype(flags_obj),
            datasets_num_private_threads=flags_obj.
            datasets_num_private_threads,
            num_parallel_batches=flags_obj.datasets_num_parallel_batches)

    def input_fn_eval():
        return input_function(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            ############### my change ##################
            # add image_size
            image_size=flags_obj.image_size,
            ########################################
            num_epochs=1,
            dtype=flags_core.get_tf_dtype(flags_obj))

    #################### My Changes #########################
    """
  purpose -- val by steps not by epochs
  """
    if flags_obj.eval_only or not flags_obj.train_epochs:
        # If --eval_only is set, perform a single loop with zero train epochs.
        schedule, n_loops = [], 1
    else:
        n_loops = math.ceil(flags_obj.train_epochs*train_size \
            / (flags_obj.steps_between_evals*flags_obj.batch_size))
        schedule = [flags_obj.steps_between_evals*flags_obj.batch_size*i%train_size\
                    for i in range(int(n_loops))]

    # print(schedule)
    # if eval_only
    if len(schedule) == 0:
        tf.logging.info('Starting to evaluate.')
        eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                           steps=flags_obj.max_train_steps)
        benchmark_logger.log_evaluation_result(eval_results)
    else:
        for cycle_index, start_index in enumerate(schedule):
            tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops))

            classifier.train(input_fn=lambda: input_fn_train(
                start_index=start_index,
                num_steps=flags_obj.steps_between_evals),
                             hooks=train_hooks,
                             max_steps=flags_obj.max_train_steps)

            tf.logging.info('Starting to evaluate.')

            eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                               steps=flags_obj.max_train_steps)

            benchmark_logger.log_evaluation_result(eval_results)

            if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                                 eval_results['accuracy']):
                break

    # if flags_obj.eval_only or not flags_obj.train_epochs:
    #   # If --eval_only is set, perform a single loop with zero train epochs.
    #   schedule, n_loops = [0], 1
    # else:
    #   # Compute the number of times to loop while training. All but the last
    #   # pass will train for `epochs_between_evals` epochs, while the last will
    #   # train for the number needed to reach `training_epochs`. For instance if
    #   #   train_epochs = 25 and epochs_between_evals = 10
    #   # schedule will be set to [10, 10, 5]. That is to say, the loop will:
    #   #   Train for 10 epochs and then evaluate.
    #   #   Train for another 10 epochs and then evaluate.
    #   #   Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
    #   n_loops = math.ceil(flags_obj.train_epochs / flags_obj.epochs_between_evals)
    #   schedule = [flags_obj.epochs_between_evals for _ in range(int(n_loops))]
    #   schedule[-1] = flags_obj.train_epochs - sum(schedule[:-1])  # over counting.

    # for cycle_index, num_train_epochs in enumerate(schedule):
    #   tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops))

    #   if num_train_epochs:
    #     classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
    #                      hooks=train_hooks, max_steps=flags_obj.max_train_steps)

    #   tf.logging.info('Starting to evaluate.')

    #   # flags_obj.max_train_steps is generally associated with testing and
    #   # profiling. As a result it is frequently called with synthetic data, which
    #   # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
    #   # eval (which is generally unimportant in those circumstances) to terminate.
    #   # Note that eval will run for max_train_steps each loop, regardless of the
    #   # global_step count.
    #   eval_results = classifier.evaluate(input_fn=input_fn_eval,
    #                                      steps=flags_obj.max_train_steps)

    #   benchmark_logger.log_evaluation_result(eval_results)

    #   if model_helpers.past_stop_threshold(
    #       flags_obj.stop_threshold, eval_results['accuracy']):
    #     break
    #########################################################

    if flags_obj.export_dir is not None:
        # Exports a saved model for the given classifier.
        export_dtype = flags_core.get_tf_dtype(flags_obj)
        if flags_obj.image_bytes_as_serving_input:
            input_receiver_fn = functools.partial(image_bytes_serving_input_fn,
                                                  shape,
                                                  dtype=export_dtype)
        else:
            input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
                shape, batch_size=flags_obj.batch_size, dtype=export_dtype)
        classifier.export_savedmodel(flags_obj.export_dir,
                                     input_receiver_fn,
                                     strip_default_attrs=True)
Exemple #32
0
def resnet_main(
    flags_obj, model_function, input_function, dataset_name, shape=None):
  """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.
  """

  # Using the Winograd non-fused algorithms provides a small performance boost.
  os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

  # Create session config based on values of inter_op_parallelism_threads and
  # intra_op_parallelism_threads. Note that we default to having
  # allow_soft_placement = True, which is required for multi-GPU and not
  # harmful for other modes.
  session_config = tf.ConfigProto(
      inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
      intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
      allow_soft_placement=True)

  if flags_core.get_num_gpus(flags_obj) == 0:
    distribution = tf.contrib.distribute.OneDeviceStrategy('device:CPU:0')
  elif flags_core.get_num_gpus(flags_obj) == 1:
    distribution = tf.contrib.distribute.OneDeviceStrategy('device:GPU:0')
  else:
    distribution = tf.contrib.distribute.MirroredStrategy(
        num_gpus=flags_core.get_num_gpus(flags_obj)
    )

  run_config = tf.estimator.RunConfig(train_distribute=distribution,
                                      session_config=session_config)

  classifier = tf.estimator.Estimator(
      model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config,
      params={
          'resnet_size': int(flags_obj.resnet_size),
          'data_format': flags_obj.data_format,
          'batch_size': flags_obj.batch_size,
          'resnet_version': int(flags_obj.resnet_version),
          'loss_scale': flags_core.get_loss_scale(flags_obj),
          'dtype': flags_core.get_tf_dtype(flags_obj)
      })

  run_params = {
      'batch_size': flags_obj.batch_size,
      'dtype': flags_core.get_tf_dtype(flags_obj),
      'resnet_size': flags_obj.resnet_size,
      'resnet_version': flags_obj.resnet_version,
      'synthetic_data': flags_obj.use_synthetic_data,
      'train_epochs': flags_obj.train_epochs,
  }
  benchmark_logger = logger.config_benchmark_logger(flags_obj)
  benchmark_logger.log_run_info('resnet', dataset_name, run_params)

  train_hooks = hooks_helper.get_train_hooks(
      flags_obj.hooks,
      batch_size=flags_obj.batch_size)

  def input_fn_train():
    return input_function(
        is_training=True, data_dir=flags_obj.data_dir,
        batch_size=per_device_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=flags_obj.epochs_between_evals)

  def input_fn_eval():
    return input_function(
        is_training=False, data_dir=flags_obj.data_dir,
        batch_size=per_device_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=1)

  total_training_cycle = (flags_obj.train_epochs //
                          flags_obj.epochs_between_evals)
  for cycle_index in range(total_training_cycle):
    tf.logging.info('Starting a training cycle: %d/%d',
                    cycle_index, total_training_cycle)

    classifier.train(input_fn=input_fn_train, hooks=train_hooks,
                     max_steps=flags_obj.max_train_steps)

    tf.logging.info('Starting to evaluate.')

    # flags_obj.max_train_steps is generally associated with testing and
    # profiling. As a result it is frequently called with synthetic data, which
    # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
    # eval (which is generally unimportant in those circumstances) to terminate.
    # Note that eval will run for max_train_steps each loop, regardless of the
    # global_step count.
    eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                       steps=flags_obj.max_train_steps)

    benchmark_logger.log_evaluation_result(eval_results)

    if model_helpers.past_stop_threshold(
        flags_obj.stop_threshold, eval_results['accuracy']):
      break

  if flags_obj.export_dir is not None:
    # Exports a saved model for the given classifier.
    input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
        shape, batch_size=flags_obj.batch_size)
    classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
Exemple #33
0
def run(flags_obj):
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  keras_utils.set_session_config(
      enable_eager=flags_obj.enable_eager,
      enable_xla=flags_obj.enable_xla)

  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
    common.set_gpu_thread_mode_and_count(flags_obj)
  common.set_cudnn_batchnorm_mode()

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == tf.float16:
    loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
    policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
        'mixed_float16', loss_scale=loss_scale)
    tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
    if not keras_utils.is_v2_0():
      raise ValueError('--dtype=fp16 is not supported in TensorFlow 1.')
  elif dtype == tf.bfloat16:
    policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
        'mixed_bfloat16')
    tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  # Configures cluster spec for distribution strategy.
  num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                                     flags_obj.task_index)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
      num_workers=num_workers,
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs,
      tpu_address=flags_obj.tpu)

  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )

  strategy_scope = distribution_utils.get_strategy_scope(strategy)

  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
    distribution_utils.set_up_synthetic_data()
    input_fn = common.get_synth_input_fn(
        height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_preprocessing.NUM_CHANNELS,
        num_classes=imagenet_preprocessing.NUM_CLASSES,
        dtype=dtype,
        drop_remainder=True)
  else:
    distribution_utils.undo_set_up_synthetic_data()
    input_fn = imagenet_preprocessing.input_fn

  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=imagenet_preprocessing.parse_record,
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
      dtype=dtype,
      drop_remainder=drop_remainder,
      tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
      training_dataset_cache=flags_obj.training_dataset_cache,
  )

  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=imagenet_preprocessing.parse_record,
        dtype=dtype,
        drop_remainder=drop_remainder)

  lr_schedule = 0.1
  if flags_obj.use_tensor_lr:
    lr_schedule = common.PiecewiseConstantDecayWithWarmup(
        batch_size=flags_obj.batch_size,
        epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
        warmup_epochs=common.LR_SCHEDULE[0][1],
        boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
        multipliers=list(p[0] for p in common.LR_SCHEDULE),
        compute_lr_on_cpu=True)

  with strategy_scope:
    optimizer = common.get_optimizer(lr_schedule)
    if flags_obj.fp16_implementation == 'graph_rewrite':
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          optimizer)

    # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
    if flags_obj.use_trivial_model:
      model = trivial_model.trivial_model(
          imagenet_preprocessing.NUM_CLASSES)
    else:
      print('data format: ', tf.keras.backend.image_data_format())
      print('num_classes: ', imagenet_preprocessing.NUM_CLASSES)
      features = tf.keras.applications.ResNet50V2(include_top=False, weights='imagenet', pooling='avg', input_shape=(224,224,3)) # , input_shape=(224,224,3)
      x = tf.keras.layers.Dense(imagenet_preprocessing.NUM_CLASSES, name='fc1000')(features.output)
      x = tf.keras.layers.Activation('softmax', dtype='float32')(x)
      model = tf.keras.models.Model(features.input, x, name='resnet50')
      
      model.summary()

      #model = resnet_model.resnet50(
      #    num_classes=imagenet_preprocessing.NUM_CLASSES)

    # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
    # a valid arg for this model. Also remove as a valid flag.
    if flags_obj.force_v2_in_keras_compile is not None:
      model.compile(
          loss='sparse_categorical_crossentropy',
          optimizer=optimizer,
          metrics=(['sparse_categorical_accuracy']
                   if flags_obj.report_accuracy_metrics else None),
          run_eagerly=flags_obj.run_eagerly,
          experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)
    else:
      model.compile(
          loss='sparse_categorical_crossentropy',
          optimizer=optimizer,
          metrics=(['sparse_categorical_accuracy']
                   if flags_obj.report_accuracy_metrics else None),
          run_eagerly=flags_obj.run_eagerly)

  callbacks = common.get_callbacks(
      common.learning_rate_schedule, imagenet_preprocessing.NUM_IMAGES['train'])
  if flags_obj.enable_checkpoint_and_export:
    ckpt_full_path = os.path.join(flags_obj.model_dir, 'model.ckpt-{epoch:04d}')
    callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_full_path,
                                                        save_weights_only=True))
  train_steps = (
      imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
  train_epochs = flags_obj.train_epochs

  # if mutliple epochs, ignore the train_steps flag.
  if train_epochs <= 1 and flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

  num_eval_steps = (
      imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size)

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None
    validation_data = None

  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
    # when not using distribition strategy.
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
                      callbacks=callbacks,
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      validation_freq=flags_obj.epochs_between_evals,
                      verbose=2)
  if flags_obj.enable_checkpoint_and_export:
    if dtype == tf.bfloat16:
      logging.warning("Keras model.save does not support bfloat16 dtype.")
    else:
      # Keras model.save assumes a float32 input designature.
      export_path = os.path.join(flags_obj.model_dir, 'saved_model')
      model.save(export_path, include_optimizer=False)

  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

  stats = common.build_stats(history, eval_output, callbacks)
  return stats
def resnet_main(
    flags_obj, model_function, input_function, dataset_name, shape=None):
  """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.

  Returns:
    Dict of results of the run.
  """

  model_helpers.apply_clean(flags.FLAGS)

  # Ensures flag override logic is only executed if explicitly triggered.
  if flags_obj.tf_gpu_thread_mode:
    override_flags_and_set_envars_for_gpu_thread_pool(flags_obj)

  # Creates session config. allow_soft_placement = True, is required for
  # multi-GPU and is not harmful for other modes.
  session_config = tf.ConfigProto(
      inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
      intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
      allow_soft_placement=True)

  distribution_strategy = distribution_utils.get_distribution_strategy(
      flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

  # Creates a `RunConfig` that checkpoints every 24 hours which essentially
  # results in checkpoints determined only by `epochs_between_evals`.
  run_config = tf.estimator.RunConfig(
      train_distribute=distribution_strategy,
      session_config=session_config,
      save_checkpoints_secs=60*60*24)

  # Initializes model with all but the dense layer from pretrained ResNet.
  if flags_obj.pretrained_model_checkpoint_path is not None:
    warm_start_settings = tf.estimator.WarmStartSettings(
        flags_obj.pretrained_model_checkpoint_path,
        vars_to_warm_start='^(?!.*dense)')
  else:
    warm_start_settings = None

  classifier = tf.estimator.Estimator(
      model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config,
      warm_start_from=warm_start_settings, params={
          'resnet_size': int(flags_obj.resnet_size),
          'data_format': flags_obj.data_format,
          'batch_size': flags_obj.batch_size,
          'resnet_version': int(flags_obj.resnet_version),
          'loss_scale': flags_core.get_loss_scale(flags_obj),
          'dtype': flags_core.get_tf_dtype(flags_obj),
          'fine_tune': flags_obj.fine_tune
      })

  run_params = {
      'batch_size': flags_obj.batch_size,
      'dtype': flags_core.get_tf_dtype(flags_obj),
      'resnet_size': flags_obj.resnet_size,
      'resnet_version': flags_obj.resnet_version,
      'synthetic_data': flags_obj.use_synthetic_data,
      'train_epochs': flags_obj.train_epochs,
  }
  if flags_obj.use_synthetic_data:
    dataset_name = dataset_name + '-synthetic'

  benchmark_logger = logger.get_benchmark_logger()
  benchmark_logger.log_run_info('resnet', dataset_name, run_params,
                                test_id=flags_obj.benchmark_test_id)

  train_hooks = hooks_helper.get_train_hooks(
      flags_obj.hooks,
      model_dir=flags_obj.model_dir,
      batch_size=flags_obj.batch_size)

  def input_fn_train(num_epochs):
    return input_function(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=distribution_utils.per_device_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=num_epochs,
        dtype=flags_core.get_tf_dtype(flags_obj),
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        num_parallel_batches=flags_obj.datasets_num_parallel_batches)

  def input_fn_eval():
    return input_function(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=distribution_utils.per_device_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=1,
        dtype=flags_core.get_tf_dtype(flags_obj))

  if flags_obj.eval_only or not flags_obj.train_epochs:
    # If --eval_only is set, perform a single loop with zero train epochs.
    schedule, n_loops = [0], 1
  else:
    # Compute the number of times to loop while training. All but the last
    # pass will train for `epochs_between_evals` epochs, while the last will
    # train for the number needed to reach `training_epochs`. For instance if
    #   train_epochs = 25 and epochs_between_evals = 10
    # schedule will be set to [10, 10, 5]. That is to say, the loop will:
    #   Train for 10 epochs and then evaluate.
    #   Train for another 10 epochs and then evaluate.
    #   Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
    n_loops = math.ceil(flags_obj.train_epochs / flags_obj.epochs_between_evals)
    schedule = [flags_obj.epochs_between_evals for _ in range(int(n_loops))]
    schedule[-1] = flags_obj.train_epochs - sum(schedule[:-1])  # over counting.

  for cycle_index, num_train_epochs in enumerate(schedule):
    tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops))

    if num_train_epochs:
      classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
                       hooks=train_hooks, max_steps=flags_obj.max_train_steps)

    tf.logging.info('Starting to evaluate.')

    # flags_obj.max_train_steps is generally associated with testing and
    # profiling. As a result it is frequently called with synthetic data, which
    # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
    # eval (which is generally unimportant in those circumstances) to terminate.
    # Note that eval will run for max_train_steps each loop, regardless of the
    # global_step count.
    eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                       steps=flags_obj.max_train_steps)

    benchmark_logger.log_evaluation_result(eval_results)

    if model_helpers.past_stop_threshold(
        flags_obj.stop_threshold, eval_results['accuracy']):
      break

  if flags_obj.export_dir is not None:
    # Exports a saved model for the given classifier.
    export_dtype = flags_core.get_tf_dtype(flags_obj)
    if flags_obj.image_bytes_as_serving_input:
      input_receiver_fn = functools.partial(
          image_bytes_serving_input_fn, shape, dtype=export_dtype)
    else:
      input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
          shape, batch_size=flags_obj.batch_size, dtype=export_dtype)
    classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn,
                                 strip_default_attrs=True)
  return eval_results
Exemple #35
0
def resnet_main(
    flags_obj, model_function, input_function, dataset_name, shape=None):
  """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.
  """

  model_helpers.apply_clean(flags.FLAGS)

  # Using the Winograd non-fused algorithms provides a small performance boost.
  os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

  # Create session config based on values of inter_op_parallelism_threads and
  # intra_op_parallelism_threads. Note that we default to having
  # allow_soft_placement = True, which is required for multi-GPU and not
  # harmful for other modes.
  session_config = tf.ConfigProto(
      inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
      intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
      allow_soft_placement=True)

  distribution_strategy = distribution_utils.get_distribution_strategy(
      flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

  run_config = tf.estimator.RunConfig(
      train_distribute=distribution_strategy, session_config=session_config)

  # initialize our model with all but the dense layer from pretrained resnet
  if flags_obj.pretrained_model_checkpoint_path is not None:
    warm_start_settings = tf.estimator.WarmStartSettings(
        flags_obj.pretrained_model_checkpoint_path,
        vars_to_warm_start='^(?!.*dense)')
  else:
    warm_start_settings = None

  classifier = tf.estimator.Estimator(
      model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config,
      warm_start_from=warm_start_settings, params={
          'resnet_size': int(flags_obj.resnet_size),
          'data_format': flags_obj.data_format,
          'batch_size': flags_obj.batch_size,
          'resnet_version': int(flags_obj.resnet_version),
          'loss_scale': flags_core.get_loss_scale(flags_obj),
          'dtype': flags_core.get_tf_dtype(flags_obj),
          'fine_tune': flags_obj.fine_tune
      })

  run_params = {
      'batch_size': flags_obj.batch_size,
      'dtype': flags_core.get_tf_dtype(flags_obj),
      'resnet_size': flags_obj.resnet_size,
      'resnet_version': flags_obj.resnet_version,
      'synthetic_data': flags_obj.use_synthetic_data,
      'train_epochs': flags_obj.train_epochs,
  }
  if flags_obj.use_synthetic_data:
    dataset_name = dataset_name + '-synthetic'

  benchmark_logger = logger.get_benchmark_logger()
  benchmark_logger.log_run_info('resnet', dataset_name, run_params,
                                test_id=flags_obj.benchmark_test_id)

  train_hooks = hooks_helper.get_train_hooks(
      flags_obj.hooks,
      model_dir=flags_obj.model_dir,
      batch_size=flags_obj.batch_size)

  def input_fn_train(num_epochs):
    return input_function(
        is_training=True, data_dir=flags_obj.data_dir,
        batch_size=distribution_utils.per_device_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=num_epochs,
        num_gpus=flags_core.get_num_gpus(flags_obj),
        dtype=flags_core.get_tf_dtype(flags_obj))

  def input_fn_eval():
    return input_function(
        is_training=False, data_dir=flags_obj.data_dir,
        batch_size=distribution_utils.per_device_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=1,
        dtype=flags_core.get_tf_dtype(flags_obj))

  if flags_obj.eval_only or not flags_obj.train_epochs:
    # If --eval_only is set, perform a single loop with zero train epochs.
    schedule, n_loops = [0], 1
  else:
    # Compute the number of times to loop while training. All but the last
    # pass will train for `epochs_between_evals` epochs, while the last will
    # train for the number needed to reach `training_epochs`. For instance if
    #   train_epochs = 25 and epochs_between_evals = 10
    # schedule will be set to [10, 10, 5]. That is to say, the loop will:
    #   Train for 10 epochs and then evaluate.
    #   Train for another 10 epochs and then evaluate.
    #   Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
    n_loops = math.ceil(flags_obj.train_epochs / flags_obj.epochs_between_evals)
    schedule = [flags_obj.epochs_between_evals for _ in range(int(n_loops))]
    schedule[-1] = flags_obj.train_epochs - sum(schedule[:-1])  # over counting.

  for cycle_index, num_train_epochs in enumerate(schedule):
    tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops))

    if num_train_epochs:
      classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
                       hooks=train_hooks, max_steps=flags_obj.max_train_steps)

    tf.logging.info('Starting to evaluate.')

    # flags_obj.max_train_steps is generally associated with testing and
    # profiling. As a result it is frequently called with synthetic data, which
    # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
    # eval (which is generally unimportant in those circumstances) to terminate.
    # Note that eval will run for max_train_steps each loop, regardless of the
    # global_step count.
    eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                       steps=flags_obj.max_train_steps)

    benchmark_logger.log_evaluation_result(eval_results)

    if model_helpers.past_stop_threshold(
        flags_obj.stop_threshold, eval_results['accuracy']):
      break

  if flags_obj.export_dir is not None:
    # Exports a saved model for the given classifier.
    export_dtype = flags_core.get_tf_dtype(flags_obj)
    if flags_obj.image_bytes_as_serving_input:
      input_receiver_fn = functools.partial(
          image_bytes_serving_input_fn, shape, dtype=export_dtype)
    else:
      input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
          shape, batch_size=flags_obj.batch_size, dtype=export_dtype)
    classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn,
                                 strip_default_attrs=True)