def train_and_eval(
        params: base_configs.ExperimentConfig,
        strategy_override: tf.distribute.Strategy) -> Mapping[str, Any]:
    """Runs the train and eval path using compile/fit."""
    logging.info('Running train and eval.')

    # Note: for TPUs, strategy and scope should be created before the dataset
    strategy = strategy_override or distribution_utils.get_distribution_strategy(
        distribution_strategy=params.runtime.distribution_strategy,
        all_reduce_alg=params.runtime.all_reduce_alg,
        num_gpus=params.runtime.num_gpus,
        tpu_address=params.runtime.tpu)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    logging.info('Detected %d devices.',
                 strategy.num_replicas_in_sync if strategy else 1)

    label_smoothing = params.model.loss.label_smoothing
    one_hot = label_smoothing and label_smoothing > 0

    builders = _get_dataset_builders(params, strategy, one_hot)
    datasets = [
        builder.build(strategy) if builder else None for builder in builders
    ]

    # Unpack datasets and builders based on train/val/test splits
    train_builder, validation_builder = builders  # pylint: disable=unbalanced-tuple-unpacking
    train_dataset, validation_dataset = datasets

    train_epochs = params.train.epochs
    train_steps = params.train.steps or train_builder.num_steps
    validation_steps = params.evaluation.steps or validation_builder.num_steps

    initialize(params, train_builder)

    logging.info('Global batch size: %d', train_builder.global_batch_size)

    with strategy_scope:
        model_params = params.model.model_params.as_dict()
        model = get_models()[params.model.name](**model_params)
        learning_rate = optimizer_factory.build_learning_rate(
            params=params.model.learning_rate,
            batch_size=train_builder.global_batch_size,
            train_steps=train_steps)
        optimizer = optimizer_factory.build_optimizer(
            optimizer_name=params.model.optimizer.name,
            base_learning_rate=learning_rate,
            params=params.model.optimizer.as_dict())

        metrics_map = _get_metrics(one_hot)
        metrics = [metrics_map[metric] for metric in params.train.metrics]

        if one_hot:
            loss_obj = tf.keras.losses.CategoricalCrossentropy(
                label_smoothing=params.model.loss.label_smoothing)
        else:
            loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
        model.compile(
            optimizer=optimizer,
            loss=loss_obj,
            metrics=metrics,
            experimental_steps_per_execution=params.train.steps_per_loop)

        initial_epoch = 0
        if params.train.resume_checkpoint:
            initial_epoch = resume_from_checkpoint(model=model,
                                                   model_dir=params.model_dir,
                                                   train_steps=train_steps)

        callbacks = custom_callbacks.get_callbacks(
            model_checkpoint=params.train.callbacks.
            enable_checkpoint_and_export,
            include_tensorboard=params.train.callbacks.enable_tensorboard,
            time_history=params.train.callbacks.enable_time_history,
            track_lr=params.train.tensorboard.track_lr,
            write_model_weights=params.train.tensorboard.write_model_weights,
            initial_step=initial_epoch * train_steps,
            batch_size=train_builder.global_batch_size,
            log_steps=params.train.time_history.log_steps,
            model_dir=params.model_dir)

    serialize_config(params=params, model_dir=params.model_dir)

    if params.evaluation.skip_eval:
        validation_kwargs = {}
    else:
        validation_kwargs = {
            'validation_data': validation_dataset,
            'validation_steps': validation_steps,
            'validation_freq': params.evaluation.epochs_between_evals,
        }

    history = model.fit(train_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        initial_epoch=initial_epoch,
                        callbacks=callbacks,
                        verbose=2,
                        **validation_kwargs)

    validation_output = None
    if not params.evaluation.skip_eval:
        validation_output = model.evaluate(validation_dataset,
                                           steps=validation_steps,
                                           verbose=2)

    # TODO(dankondratyuk): eval and save final test accuracy
    stats = common.build_stats(history, validation_output, callbacks)
    return stats
