def export_tfhub(model_path, hub_destination):
  """Restores a tf.keras.Model and saves for TF-Hub."""
  model = resnet_model.resnet50(
      num_classes=imagenet_preprocessing.NUM_CLASSES, rescale_inputs=True)
  model.load_weights(model_path)
  model.save(
      os.path.join(hub_destination, "classification"), include_optimizer=False)

  # Extracts a sub-model to use pooling feature vector as model output.
  image_input = model.get_layer(index=0).get_output_at(0)
  feature_vector_output = model.get_layer(name="reduce_mean").get_output_at(0)
  hub_model = tf.keras.Model(image_input, feature_vector_output)

  # Exports a SavedModel.
  hub_model.save(
      os.path.join(hub_destination, "feature-vector"), include_optimizer=False)
Beispiel #2
0
def build_model(num_classes, mode='normal', save_labels=False):
  if 'trivial' in mode:
    base_model = test_utils.trivial_model(num_classes)
    return base_model
  if 'resnet50' in mode:
    base_model = resnet_model.resnet50(num_classes)
  base_model = tf.keras.models.Model(inputs=base_model.input, outputs=base_model.layers[-3].output)
  x = tf.keras.layers.Dropout(0.5)(base_model.output)
  y = tf.keras.layers.Dense(
        256,
        kernel_initializer=initializers.RandomNormal(stddev=0.01),
        kernel_regularizer=_gen_l2_regularizer(),
        bias_regularizer=_gen_l2_regularizer(),
        name='features'
    )(x)

  if 'features' not in mode:
    probs = tf.keras.layers.Dense(
        num_classes,
        kernel_initializer=initializers.RandomNormal(stddev=0.01),
        kernel_regularizer=_gen_l2_regularizer(),
        bias_regularizer=_gen_l2_regularizer(),
        activation='softmax',
        name='logits'
    )(y)

    model = tf.keras.models.Model(inputs=base_model.input, outputs=probs, name='imagenet')
  else:
    model = tf.keras.models.Model(inputs=base_model.input, outputs=y, name='imagenet')


  if save_labels:
    label = tf.keras.layers.Input(shape=(),dtype=tf.int32)
    ori_label = tf.keras.layers.Input(shape=(),dtype=tf.int32)
    model = tf.keras.models.Model(inputs=[model.input,label,ori_label], outputs=[model.output,label,ori_label],name='imagenet')

  return model
Beispiel #3
0
def run(flags_obj):
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
    NotImplementedError: If some features are not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  keras_utils.set_session_config(
      enable_eager=flags_obj.enable_eager,
      enable_xla=flags_obj.enable_xla)

  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
    keras_utils.set_gpu_thread_mode_and_count(
        per_gpu_thread_count=flags_obj.per_gpu_thread_count,
        gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
        num_gpus=flags_obj.num_gpus,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads)
  common.set_cudnn_batchnorm_mode()

  dtype = flags_core.get_tf_dtype(flags_obj)
  performance.set_mixed_precision_policy(
      flags_core.get_tf_dtype(flags_obj),
      flags_core.get_loss_scale(flags_obj, default_for_fp16=128))

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
                   else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  # Configures cluster spec for distribution strategy.
  _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                           flags_obj.task_index)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs,
      tpu_address=flags_obj.tpu)

  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )

  strategy_scope = distribution_utils.get_strategy_scope(strategy)

  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
    input_fn = common.get_synth_input_fn(
        height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_preprocessing.NUM_CHANNELS,
        num_classes=imagenet_preprocessing.NUM_CLASSES,
        dtype=dtype,
        drop_remainder=True)
  else:
    input_fn = imagenet_preprocessing.input_fn

  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

  # Current resnet_model.resnet50 input format is always channel-last.
  # We use keras_application mobilenet model which input format is depends on
  # the keras beckend image data format.
  # This use_keras_image_data_format flags indicates whether image preprocessor
  # output format should be same as the keras backend image data format or just
  # channel-last format.
  use_keras_image_data_format = (flags_obj.model == 'mobilenet')
  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
          use_keras_image_data_format=use_keras_image_data_format),
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
      dtype=dtype,
      drop_remainder=drop_remainder,
      tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
      training_dataset_cache=flags_obj.training_dataset_cache,
  )

  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
        dtype=dtype,
        drop_remainder=drop_remainder)

  lr_schedule = common.PiecewiseConstantDecayWithWarmup(
      batch_size=flags_obj.batch_size,
      epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
      warmup_epochs=common.LR_SCHEDULE[0][1],
      boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
      multipliers=list(p[0] for p in common.LR_SCHEDULE),
      compute_lr_on_cpu=True)
  steps_per_epoch = (
      imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)

  with strategy_scope:
    if flags_obj.optimizer == 'resnet50_default':
      optimizer = common.get_optimizer(lr_schedule)
    elif flags_obj.optimizer == 'mobilenet_default':
      initial_learning_rate = \
          flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
      optimizer = tf.keras.optimizers.SGD(
          learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
              initial_learning_rate,
              decay_steps=steps_per_epoch * flags_obj.num_epochs_per_decay,
              decay_rate=flags_obj.lr_decay_factor,
              staircase=True),
          momentum=0.9)
    if flags_obj.fp16_implementation == 'graph_rewrite':
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          optimizer)

    # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
    if flags_obj.use_trivial_model:
      model = test_utils.trivial_model(imagenet_preprocessing.NUM_CLASSES)
    elif flags_obj.model == 'resnet50_v1.5':
      model = resnet_model.resnet50(
          num_classes=imagenet_preprocessing.NUM_CLASSES)
    elif flags_obj.model == 'mobilenet':
      # TODO(kimjaehong): Remove layers attribute when minimum TF version
      # support 2.0 layers by default.
      model = tf.keras.applications.mobilenet.MobileNet(
          weights=None,
          classes=imagenet_preprocessing.NUM_CLASSES,
          layers=tf.keras.layers)
    if flags_obj.pretrained_filepath:
      model.load_weights(flags_obj.pretrained_filepath)

    if flags_obj.pruning_method == 'polynomial_decay':
      if dtype != tf.float32:
        raise NotImplementedError(
            'Pruning is currently only supported on dtype=tf.float32.')
      pruning_params = {
          'pruning_schedule':
              tfmot.sparsity.keras.PolynomialDecay(
                  initial_sparsity=flags_obj.pruning_initial_sparsity,
                  final_sparsity=flags_obj.pruning_final_sparsity,
                  begin_step=flags_obj.pruning_begin_step,
                  end_step=flags_obj.pruning_end_step,
                  frequency=flags_obj.pruning_frequency),
      }
      model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
    elif flags_obj.pruning_method:
      raise NotImplementedError(
          'Only polynomial_decay is currently supported.')

    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=optimizer,
        metrics=(['sparse_categorical_accuracy']
                 if flags_obj.report_accuracy_metrics else None),
        run_eagerly=flags_obj.run_eagerly)

  train_epochs = flags_obj.train_epochs

  callbacks = common.get_callbacks(
      pruning_method=flags_obj.pruning_method,
      enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
      model_dir=flags_obj.model_dir)

  # if mutliple epochs, ignore the train_steps flag.
  if train_epochs <= 1 and flags_obj.train_steps:
    steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
    train_epochs = 1

  num_eval_steps = (
      imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size)

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None
    validation_data = None

  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
    # when not using distribition strategy.
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=steps_per_epoch,
                      callbacks=callbacks,
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      validation_freq=flags_obj.epochs_between_evals,
                      verbose=2)

  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

  if flags_obj.pruning_method:
    model = tfmot.sparsity.keras.strip_pruning(model)
  if flags_obj.enable_checkpoint_and_export:
    if dtype == tf.bfloat16:
      logging.warning('Keras model.save does not support bfloat16 dtype.')
    else:
      # Keras model.save assumes a float32 input designature.
      export_path = os.path.join(flags_obj.model_dir, 'saved_model')
      model.save(export_path, include_optimizer=False)

  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

  stats = common.build_stats(history, eval_output, callbacks)
  return stats
Beispiel #4
0
    def __init__(self, flags_obj, time_callback, epoch_steps):
        self.strategy = tf.distribute.get_strategy()
        self.flags_obj = flags_obj
        self.dtype = flags_core.get_tf_dtype(flags_obj)
        self.time_callback = time_callback

        # Input pipeline related
        batch_size = flags_obj.batch_size
        if batch_size % self.strategy.num_replicas_in_sync != 0:
            raise ValueError(
                'Batch size must be divisible by number of replicas : {}'.
                format(self.strategy.num_replicas_in_sync))

        # As auto rebatching is not supported in
        # `distribute_datasets_from_function()` API, which is
        # required when cloning dataset to multiple workers in eager mode,
        # we use per-replica batch size.
        self.batch_size = int(batch_size / self.strategy.num_replicas_in_sync)

        if self.flags_obj.use_synthetic_data:
            self.input_fn = common.get_synth_input_fn(
                height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
                width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
                num_channels=imagenet_preprocessing.NUM_CHANNELS,
                num_classes=imagenet_preprocessing.NUM_CLASSES,
                dtype=self.dtype,
                drop_remainder=True)
        else:
            self.input_fn = imagenet_preprocessing.input_fn

        self.model = resnet_model.resnet50(
            num_classes=imagenet_preprocessing.NUM_CLASSES,
            use_l2_regularizer=not flags_obj.single_l2_loss_op)

        lr_schedule = common.PiecewiseConstantDecayWithWarmup(
            batch_size=flags_obj.batch_size,
            epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
            warmup_epochs=common.LR_SCHEDULE[0][1],
            boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
            multipliers=list(p[0] for p in common.LR_SCHEDULE),
            compute_lr_on_cpu=True)
        self.optimizer = common.get_optimizer(lr_schedule)
        # Make sure iterations variable is created inside scope.
        self.global_step = self.optimizer.iterations

        use_graph_rewrite = flags_obj.fp16_implementation == 'graph_rewrite'
        if use_graph_rewrite and not flags_obj.use_tf_function:
            raise ValueError('--fp16_implementation=graph_rewrite requires '
                             '--use_tf_function to be true')
        self.optimizer = performance.configure_optimizer(
            self.optimizer,
            use_float16=self.dtype == tf.float16,
            use_graph_rewrite=use_graph_rewrite,
            loss_scale=flags_core.get_loss_scale(flags_obj,
                                                 default_for_fp16=128))

        self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
        self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'train_accuracy', dtype=tf.float32)
        self.test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
        self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'test_accuracy', dtype=tf.float32)

        self.checkpoint = tf.train.Checkpoint(model=self.model,
                                              optimizer=self.optimizer)

        # Handling epochs.
        self.epoch_steps = epoch_steps
        self.epoch_helper = orbit.utils.EpochHelper(epoch_steps,
                                                    self.global_step)
        train_dataset = orbit.utils.make_distributed_dataset(
            self.strategy,
            self.input_fn,
            is_training=True,
            data_dir=self.flags_obj.data_dir,
            batch_size=self.batch_size,
            parse_record_fn=imagenet_preprocessing.parse_record,
            datasets_num_private_threads=self.flags_obj.
            datasets_num_private_threads,
            dtype=self.dtype,
            drop_remainder=True)
        orbit.StandardTrainer.__init__(
            self,
            train_dataset,
            options=orbit.StandardTrainerOptions(
                use_tf_while_loop=flags_obj.use_tf_while_loop,
                use_tf_function=flags_obj.use_tf_function))
        if not flags_obj.skip_eval:
            eval_dataset = orbit.utils.make_distributed_dataset(
                self.strategy,
                self.input_fn,
                is_training=False,
                data_dir=self.flags_obj.data_dir,
                batch_size=self.batch_size,
                parse_record_fn=imagenet_preprocessing.parse_record,
                dtype=self.dtype)
            orbit.StandardEvaluator.__init__(
                self,
                eval_dataset,
                options=orbit.StandardEvaluatorOptions(
                    use_tf_function=flags_obj.use_tf_function))