Пример #2
0
def run(flags_obj):
    """Run ResNet Cifar-10 training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_utils.set_gpu_thread_mode_and_count(
            per_gpu_thread_count=flags_obj.per_gpu_thread_count,
            gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
            num_gpus=flags_obj.num_gpus,
            datasets_num_private_threads=flags_obj.datasets_num_private_threads
        )
    common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'fp16':
        raise ValueError(
            'dtype fp16 is not supported in Keras. Use the default '
            'value(fp32).')

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.config.list_physical_devices('GPU') else
                       'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    if flags_obj.use_synthetic_data:
        synthetic_util.set_up_synthetic_data()
        input_fn = common.get_synth_input_fn(
            height=cifar_preprocessing.HEIGHT,
            width=cifar_preprocessing.WIDTH,
            num_channels=cifar_preprocessing.NUM_CHANNELS,
            num_classes=cifar_preprocessing.NUM_CLASSES,
            dtype=flags_core.get_tf_dtype(flags_obj),
            drop_remainder=True)
    else:
        synthetic_util.undo_set_up_synthetic_data()
        input_fn = cifar_preprocessing.input_fn

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        parse_record_fn=cifar_preprocessing.parse_record,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        # Setting drop_remainder to avoid the partial batch logic in normalization
        # layer, which triggers tf.where and leads to extra memory copy of input
        # sizes between host and GPU.
        drop_remainder=(not flags_obj.enable_get_next_as_optional))

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=flags_obj.batch_size,
            parse_record_fn=cifar_preprocessing.parse_record)

    steps_per_epoch = (cifar_preprocessing.NUM_IMAGES['train'] //
                       flags_obj.batch_size)
    lr_schedule = 0.1
    if flags_obj.use_tensor_lr:
        initial_learning_rate = common.BASE_LEARNING_RATE * flags_obj.batch_size / 128
        lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            boundaries=list(p[1] * steps_per_epoch for p in LR_SCHEDULE),
            values=[initial_learning_rate] + list(p[0] * initial_learning_rate
                                                  for p in LR_SCHEDULE))

    with strategy_scope:
        optimizer = common.get_optimizer(lr_schedule)
        model = resnet_cifar_model.resnet56(
            classes=cifar_preprocessing.NUM_CLASSES)
        model.compile(loss='sparse_categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=(['sparse_categorical_accuracy']
                               if flags_obj.report_accuracy_metrics else None),
                      run_eagerly=flags_obj.run_eagerly)

    train_epochs = flags_obj.train_epochs

    callbacks = common.get_callbacks(steps_per_epoch)

    if not flags_obj.use_tensor_lr:
        lr_callback = LearningRateBatchScheduler(
            schedule=learning_rate_schedule,
            batch_size=flags_obj.batch_size,
            steps_per_epoch=steps_per_epoch)
        callbacks.append(lr_callback)

    # if mutliple epochs, ignore the train_steps flag.
    if train_epochs <= 1 and flags_obj.train_steps:
        steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
        train_epochs = 1

    num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    if not strategy and flags_obj.explicit_gpu_placement:
        # TODO(b/135607227): Add device scope automatically in Keras training loop
        # when not using distribition strategy.
        no_dist_strat_device = tf.device('/device:GPU:0')
        no_dist_strat_device.__enter__()

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)
    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

    if not strategy and flags_obj.explicit_gpu_placement:
        no_dist_strat_device.__exit__()

    stats = common.build_stats(history, eval_output, callbacks)
    return stats
Пример #3
0
def run_train(flags_obj, datasets_override=None, strategy_override=None):
    strategy = strategy_override or distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        tpu_address=flags_obj.tpu)
    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    train_input_dataset, eval_input_dataset, tr_dataset, te_dataset = setup_datasets(
    )

    with strategy_scope:
        #lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        #    0.1, decay_steps=100000, decay_rate=0.96)
        #optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)

        model = build_model(mode='normal')

        losses = {"logits": "sparse_categorical_crossentropy"}
        lossWeights = {"logits": 1.0}
        model.compile(optimizer='sgd',
                      loss=losses,
                      loss_weights=lossWeights,
                      metrics=['sparse_categorical_accuracy'])

        num_train_examples = tr_dataset.num_examples_per_epoch()
        train_steps = num_train_examples // GB_OPTIONS.batch_size
        train_epochs = GB_OPTIONS.num_epochs

        if not hasattr(tr_dataset, "n_poison"):
            n_poison = 0
            n_cover = 0
        else:
            n_poison = tr_dataset.n_poison
            n_cover = tr_dataset.n_cover

        ckpt_full_path = os.path.join(
            GB_OPTIONS.checkpoint_folder,
            'model.ckpt-{epoch:04d}-p%d-c%d' % (n_poison, n_cover))
        callbacks = [
            tf.keras.callbacks.LearningRateScheduler(lr_scheduler),
            tf.keras.callbacks.ModelCheckpoint(ckpt_full_path,
                                               save_weights_only=True,
                                               save_best_only=False),
        ]

        num_eval_examples = te_dataset.num_examples_per_epoch()
        num_eval_steps = num_eval_examples // GB_OPTIONS.batch_size

        history = model.fit(train_input_dataset,
                            epochs=train_epochs,
                            steps_per_epoch=train_steps,
                            callbacks=callbacks,
                            validation_steps=num_eval_steps,
                            validation_data=eval_input_dataset,
                            validation_freq=flags_obj.epochs_between_evals)

        export_path = os.path.join(GB_OPTIONS.checkpoint_folder, 'saved_model')
        model.save(export_path, include_optimizer=False)

        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

        stats = common.build_stats(history, eval_output, callbacks)

        logging.info('Run stats:\n%s', stats)

        from utils import save_options_to_file
        save_options_to_file(GB_OPTIONS,
                             GB_OPTIONS.checkpoint_folder + 'config.json')

        return stats
Пример #4
0
def run(flags_obj):
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
    NotImplementedError: If some features are not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  keras_utils.set_session_config(
      enable_eager=flags_obj.enable_eager,
      enable_xla=flags_obj.enable_xla)

  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
    keras_utils.set_gpu_thread_mode_and_count(
        per_gpu_thread_count=flags_obj.per_gpu_thread_count,
        gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
        num_gpus=flags_obj.num_gpus,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads)
  common.set_cudnn_batchnorm_mode()

  dtype = flags_core.get_tf_dtype(flags_obj)
  performance.set_mixed_precision_policy(
      flags_core.get_tf_dtype(flags_obj),
      flags_core.get_loss_scale(flags_obj, default_for_fp16=128))

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
                   else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  # Configures cluster spec for distribution strategy.
  _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                           flags_obj.task_index)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs,
      tpu_address=flags_obj.tpu)

  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )

  strategy_scope = distribution_utils.get_strategy_scope(strategy)

  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
    input_fn = common.get_synth_input_fn(
        height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_preprocessing.NUM_CHANNELS,
        num_classes=imagenet_preprocessing.NUM_CLASSES,
        dtype=dtype,
        drop_remainder=True)
  else:
    input_fn = imagenet_preprocessing.input_fn

  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

  # Current resnet_model.resnet50 input format is always channel-last.
  # We use keras_application mobilenet model which input format is depends on
  # the keras beckend image data format.
  # This use_keras_image_data_format flags indicates whether image preprocessor
  # output format should be same as the keras backend image data format or just
  # channel-last format.
  use_keras_image_data_format = (flags_obj.model == 'mobilenet')
  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
          use_keras_image_data_format=use_keras_image_data_format),
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
      dtype=dtype,
      drop_remainder=drop_remainder,
      tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
      training_dataset_cache=flags_obj.training_dataset_cache,
  )

  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
        dtype=dtype,
        drop_remainder=drop_remainder)

  lr_schedule = common.PiecewiseConstantDecayWithWarmup(
      batch_size=flags_obj.batch_size,
      epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
      warmup_epochs=common.LR_SCHEDULE[0][1],
      boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
      multipliers=list(p[0] for p in common.LR_SCHEDULE),
      compute_lr_on_cpu=True)
  steps_per_epoch = (
      imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)

  with strategy_scope:
    if flags_obj.optimizer == 'resnet50_default':
      optimizer = common.get_optimizer(lr_schedule)
    elif flags_obj.optimizer == 'mobilenet_default':
      initial_learning_rate = \
          flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
      optimizer = tf.keras.optimizers.SGD(
          learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
              initial_learning_rate,
              decay_steps=steps_per_epoch * flags_obj.num_epochs_per_decay,
              decay_rate=flags_obj.lr_decay_factor,
              staircase=True),
          momentum=0.9)
    if flags_obj.fp16_implementation == 'graph_rewrite':
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          optimizer)

    # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
    if flags_obj.use_trivial_model:
      model = test_utils.trivial_model(imagenet_preprocessing.NUM_CLASSES)
    elif flags_obj.model == 'resnet50_v1.5':
      model = resnet_model.resnet50(
          num_classes=imagenet_preprocessing.NUM_CLASSES)
    elif flags_obj.model == 'mobilenet':
      # TODO(kimjaehong): Remove layers attribute when minimum TF version
      # support 2.0 layers by default.
      model = tf.keras.applications.mobilenet.MobileNet(
          weights=None,
          classes=imagenet_preprocessing.NUM_CLASSES,
          layers=tf.keras.layers)
    if flags_obj.pretrained_filepath:
      model.load_weights(flags_obj.pretrained_filepath)

    if flags_obj.pruning_method == 'polynomial_decay':
      if dtype != tf.float32:
        raise NotImplementedError(
            'Pruning is currently only supported on dtype=tf.float32.')
      pruning_params = {
          'pruning_schedule':
              tfmot.sparsity.keras.PolynomialDecay(
                  initial_sparsity=flags_obj.pruning_initial_sparsity,
                  final_sparsity=flags_obj.pruning_final_sparsity,
                  begin_step=flags_obj.pruning_begin_step,
                  end_step=flags_obj.pruning_end_step,
                  frequency=flags_obj.pruning_frequency),
      }
      model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
    elif flags_obj.pruning_method:
      raise NotImplementedError(
          'Only polynomial_decay is currently supported.')

    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=optimizer,
        metrics=(['sparse_categorical_accuracy']
                 if flags_obj.report_accuracy_metrics else None),
        run_eagerly=flags_obj.run_eagerly)

  train_epochs = flags_obj.train_epochs

  callbacks = common.get_callbacks(
      pruning_method=flags_obj.pruning_method,
      enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
      model_dir=flags_obj.model_dir)

  # if mutliple epochs, ignore the train_steps flag.
  if train_epochs <= 1 and flags_obj.train_steps:
    steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
    train_epochs = 1

  num_eval_steps = (
      imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size)

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None
    validation_data = None

  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
    # when not using distribition strategy.
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=steps_per_epoch,
                      callbacks=callbacks,
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      validation_freq=flags_obj.epochs_between_evals,
                      verbose=2)

  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

  if flags_obj.pruning_method:
    model = tfmot.sparsity.keras.strip_pruning(model)
  if flags_obj.enable_checkpoint_and_export:
    if dtype == tf.bfloat16:
      logging.warning('Keras model.save does not support bfloat16 dtype.')
    else:
      # Keras model.save assumes a float32 input designature.
      export_path = os.path.join(flags_obj.model_dir, 'saved_model')
      model.save(export_path, include_optimizer=False)

  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

  stats = common.build_stats(history, eval_output, callbacks)
  return stats
Пример #5
0
def run_train(flags_obj):
  keras_utils.set_session_config(
    enable_eager=flags_obj.enable_eager,
    enable_xla=flags_obj.enable_xla)

  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
    keras_utils.set_gpu_thread_mode_and_count(
      per_gpu_thread_count=flags_obj.per_gpu_thread_count,
      gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
      num_gpus=flags_obj.num_gpus,
      datasets_num_private_threads=flags_obj.datasets_num_private_threads)
  common.set_cudnn_batchnorm_mode()

  performance.set_mixed_precision_policy(
    flags_core.get_tf_dtype(flags_obj),
    flags_core.get_loss_scale(flags_obj, default_for_fp16=128))

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  # Configures cluster spec for distribution strategy.
  _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                           flags_obj.task_index)

  strategy = distribution_utils.get_distribution_strategy(
    distribution_strategy=flags_obj.distribution_strategy,
    num_gpus=flags_obj.num_gpus,
    all_reduce_alg=flags_obj.all_reduce_alg,
    num_packs=flags_obj.num_packs,
    tpu_address=flags_obj.tpu)

  if strategy:
  # flags_obj.enable_get_next_as_optional controls whether enabling
  # get_next_as_optional behavior in DistributedIterator. If true, last
  # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
      flags_obj.enable_get_next_as_optional
    )

  strategy_scope = distribution_utils.get_strategy_scope(strategy)

  distribution_utils.undo_set_up_synthetic_data()

  train_input_dataset, eval_input_dataset, tr_dataset, te_dataset = setup_datasets(flags_obj)

  lr_schedule = common.PiecewiseConstantDecayWithWarmup(
    batch_size=GB_OPTIONS.batch_size,
    epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
    warmup_epochs=common.LR_SCHEDULE[0][1],
    boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
    multipliers=list(p[0] for p in common.LR_SCHEDULE),
    compute_lr_on_cpu=True)
  steps_per_epoch = (imagenet_preprocessing.NUM_IMAGES['train'] // GB_OPTIONS.batch_size)

  with strategy_scope:
    optimizer = common.get_optimizer(lr_schedule)
    model = build_model(imagenet_preprocessing.NUM_CLASSES, mode='resnet50')

    if GB_OPTIONS.pretrained_filepath is not None:
      latest = tf.train.latest_checkpoint(GB_OPTIONS.pretrained_filepath)
      print(latest)
      model.load_weights(latest)

    #losses = ["sparse_categorical_crossentropy"]
    #lossWeights = [1.0]
    model.compile(
      optimizer=optimizer,
      loss="sparse_categorical_crossentropy",
      #loss_weights=lossWeights,
      metrics=['sparse_categorical_accuracy'])

    train_epochs = GB_OPTIONS.num_epochs

    if not hasattr(tr_dataset, "n_poison"):
      n_poison=0
      n_cover=0
    else:
      n_poison = tr_dataset.n_poison
      n_cover = tr_dataset.n_cover

    callbacks = common.get_callbacks(
      steps_per_epoch=steps_per_epoch,
      pruning_method=flags_obj.pruning_method,
      enable_checkpoint_and_export=False,
      model_dir=GB_OPTIONS.checkpoint_folder
    )
    ckpt_full_path = os.path.join(GB_OPTIONS.checkpoint_folder, 'model.ckpt-{epoch:04d}-p%d-c%d'%(n_poison,n_cover))
    callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_full_path, save_weights_only=True, save_best_only=True))

    num_eval_steps = imagenet_preprocessing.NUM_IMAGES['validation'] // GB_OPTIONS.batch_size

    if flags_obj.skip_eval:
      # Only build the training graph. This reduces memory usage introduced by
      # control flow ops in layers that have different implementations for
      # training and inference (e.g., batch norm).
      if flags_obj.set_learning_phase_to_train:
        # TODO(haoyuzhang): Understand slowdown of setting learning phase when
        # not using distribution strategy.
        tf.keras.backend.set_learning_phase(1)
      num_eval_steps = None
      eval_input_dataset = None

    history = model.fit(
      train_input_dataset,
      epochs=train_epochs,
      steps_per_epoch=steps_per_epoch,
      callbacks=callbacks,
      validation_steps=num_eval_steps,
      validation_data=eval_input_dataset,
      validation_freq=flags_obj.epochs_between_evals
    )

    export_path = os.path.join(GB_OPTIONS.checkpoint_folder, 'saved_model')
    model.save(export_path, include_optimizer=False)

    eval_output = model.evaluate(
      eval_input_dataset, steps=num_eval_steps, verbose=2
    )

    stats = common.build_stats(history, eval_output, callbacks)

    cmmd = 'cp config.py '+GB_OPTIONS.checkpoint_folder
    os.system(cmmd)

    return stats
Пример #6
0
def run(flags_obj, datasets_override=None, strategy_override=None):
    """Run MNIST model training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.
    datasets_override: A pair of `tf.data.Dataset` objects to train the model,
                       representing the train and test sets.
    strategy_override: A `tf.distribute.Strategy` object to use for model.

  Returns:
    Dictionary of training and eval stats.
  """
    # Start TF profiler server.
    tf.profiler.experimental.server.start(flags_obj.profiler_port)

    strategy = strategy_override or distribute_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        tpu_address=flags_obj.tpu)

    strategy_scope = distribute_utils.get_strategy_scope(strategy)

    mnist = tfds.builder('mnist', data_dir=flags_obj.data_dir)
    if flags_obj.download:
        mnist.download_and_prepare()

    mnist_train, mnist_test = datasets_override or mnist.as_dataset(
        split=['train', 'test'],
        decoders={'image': decode_image()},  # pylint: disable=no-value-for-parameter
        as_supervised=True)
    train_input_dataset = mnist_train.cache().repeat().shuffle(
        buffer_size=50000).batch(flags_obj.batch_size)
    eval_input_dataset = mnist_test.cache().repeat().batch(
        flags_obj.batch_size)

    with strategy_scope:
        lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            0.05, decay_steps=100000, decay_rate=0.96)
        optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)

        model = build_model()
        model.compile(optimizer=optimizer,
                      loss='sparse_categorical_crossentropy',
                      metrics=['sparse_categorical_accuracy'])

    num_train_examples = mnist.info.splits['train'].num_examples
    train_steps = num_train_examples // flags_obj.batch_size
    train_epochs = flags_obj.train_epochs

    ckpt_full_path = os.path.join(flags_obj.model_dir,
                                  'model.ckpt-{epoch:04d}')
    callbacks = [
        tf.keras.callbacks.ModelCheckpoint(ckpt_full_path,
                                           save_weights_only=True),
        tf.keras.callbacks.TensorBoard(log_dir=flags_obj.model_dir),
    ]

    num_eval_examples = mnist.info.splits['test'].num_examples
    num_eval_steps = num_eval_examples // flags_obj.batch_size

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=eval_input_dataset,
                        validation_freq=flags_obj.epochs_between_evals)

    export_path = os.path.join(flags_obj.model_dir, 'saved_model')
    model.save(export_path, include_optimizer=False)

    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

    stats = common.build_stats(history, eval_output, callbacks)
    return stats