def main(_):
  flags_obj = flags.FLAGS

  # get rank and comm_size
  rank = tnt.get_rank()
  comm_size = tnt.get_size()

  # compute micro batch if the dataset is not automatically distributed by Tarantella
  if not flags_obj.auto_distributed:
    batch_size = flags_obj.batch_size // comm_size
  else:
    batch_size = flags_obj.batch_size

  # Load and preprocess datasets
  train_dataset = imagenet_preprocessing.input_fn(is_training=True,
                                                  data_dir=flags_obj.data_dir,
                                                  batch_size=batch_size,
                                                  shuffle_seed = 42,
                                                  drop_remainder=True,
                                                  rank=rank,
                                                  comm_size= comm_size,
                                                  auto_distributed=flags_obj.auto_distributed)
  validation_dataset = imagenet_preprocessing.input_fn(is_training=False,
                                                       data_dir=flags_obj.data_dir,
                                                       batch_size=batch_size,
                                                       shuffle_seed = 42,
                                                       drop_remainder=True,
                                                       rank=rank,
                                                       comm_size=comm_size,
                                                       auto_distributed=flags_obj.auto_distributed)

  # Create model and wrap it into a Tarantella model
  model = resnet_model.resnet50(num_classes=imagenet_preprocessing.NUM_CLASSES)
  model = tnt.Model(model)

  optimizer = get_optimizer(flags_obj.batch_size)
  model.compile(optimizer=optimizer,
                loss='sparse_categorical_crossentropy',
                metrics=(['sparse_categorical_accuracy']))
  model.summary()

  callbacks = []
  if flags_obj.enable_tensorboard:
    callbacks.append(tf.keras.callbacks.TensorBoard(log_dir=flags_obj.model_dir,
                                                    profile_batch=2))
  if flags_obj.profile_runtime:
    callbacks.append(RuntimeProfiler(batch_size = batch_size,
                                    logging_freq = flags_obj.logging_freq,
                                    print_freq = flags_obj.print_freq))

  if flags_obj.enable_checkpoint_and_export:
    if flags_obj.model_dir is not None:
      ckpt_full_path = os.path.join(flags_obj.model_dir, 'model.ckpt-{epoch:04d}')
      callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_full_path, save_weights_only=True))

  logging.info("Start training")
  kwargs = {'tnt_distribute_dataset': flags_obj.auto_distributed,
            'tnt_distribute_validation_dataset': flags_obj.auto_distributed}
  history = model.fit(train_dataset,
                      epochs=flags_obj.train_epochs,
                      callbacks=callbacks,
                      validation_data=validation_dataset,
                      validation_freq=flags_obj.epochs_between_evals,
                      verbose=flags_obj.verbose,
                      **kwargs)
  logging.info("Train history: {}".format(history.history))

  kwargs = {'tnt_distribute_dataset': flags_obj.auto_distributed}
  eval_output = model.evaluate(validation_dataset,
                               verbose=flags_obj.verbose,
                               **kwargs)
Beispiel #6
0
def load_resnet(batch_size, num_classes):
    """Loads the ResNet model from TF Model Garden."""
    return resnet_model.resnet50(batch_size=batch_size,
                                 num_classes=num_classes)