Esempio n. 1
0
def run(flags_obj):
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
  # Eager is default in tf 2.0 and should not be toggled
  if keras_common.is_v2_0():
    keras_common.set_config_v2()
  else:
    config = keras_common.get_config_proto_v1()
    if flags_obj.enable_eager:
      tf.compat.v1.enable_eager_execution(config=config)
    else:
      sess = tf.Session(config=config)
      tf.keras.backend.set_session(sess)

  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
    keras_common.set_gpu_thread_mode_and_count(flags_obj)
  if flags_obj.data_prefetch_with_slack:
    keras_common.data_prefetch_with_slack()
  keras_common.set_cudnn_batchnorm_mode()

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == 'float16':
    policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
    tf.keras.mixed_precision.experimental.set_policy(policy)

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
      num_workers=distribution_utils.configure_cluster(),
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs)

  strategy_scope = distribution_utils.get_strategy_scope(strategy)

  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
    distribution_utils.set_up_synthetic_data()
    input_fn = keras_common.get_synth_input_fn(
        height=imagenet_main.DEFAULT_IMAGE_SIZE,
        width=imagenet_main.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_main.NUM_CHANNELS,
        num_classes=imagenet_main.NUM_CLASSES,
        dtype=dtype,
        drop_remainder=True)
  else:
    distribution_utils.undo_set_up_synthetic_data()
    input_fn = imagenet_main.input_fn

  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras,
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
      dtype=dtype,
      drop_remainder=drop_remainder)

  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=parse_record_keras,
        dtype=dtype,
        drop_remainder=drop_remainder)

  lr_schedule = 0.1
  if flags_obj.use_tensor_lr:
    lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup(
        batch_size=flags_obj.batch_size,
        epoch_size=imagenet_main.NUM_IMAGES['train'],
        warmup_epochs=LR_SCHEDULE[0][1],
        boundaries=list(p[1] for p in LR_SCHEDULE[1:]),
        multipliers=list(p[0] for p in LR_SCHEDULE),
        compute_lr_on_cpu=True)

  with strategy_scope:
    optimizer = keras_common.get_optimizer(lr_schedule)
    if dtype == 'float16':
      # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
      # can be enabled with a single line of code.
      optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
          optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))

    if flags_obj.enable_xla and not flags_obj.enable_eager:
      # TODO(b/129861005): Fix OOM issue in eager mode when setting
      # `batch_size` in keras.Input layer.
      if strategy and strategy.num_replicas_in_sync > 1:
        # TODO(b/129791381): Specify `input_layer_batch_size` value in
        # DistributionStrategy multi-replica case.
        input_layer_batch_size = None
      else:
        input_layer_batch_size = flags_obj.batch_size
    else:
      input_layer_batch_size = None

    if flags_obj.use_trivial_model:
      model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES, dtype)
    else:
      model = resnet_model.resnet50(
          num_classes=imagenet_main.NUM_CLASSES,
          dtype=dtype,
          batch_size=input_layer_batch_size)

    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer=optimizer,
                  metrics=(['sparse_categorical_accuracy']
                           if flags_obj.report_accuracy_metrics else None),
                  cloning=flags_obj.clone_model_in_keras_dist_strat)

  callbacks = keras_common.get_callbacks(
      learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])

  train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

  num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
                    flags_obj.batch_size)

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
    tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None
    validation_data = None

  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
                      callbacks=callbacks,
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      validation_freq=flags_obj.epochs_between_evals,
                      verbose=2)

  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)
  stats = keras_common.build_stats(history, eval_output, callbacks)
  return stats
def run_customized_training_loop(
        # pylint: disable=invalid-name
        _sentinel=None,
        # pylint: enable=invalid-name
        strategy=None,
        model_fn=None,
        loss_fn=None,
        scale_loss=True,
        model_dir=None,
        train_input_fn=None,
        steps_per_epoch=None,
        steps_per_loop=1,
        epochs=1,
        eval_input_fn=None,
        eval_steps=None,
        metric_fn=None,
        init_checkpoint=None,
        custom_callbacks=None,
        run_eagerly=False,
        sub_model_export_name=None,
        explicit_allreduce=False,
        pre_allreduce_callbacks=None,
        post_allreduce_callbacks=None,
        train_summary_interval=0):
    """Run BERT pretrain model training using low-level API.

  Arguments:
      _sentinel: Used to prevent positional parameters. Internal, do not use.
      strategy: Distribution strategy on which to run low level training loop.
      model_fn: Function that returns a tuple (model, sub_model). Caller of this
        function should add optimizer to the `model` via calling
        `model.compile()` API or manually setting `model.optimizer` attribute.
        Second element of the returned tuple(sub_model) is an optional sub model
        to be used for initial checkpoint -- if provided.
      loss_fn: Function with signature func(labels, logits) and returns a loss
        tensor.
      scale_loss: Whether to divide the raw loss by number of replicas before
        gradients calculation.
      model_dir: Model directory used during training for restoring/saving model
        weights.
      train_input_fn: Function that returns a tf.data.Dataset used for training.
      steps_per_epoch: Number of steps to run per epoch. At the end of each
        epoch, model checkpoint will be saved and evaluation will be conducted
        if evaluation dataset is provided.
      steps_per_loop: Number of steps per graph-mode loop. In order to reduce
        communication in eager context, training logs are printed every
        steps_per_loop.
      epochs: Number of epochs to train.
      eval_input_fn: Function that returns evaluation dataset. If none,
        evaluation is skipped.
      eval_steps: Number of steps to run evaluation. Required if `eval_input_fn`
        is not none.
      metric_fn: A metrics function that returns a Keras Metric object to record
        evaluation result using evaluation dataset or with training dataset
        after every epoch.
      init_checkpoint: Optional checkpoint to load to `sub_model` returned by
        `model_fn`.
      custom_callbacks: A list of Keras Callbacks objects to run during
        training. More specifically, `on_batch_begin()`, `on_batch_end()`,
        `on_epoch_begin()`, `on_epoch_end()` methods are invoked during
        training.  Note that some metrics may be missing from `logs`.
      run_eagerly: Whether to run model training in pure eager execution. This
        should be disable for TPUStrategy.
      sub_model_export_name: If not None, will export `sub_model` returned by
        `model_fn` into checkpoint files. The name of intermediate checkpoint
        file is {sub_model_export_name}_step_{step}.ckpt and the last
        checkpint's name is {sub_model_export_name}.ckpt;
        if None, `sub_model` will not be exported as checkpoint.
      explicit_allreduce: Whether to explicitly perform gradient allreduce,
        instead of relying on implicit allreduce in optimizer.apply_gradients().
        default is False. For now, if training using FP16 mixed precision,
        explicit allreduce will aggregate gradients in FP16 format. For TPU and
        GPU training using FP32, explicit allreduce will aggregate gradients in
        FP32 format.
      pre_allreduce_callbacks: A list of callback functions that takes gradients
        and model variables pairs as input, manipulate them, and returns a new
        gradients and model variables paris. The callback functions will be
        invoked in the list order and before gradients are allreduced.
        With mixed precision training, the pre_allreduce_allbacks will be
        applied on scaled_gradients. Default is no callbacks.
        Only used when explicit_allreduce=True.
      post_allreduce_callbacks: A list of callback functions that takes
        gradients and model variables pairs as input, manipulate them, and
        returns a new gradients and model variables paris. The callback
        functions will be invoked in the list order and right before gradients
        are applied to variables for updates. Default is no callbacks. Only used
        when explicit_allreduce=True.
      train_summary_interval: Step interval for training summaries. If the value
        is a negative number, then training summaries are not enabled.

  Returns:
      Trained model.

  Raises:
      ValueError: (1) When model returned by `model_fn` does not have optimizer
        attribute or when required parameters are set to none. (2) eval args are
        not specified correctly. (3) metric_fn must be a callable if specified.
        (4) sub_model_checkpoint_name is specified, but `sub_model` returned
        by `model_fn` is None.
  """

    if _sentinel is not None:
        raise ValueError('only call `run_customized_training_loop()` '
                         'with named arguments.')

    required_arguments = [
        strategy, model_fn, loss_fn, model_dir, steps_per_epoch, train_input_fn
    ]
    if [arg for arg in required_arguments if arg is None]:
        raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, '
                         '`steps_per_loop` and `steps_per_epoch` are required '
                         'parameters.')
    if steps_per_loop > steps_per_epoch:
        logging.error(
            'steps_per_loop: %d is specified to be greater than '
            ' steps_per_epoch: %d, we will use steps_per_epoch as'
            ' steps_per_loop.', steps_per_loop, steps_per_epoch)
        steps_per_loop = steps_per_epoch
    assert tf.executing_eagerly()

    if run_eagerly:
        if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
            raise ValueError(
                'TPUStrategy should not run eagerly as it heavily relies on graph'
                ' optimization for the distributed system.')

    if eval_input_fn and (eval_steps is None or metric_fn is None):
        raise ValueError(
            '`eval_step` and `metric_fn` are required when `eval_input_fn ` '
            'is not none.')
    if metric_fn and not callable(metric_fn):
        raise ValueError(
            'if `metric_fn` is specified, metric_fn must be a callable.')

    callback_list = tf.keras.callbacks.CallbackList(custom_callbacks)

    total_training_steps = steps_per_epoch * epochs
    train_iterator = _get_input_iterator(train_input_fn, strategy)

    with distribution_utils.get_strategy_scope(strategy):
        # To correctly place the model weights on accelerators,
        # model and optimizer should be created in scope.
        model, sub_model = model_fn()
        if not hasattr(model, 'optimizer'):
            raise ValueError('User should set optimizer attribute to model '
                             'inside `model_fn`.')
        if sub_model_export_name and sub_model is None:
            raise ValueError('sub_model_export_name is specified as %s, but '
                             'sub_model is None.' % sub_model_export_name)

        optimizer = model.optimizer

        if init_checkpoint:
            logging.info(
                'Checkpoint file %s found and restoring from '
                'initial checkpoint for core model.', init_checkpoint)
            checkpoint = tf.train.Checkpoint(model=sub_model)
            checkpoint.restore(
                init_checkpoint).assert_existing_objects_matched()
            logging.info('Loading from checkpoint file completed')

        train_loss_metric = tf.keras.metrics.Mean('training_loss',
                                                  dtype=tf.float32)
        eval_metrics = [metric_fn()] if metric_fn else []
        # If evaluation is required, make a copy of metric as it will be used by
        # both train and evaluation.
        train_metrics = [
            metric.__class__.from_config(metric.get_config())
            for metric in eval_metrics
        ]

        # Create summary writers
        if _should_export_summary(strategy):
            summary_dir = os.path.join(model_dir, 'summaries')
        else:
            # In multi worker training we need every worker to write summary, because
            # variables can trigger synchronization on read and synchronization needs
            # all workers to participate.
            summary_dir = tempfile.mkdtemp()
        eval_summary_writer = tf.summary.create_file_writer(
            os.path.join(summary_dir, 'eval'))
        last_summary_step = 0
        if steps_per_loop >= _MIN_SUMMARY_STEPS and train_summary_interval >= 0:
            # Only writes summary when the stats are collected sufficiently over
            # enough steps.
            train_summary_writer = tf.summary.create_file_writer(
                os.path.join(summary_dir, 'train'))
        else:
            train_summary_writer = tf.summary.create_noop_writer()

        # Collects training variables.
        training_vars = model.trainable_variables

        def _replicated_step(inputs):
            """Replicated training step."""

            inputs, labels = inputs
            with tf.GradientTape() as tape:
                model_outputs = model(inputs, training=True)
                loss = loss_fn(labels, model_outputs)
                # Raw loss is used for reporting in metrics/logs.
                raw_loss = loss
                if scale_loss:
                    # Scales down the loss for gradients to be invariant from replicas.
                    loss = loss / strategy.num_replicas_in_sync

            if explicit_allreduce:
                grad_utils.minimize_using_explicit_allreduce(
                    tape, optimizer, loss, training_vars,
                    pre_allreduce_callbacks, post_allreduce_callbacks)
            else:
                if isinstance(
                        optimizer, tf.keras.mixed_precision.experimental.
                        LossScaleOptimizer):
                    with tape:
                        scaled_loss = optimizer.get_scaled_loss(loss)
                    scaled_grads = tape.gradient(scaled_loss, training_vars)
                    grads = optimizer.get_unscaled_gradients(scaled_grads)
                else:
                    grads = tape.gradient(loss, training_vars)
                optimizer.apply_gradients(zip(grads, training_vars))
            # For reporting, the metric takes the mean of losses.
            train_loss_metric.update_state(raw_loss)
            for metric in train_metrics:
                metric.update_state(labels, model_outputs)

        @tf.function
        def train_steps(iterator, steps):
            """Performs distributed training steps in a loop.

      Args:
        iterator: the distributed iterator of training datasets.
        steps: an tf.int32 integer tensor to specify number of steps to run
          inside host training loop.

      Raises:
        ValueError: Any of the arguments or tensor shapes are invalid.
      """
            if not isinstance(steps, tf.Tensor):
                raise ValueError(
                    'steps should be an Tensor. Python object may cause '
                    'retracing.')

            for _ in tf.range(steps):
                strategy.run(_replicated_step, args=(next(iterator), ))

        def train_single_step(iterator):
            """Performs a distributed training step.

      Args:
        iterator: the distributed iterator of training datasets.

      Raises:
        ValueError: Any of the arguments or tensor shapes are invalid.
      """
            strategy.run(_replicated_step, args=(next(iterator), ))

        def test_step(iterator):
            """Calculates evaluation metrics on distributed devices."""
            def _test_step_fn(inputs):
                """Replicated accuracy calculation."""

                inputs, labels = inputs
                model_outputs = model(inputs, training=False)
                for metric in eval_metrics:
                    metric.update_state(labels, model_outputs)

            strategy.run(_test_step_fn, args=(next(iterator), ))

        if not run_eagerly:
            train_single_step = tf.function(train_single_step)
            test_step = tf.function(test_step)

        def _run_evaluation(current_training_step, test_iterator):
            """Runs validation steps and aggregate metrics.

      Args:
        current_training_step: tf.int32 tensor containing the current step.
        test_iterator: distributed iterator of test datasets.

      Returns:
        A dict of metic names and values.
      """
            for _ in range(eval_steps):
                test_step(test_iterator)

            logs = {}
            with eval_summary_writer.as_default():
                for metric in eval_metrics + model.metrics:
                    metric_value = _float_metric_value(metric)
                    logs[metric.name] = metric_value
                    logging.info('Step: [%d] Validation %s = %f',
                                 current_training_step, metric.name,
                                 metric_value)
                    tf.summary.scalar(metric.name,
                                      metric_value,
                                      step=current_training_step)
                eval_summary_writer.flush()

            return logs

        # Training loop starts here.
        checkpoint = tf.train.Checkpoint(model=model,
                                         optimizer=optimizer,
                                         global_step=optimizer.iterations)
        sub_model_checkpoint = tf.train.Checkpoint(
            model=sub_model, global_step=optimizer.iterations
        ) if sub_model_export_name else None

        latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
        if latest_checkpoint_file:
            logging.info(
                'Checkpoint file %s found and restoring from '
                'checkpoint', latest_checkpoint_file)
            checkpoint.restore(latest_checkpoint_file)
            logging.info('Loading from checkpoint file completed')

        current_step = optimizer.iterations.numpy()
        checkpoint_name = 'ctl_step_{step}.ckpt'

        while current_step < total_training_steps:
            if current_step % steps_per_epoch == 0:
                callback_list.on_epoch_begin(
                    int(current_step / steps_per_epoch) + 1)

            # Training loss/metric are taking average over steps inside micro
            # training loop. We reset the their values before each round.
            train_loss_metric.reset_states()
            for metric in train_metrics + model.metrics:
                metric.reset_states()

            callback_list.on_batch_begin(current_step)
            # Runs several steps in the host while loop.
            steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop)

            if tf.config.list_physical_devices('GPU'):
                # TODO(zongweiz): merge with train_steps once tf.while_loop
                # GPU performance bugs are fixed.
                for _ in range(steps):
                    train_single_step(train_iterator)
            else:
                # Converts steps to a Tensor to avoid tf.function retracing.
                train_steps(train_iterator,
                            tf.convert_to_tensor(steps, dtype=tf.int32))
            train_loss = _float_metric_value(train_loss_metric)
            current_step += steps
            callback_list.on_batch_end(current_step - 1, {'loss': train_loss})

            # Updates training logging.
            training_status = 'Train Step: %d/%d  / loss = %s' % (
                current_step, total_training_steps, train_loss)

            if current_step >= last_summary_step + train_summary_interval:
                summary_writer = train_summary_writer
                last_summary_step = current_step
            else:
                summary_writer = tf.summary.create_noop_writer()

            with summary_writer.as_default():
                tf.summary.scalar(train_loss_metric.name,
                                  train_loss,
                                  step=current_step)
                for metric in train_metrics + model.metrics:
                    metric_value = _float_metric_value(metric)
                    training_status += '  %s = %f' % (metric.name,
                                                      metric_value)
                    tf.summary.scalar(metric.name,
                                      metric_value,
                                      step=current_step)
                summary_writer.flush()
            logging.info(training_status)

            if current_step % steps_per_epoch == 0:
                # Save a submodel with the step in the file name after each epoch.
                if sub_model_export_name:
                    _save_checkpoint(
                        strategy, sub_model_checkpoint, model_dir,
                        '%s_step_%d.ckpt' %
                        (sub_model_export_name, current_step))

                # Save model checkpoints and run validation steps after each epoch
                # (with the exception of the final epoch which is handled after the
                # training loop).
                if current_step < total_training_steps:
                    _save_checkpoint(strategy, checkpoint, model_dir,
                                     checkpoint_name.format(step=current_step))
                    logs = None
                    if eval_input_fn:
                        logging.info('Running evaluation after step: %s.',
                                     current_step)
                        logs = _run_evaluation(
                            current_step,
                            _get_input_iterator(eval_input_fn, strategy))
                        # Re-initialize evaluation metric.
                        for metric in eval_metrics + model.metrics:
                            metric.reset_states()

                    callback_list.on_epoch_end(
                        int(current_step / steps_per_epoch), logs)

        if sub_model_export_name:
            _save_checkpoint(strategy, sub_model_checkpoint, model_dir,
                             '%s.ckpt' % sub_model_export_name)

        _save_checkpoint(strategy, checkpoint, model_dir,
                         checkpoint_name.format(step=current_step))
        logs = None
        if eval_input_fn:
            logging.info(
                'Running final evaluation after training is complete.')
            logs = _run_evaluation(
                current_step, _get_input_iterator(eval_input_fn, strategy))

        callback_list.on_epoch_end(int(current_step / steps_per_epoch), logs)

        training_summary = {
            'total_training_steps': total_training_steps,
            'train_loss': _float_metric_value(train_loss_metric),
        }
        for metric in model.metrics:
            training_summary[metric.name] = _float_metric_value(metric)
        if eval_metrics:
            # TODO(hongkuny): Cleans up summary reporting in text.
            training_summary['last_train_metrics'] = _float_metric_value(
                train_metrics[0])
            training_summary['eval_metrics'] = _float_metric_value(
                eval_metrics[0])

        write_txt_summary(training_summary, summary_dir)

        if not _should_export_summary(strategy):
            tf.io.gfile.rmtree(summary_dir)

        return model
def run(flags_obj, datasets_override=None, strategy_override=None):
    """Run Malaria model training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.
    datasets_override: A pair of `tf.data.Dataset` objects to train the model,
                       representing the train and test sets.
    strategy_override: A `tf.distribute.Strategy` object to use for model.

  Returns:
    Dictionary of training and eval stats.
  """
    strategy = strategy_override or distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        tpu_address=flags_obj.tpu)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    train_ds, val_ds = tfds.load('malaria',
                                 split=['train[:5%]', 'train[90%:92%]'],
                                 shuffle_files=True,
                                 download=False,
                                 as_supervised=True,
                                 data_dir=flags_obj.data_dir)

    padded_train_ds = (train_ds.cache().map(pad).batch(BATCH_SIZE))

    padded_val_ds = (val_ds.cache().map(pad).batch(BATCH_SIZE))

    ckpt_full_path = os.path.join(flags_obj.model_dir,
                                  'model.ckpt-{epoch:04d}')
    with strategy_scope:
        lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            0.05, decay_steps=100000, decay_rate=0.96)
        checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(ckpt_full_path,
                                                           save_best_only=True)
        #optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)

        model = build_model()
        model.compile(optimizer='adam',
                      loss='binary_crossentropy',
                      metrics=['accuracy'])

    train_steps = train_length // flags_obj.batch_size
    train_epochs = flags_obj.train_epochs

    #num_eval_examples = val_ds.info.splits['train'].num_examples
    num_eval_steps = test_length // flags_obj.batch_size

    callbacks = [checkpoint_cb]

    history = model.fit(padded_train_ds,
                        epochs=3,
                        callbacks=callbacks,
                        validation_data=padded_val_ds)

    export_path = os.path.join(flags_obj.model_dir, 'saved_model')
    model.save(export_path, include_optimizer=False)

    eval_output = model.evaluate(padded_val_ds,
                                 steps=num_eval_steps,
                                 verbose=2)

    stats = common.build_stats(history, eval_output, callbacks)
    return stats
Esempio n. 4
0
  def train(self):
    """Trains the model."""
    params = self.params
    flags_obj = self.flags_obj
    # Sets config options.
    keras_utils.set_session_config(
        enable_xla=flags_obj.enable_xla)

    _ensure_dir(flags_obj.model_dir)
    with distribution_utils.get_strategy_scope(self.distribution_strategy):
      model = transformer.create_model(params, is_train=True)
      opt = self._create_optimizer()

      current_step = 0
      checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)
      latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir)
      if latest_checkpoint:
        checkpoint.restore(latest_checkpoint)
        logging.info("Loaded checkpoint %s", latest_checkpoint)
        current_step = opt.iterations.numpy()

      if params["use_ctl"]:
        train_loss_metric = tf.keras.metrics.Mean(
            "training_loss", dtype=tf.float32)
      else:
        model.compile(opt)

    model.summary()

    if self.use_tpu:
      # Different from experimental_distribute_dataset,
      # experimental_distribute_datasets_from_function requires
      # per-replica/local batch size.
      params["batch_size"] /= self.distribution_strategy.num_replicas_in_sync
      train_ds = (
          self.distribution_strategy
          .experimental_distribute_datasets_from_function(
              lambda ctx: data_pipeline.train_input_fn(params, ctx)))
    else:
      train_ds = data_pipeline.train_input_fn(params)
      map_data_fn = data_pipeline.map_data_for_transformer_fn
      train_ds = train_ds.map(
          map_data_fn, num_parallel_calls=params["num_parallel_calls"])
    if params["use_ctl"]:
      train_ds_iterator = iter(train_ds)

    callbacks = self._create_callbacks(flags_obj.model_dir, 0, params)

    # TODO(b/139418525): Refactor the custom training loop logic.
    @tf.function
    def train_steps(iterator, steps):
      """Training steps function for TPU runs.

      Args:
        iterator: The input iterator of the training dataset.
        steps: An integer, the number of training steps.

      Returns:
        A float, the loss value.
      """

      def _step_fn(inputs):
        """Per-replica step function."""
        inputs, targets = inputs
        with tf.GradientTape() as tape:
          logits = model([inputs, targets], training=True)
          loss = metrics.transformer_loss(logits, targets,
                                          params["label_smoothing"],
                                          params["vocab_size"])
          # Scales the loss, which results in using the average loss across all
          # of the replicas for backprop.
          scaled_loss = loss / self.distribution_strategy.num_replicas_in_sync

        # De-dupes variables due to keras tracking issues.
        tvars = list({id(v): v for v in model.trainable_variables}.values())
        grads = tape.gradient(scaled_loss, tvars)
        opt.apply_gradients(zip(grads, tvars))
        # For reporting, the metric takes the mean of losses.
        train_loss_metric.update_state(loss)

      for _ in tf.range(steps):
        train_loss_metric.reset_states()
        self.distribution_strategy.experimental_run_v2(
            _step_fn, args=(next(iterator),))

    cased_score, uncased_score = None, None
    cased_score_history, uncased_score_history = [], []
    while current_step < flags_obj.train_steps:
      remaining_steps = flags_obj.train_steps - current_step
      train_steps_per_eval = (
          remaining_steps if remaining_steps < flags_obj.steps_between_evals
          else flags_obj.steps_between_evals)
      current_iteration = current_step // flags_obj.steps_between_evals

      logging.info(
          "Start train iteration at global step:{}".format(current_step))
      history = None
      if params["use_ctl"]:
        if not self.use_tpu:
          raise NotImplementedError(
              "Custom training loop on GPUs is not implemented.")
        # Runs training steps.
        train_steps(train_ds_iterator,
                    tf.convert_to_tensor(train_steps_per_eval, dtype=tf.int32))
        current_step += train_steps_per_eval
        train_loss = train_loss_metric.result().numpy().astype(float)
        logging.info("Train Step: %d/%d / loss = %s",
                     current_step, flags_obj.train_steps, train_loss)

        checkpoint_name = checkpoint.save(
            os.path.join(
                flags_obj.model_dir,
                "ctl_step_{}.ckpt".format(current_step)))
        logging.info("Saved checkpoint to %s", checkpoint_name)
      else:
        if self.use_tpu:
          raise NotImplementedError(
              "Keras model.fit on TPUs is not implemented.")
        history = model.fit(
            train_ds,
            initial_epoch=current_iteration,
            epochs=current_iteration + 1,
            steps_per_epoch=train_steps_per_eval,
            callbacks=callbacks,
            # If TimeHistory is enabled, progress bar would be messy. Increase
            # the verbose level to get rid of it.
            verbose=(2 if flags_obj.enable_time_history else 1))
        current_step += train_steps_per_eval
        logging.info("Train history: {}".format(history.history))

      logging.info("End train iteration at global step:{}".format(current_step))

      if (flags_obj.bleu_source and flags_obj.bleu_ref):
        uncased_score, cased_score = self.eval()
        cased_score_history.append([current_iteration + 1, cased_score])
        uncased_score_history.append([current_iteration + 1, uncased_score])

    stats = ({
        "loss": train_loss
    } if history is None else misc.build_stats(history, callbacks))
    if uncased_score and cased_score:
      stats["bleu_uncased"] = uncased_score
      stats["bleu_cased"] = cased_score
      stats["bleu_uncased_history"] = uncased_score_history
      stats["bleu_cased_history"] = cased_score_history
    return stats
Esempio n. 5
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
    # Eager is default in tf 2.0 and should not be toggled
    if keras_common.is_v2_0():
        keras_common.set_config_v2()
    else:
        config = keras_common.get_config_proto_v1()
        if flags_obj.enable_eager:
            tf.compat.v1.enable_eager_execution(config=config)
        else:
            sess = tf.Session(config=config)
            tf.keras.backend.set_session(sess)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_common.set_gpu_thread_mode_and_count(flags_obj)

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'float16':
        policy = tf.keras.mixed_precision.experimental.Policy(
            'infer_float32_vars')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=distribution_utils.configure_cluster())

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    # pylint: disable=protected-access
    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = keras_common.get_synth_input_fn(
            height=imagenet_main.DEFAULT_IMAGE_SIZE,
            width=imagenet_main.DEFAULT_IMAGE_SIZE,
            num_channels=imagenet_main.NUM_CHANNELS,
            num_classes=imagenet_main.NUM_CLASSES,
            dtype=dtype,
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = imagenet_main.input_fn

    # When `enable_xla` is True, we always drop the remainder of the batches
    # in the dataset, as XLA-GPU doesn't support dynamic shapes.
    drop_remainder = flags_obj.enable_xla

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=parse_record_keras,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        drop_remainder=drop_remainder)

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(is_training=False,
                                      data_dir=flags_obj.data_dir,
                                      batch_size=flags_obj.batch_size,
                                      num_epochs=flags_obj.train_epochs,
                                      parse_record_fn=parse_record_keras,
                                      dtype=dtype,
                                      drop_remainder=drop_remainder)

    with strategy_scope:
        optimizer = keras_common.get_optimizer()
        if dtype == 'float16':
            # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
            # can be enabled with a single line of code.
            optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
                optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))

        if flags_obj.enable_xla and not flags_obj.enable_eager:
            # TODO(b/129861005): Fix OOM issue in eager mode when setting
            # `batch_size` in keras.Input layer.
            if strategy and strategy.num_replicas_in_sync > 1:
                # TODO(b/129791381): Specify `input_layer_batch_size` value in
                # DistributionStrategy multi-replica case.
                input_layer_batch_size = None
            else:
                input_layer_batch_size = flags_obj.batch_size
        else:
            input_layer_batch_size = None

        if flags_obj.use_trivial_model:
            model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES)
        else:
            model = resnet_model.resnet50(
                num_classes=imagenet_main.NUM_CLASSES,
                dtype=dtype,
                batch_size=input_layer_batch_size)

        model.compile(loss='sparse_categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=['sparse_categorical_accuracy'])

    callbacks = keras_common.get_callbacks(learning_rate_schedule,
                                           imagenet_main.NUM_IMAGES['train'])

    train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
    train_epochs = flags_obj.train_epochs

    if flags_obj.train_steps:
        train_steps = min(flags_obj.train_steps, train_steps)
        train_epochs = 1

    num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        # Only build the training graph. This reduces memory usage introduced by
        # control flow ops in layers that have different implementations for
        # training and inference (e.g., batch norm).
        tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)

    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)
    stats = keras_common.build_stats(history, eval_output, callbacks)
    return stats
def run_ncf(_):
    """Run NCF training and eval with Keras."""
    # TODO(seemuch): Support different train and eval batch sizes
    if FLAGS.eval_batch_size != FLAGS.batch_size:
        logging.warning(
            "The Keras implementation of NCF currently does not support batch_size "
            "!= eval_batch_size ({} vs. {}). Overriding eval_batch_size to match "
            "batch_size".format(FLAGS.eval_batch_size, FLAGS.batch_size))
        FLAGS.eval_batch_size = FLAGS.batch_size

    params = ncf_common.parse_flags(FLAGS)
    batch_size = params["batch_size"]

    # ncf_common rounds eval_batch_size (this is needed due to a reshape during
    # eval). This carries over that rounding to batch_size as well.
    params['batch_size'] = params['eval_batch_size']

    num_users, num_items, num_train_steps, num_eval_steps, producer = (
        ncf_common.get_inputs(params))

    params["num_users"], params["num_items"] = num_users, num_items
    producer.start()
    model_helpers.apply_clean(flags.FLAGS)

    batches_per_step = params["batches_per_step"]
    train_input_dataset, eval_input_dataset = _get_train_and_eval_data(
        producer, params)
    # It is required that for distributed training, the dataset must call
    # batch(). The parameter of batch() here is the number of replicas involed,
    # such that each replica evenly gets a slice of data.
    train_input_dataset = train_input_dataset.batch(batches_per_step)
    eval_input_dataset = eval_input_dataset.batch(batches_per_step)

    strategy = ncf_common.get_distribution_strategy(params)
    with distribution_utils.get_strategy_scope(strategy):
        keras_model = _get_keras_model(params)
        optimizer = tf.keras.optimizers.Adam(
            learning_rate=params["learning_rate"],
            beta_1=params["beta1"],
            beta_2=params["beta2"],
            epsilon=params["epsilon"])
        time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps)

        keras_model.compile(loss=_keras_loss,
                            metrics=[_get_metric_fn(params)],
                            optimizer=optimizer,
                            cloning=params["clone_model_in_keras_dist_strat"])

        history = keras_model.fit(
            train_input_dataset,
            epochs=FLAGS.train_epochs,
            callbacks=[IncrementEpochCallback(producer), time_callback],
            verbose=2)

        logging.info("Training done. Start evaluating")

        eval_results = keras_model.evaluate(eval_input_dataset,
                                            steps=num_eval_steps,
                                            verbose=2)

    logging.info("Keras evaluation is done.")

    stats = build_stats(history, eval_results, time_callback)
    return stats
Esempio n. 7
0
def run_ncf(_):
  """Run NCF training and eval with Keras."""

  keras_utils.set_session_config(enable_xla=FLAGS.enable_xla)

  if FLAGS.seed is not None:
    print("Setting tf seed")
    tf.random.set_seed(FLAGS.seed)

  # TODO(seemuch): Support different train and eval batch sizes
  if FLAGS.eval_batch_size != FLAGS.batch_size:
    logging.warning(
        "The Keras implementation of NCF currently does not support batch_size "
        "!= eval_batch_size ({} vs. {}). Overriding eval_batch_size to match "
        "batch_size".format(FLAGS.eval_batch_size, FLAGS.batch_size)
        )
    FLAGS.eval_batch_size = FLAGS.batch_size

  params = ncf_common.parse_flags(FLAGS)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=FLAGS.distribution_strategy,
      num_gpus=FLAGS.num_gpus)
  params["distribute_strategy"] = strategy

  if not keras_utils.is_v2_0() and strategy is not None:
    logging.error("NCF Keras only works with distribution strategy in TF 2.0")
    return

  if (params["keras_use_ctl"] and (
      not keras_utils.is_v2_0() or strategy is None)):
    logging.error(
        "Custom training loop only works with tensorflow 2.0 and dist strat.")
    return

  # ncf_common rounds eval_batch_size (this is needed due to a reshape during
  # eval). This carries over that rounding to batch_size as well. This is the
  # per device batch size
  params["batch_size"] = params["eval_batch_size"]
  batch_size = params["batch_size"]

  num_users, num_items, num_train_steps, num_eval_steps, producer = (
      ncf_common.get_inputs(params))

  params["num_users"], params["num_items"] = num_users, num_items
  producer.start()
  model_helpers.apply_clean(flags.FLAGS)

  batches_per_step = params["batches_per_step"]
  train_input_dataset, eval_input_dataset = _get_train_and_eval_data(producer,
                                                                     params)
  # It is required that for distributed training, the dataset must call
  # batch(). The parameter of batch() here is the number of replicas involed,
  # such that each replica evenly gets a slice of data.
  # drop_remainder = True, as we would like batch call to return a fixed shape
  # vs None, this prevents a expensive broadcast during weighted_loss
  train_input_dataset = train_input_dataset.batch(batches_per_step,
                                                  drop_remainder=True)
  eval_input_dataset = eval_input_dataset.batch(batches_per_step,
                                                drop_remainder=True)

  time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps)
  per_epoch_callback = IncrementEpochCallback(producer)
  callbacks = [per_epoch_callback, time_callback]

  if FLAGS.early_stopping:
    early_stopping_callback = CustomEarlyStopping(
        "val_HR_METRIC", desired_value=FLAGS.hr_threshold)
    callbacks.append(early_stopping_callback)


  with distribution_utils.get_strategy_scope(strategy):
    keras_model = _get_keras_model(params)
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=params["learning_rate"],
        beta_1=params["beta1"],
        beta_2=params["beta2"],
        epsilon=params["epsilon"])

  if params["keras_use_ctl"]:
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        reduction=tf.keras.losses.Reduction.SUM,
        from_logits=True)
    train_input_iterator = strategy.make_dataset_iterator(train_input_dataset)
    eval_input_iterator = strategy.make_dataset_iterator(eval_input_dataset)

    @tf.function
    def train_step():
      """Called once per step to train the model."""
      def step_fn(features):
        """Computes loss and applied gradient per replica."""
        with tf.GradientTape() as tape:
          softmax_logits = keras_model(features)
          labels = features[rconst.TRAIN_LABEL_KEY]
          loss = loss_object(labels, softmax_logits,
                             sample_weight=features[rconst.VALID_POINT_MASK])
          loss *= (1.0 / (batch_size*strategy.num_replicas_in_sync))

        grads = tape.gradient(loss, keras_model.trainable_variables)
        # Converting gradients to dense form helps in perf on GPU for NCF
        grads = neumf_model.sparse_to_dense_grads(
            list(zip(grads, keras_model.trainable_variables)))
        optimizer.apply_gradients(grads)
        return loss

      per_replica_losses = strategy.experimental_run(step_fn,
                                                     train_input_iterator)
      mean_loss = strategy.reduce(
          tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
      return mean_loss

    @tf.function
    def eval_step():
      """Called once per eval step to compute eval metrics."""
      def step_fn(features):
        """Computes eval metrics per replica."""
        softmax_logits = keras_model(features)
        in_top_k, metric_weights = metric_fn(
            softmax_logits, features[rconst.DUPLICATE_MASK], params)
        hr_sum = tf.reduce_sum(in_top_k*metric_weights)
        hr_count = tf.reduce_sum(metric_weights)
        return hr_sum, hr_count

      per_replica_hr_sum, per_replica_hr_count = (
          strategy.experimental_run(step_fn, eval_input_iterator))
      hr_sum = strategy.reduce(
          tf.distribute.ReduceOp.SUM, per_replica_hr_sum, axis=None)
      hr_count = strategy.reduce(
          tf.distribute.ReduceOp.SUM, per_replica_hr_count, axis=None)
      return hr_sum, hr_count

    time_callback.on_train_begin()
    for epoch in range(FLAGS.train_epochs):
      per_epoch_callback.on_epoch_begin(epoch)
      train_input_iterator.initialize()
      train_loss = 0
      for step in range(num_train_steps):
        time_callback.on_batch_begin(step+epoch*num_train_steps)
        train_loss += train_step()
        time_callback.on_batch_end(step+epoch*num_train_steps)
      train_loss /= num_train_steps
      logging.info("Done training epoch %s, epoch loss=%s.",
                   epoch+1, train_loss)
      eval_input_iterator.initialize()
      hr_sum = 0
      hr_count = 0
      for _ in range(num_eval_steps):
        step_hr_sum, step_hr_count = eval_step()
        hr_sum += step_hr_sum
        hr_count += step_hr_count
      logging.info("Done eval epoch %s, hr=%s.", epoch+1, hr_sum/hr_count)

      if (FLAGS.early_stopping and
          float(hr_sum/hr_count) > params["hr_threshold"]):
        break

    time_callback.on_train_end()
    eval_results = [None, hr_sum/hr_count]

  else:
    with distribution_utils.get_strategy_scope(strategy):

      keras_model.compile(optimizer=optimizer,
                          run_eagerly=FLAGS.run_eagerly)

      history = keras_model.fit(train_input_dataset,
                                epochs=FLAGS.train_epochs,
                                callbacks=callbacks,
                                validation_data=eval_input_dataset,
                                validation_steps=num_eval_steps,
                                verbose=2)

      logging.info("Training done. Start evaluating")

      eval_results = keras_model.evaluate(
          eval_input_dataset,
          steps=num_eval_steps,
          verbose=2)

      logging.info("Keras evaluation is done.")

    if history and history.history:
      train_history = history.history
      train_loss = train_history["loss"][-1]

  stats = build_stats(train_loss, eval_results, time_callback)
  return stats
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using custom training loops.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == tf.bfloat16:
        policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
            'mixed_bfloat16')
        tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)

    # TODO(anj-s): Set data_format without using Keras.
    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=distribution_utils.configure_cluster(),
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    train_ds, test_ds = get_input_dataset(flags_obj, strategy)
    per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
        flags_obj)
    steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps)
    logging.info(
        "Training %d epochs, each epoch has %d steps, "
        "total steps: %d; Eval %d steps", train_epochs, per_epoch_steps,
        train_epochs * per_epoch_steps, eval_steps)

    time_callback = keras_utils.TimeHistory(flags_obj.batch_size,
                                            flags_obj.log_steps)

    with distribution_utils.get_strategy_scope(strategy):
        model = resnet_model.resnet50(
            num_classes=imagenet_preprocessing.NUM_CLASSES,
            batch_size=flags_obj.batch_size,
            use_l2_regularizer=not flags_obj.single_l2_loss_op)

        lr_schedule = common.PiecewiseConstantDecayWithWarmup(
            batch_size=flags_obj.batch_size,
            epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
            warmup_epochs=common.LR_SCHEDULE[0][1],
            boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
            multipliers=list(p[0] for p in common.LR_SCHEDULE),
            compute_lr_on_cpu=True)
        optimizer = common.get_optimizer(lr_schedule)

        if flags_obj.fp16_implementation == 'graph_rewrite':
            if not flags_obj.use_tf_function:
                raise ValueError(
                    '--fp16_implementation=graph_rewrite requires '
                    '--use_tf_function to be true')
            loss_scale = flags_core.get_loss_scale(flags_obj,
                                                   default_for_fp16=128)
            optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                optimizer, loss_scale)

        train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
        training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'training_accuracy', dtype=tf.float32)
        test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
        test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'test_accuracy', dtype=tf.float32)

        trainable_variables = model.trainable_variables

        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            with tf.GradientTape() as tape:
                logits = model(images, training=True)

                prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, logits)
                loss = tf.reduce_sum(prediction_loss) * (1.0 /
                                                         flags_obj.batch_size)
                num_replicas = tf.distribute.get_strategy(
                ).num_replicas_in_sync

                if flags_obj.single_l2_loss_op:
                    filtered_variables = [
                        tf.reshape(v, (-1, )) for v in trainable_variables
                        if 'bn' not in v.name
                    ]
                    l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss(
                        tf.concat(filtered_variables, axis=0))
                    loss += (l2_loss / num_replicas)
                else:
                    loss += (tf.reduce_sum(model.losses) / num_replicas)

                # Scale the loss
                if flags_obj.dtype == "fp16":
                    loss = optimizer.get_scaled_loss(loss)

            grads = tape.gradient(loss, trainable_variables)

            # Unscale the grads
            if flags_obj.dtype == "fp16":
                grads = optimizer.get_unscaled_gradients(grads)

            optimizer.apply_gradients(zip(grads, trainable_variables))
            train_loss.update_state(loss)
            training_accuracy.update_state(labels, logits)

        @tf.function
        def train_steps(iterator, steps):
            """Performs distributed training steps in a loop."""
            for _ in tf.range(steps):
                strategy.experimental_run_v2(step_fn, args=(next(iterator), ))

        def train_single_step(iterator):
            if strategy:
                strategy.experimental_run_v2(step_fn, args=(next(iterator), ))
            else:
                return step_fn(next(iterator))

        def test_step(iterator):
            """Evaluation StepFn."""
            def step_fn(inputs):
                images, labels = inputs
                logits = model(images, training=False)
                loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, logits)
                loss = tf.reduce_sum(loss) * (1.0 / flags_obj.batch_size)
                test_loss.update_state(loss)
                test_accuracy.update_state(labels, logits)

            if strategy:
                strategy.experimental_run_v2(step_fn, args=(next(iterator), ))
            else:
                step_fn(next(iterator))

        if flags_obj.use_tf_function:
            train_single_step = tf.function(train_single_step)
            test_step = tf.function(test_step)

        train_iter = iter(train_ds)
        time_callback.on_train_begin()
        for epoch in range(train_epochs):
            train_loss.reset_states()
            training_accuracy.reset_states()

            steps_in_current_epoch = 0
            while steps_in_current_epoch < per_epoch_steps:
                time_callback.on_batch_begin(steps_in_current_epoch +
                                             epoch * per_epoch_steps)
                steps = _steps_to_run(steps_in_current_epoch, per_epoch_steps,
                                      steps_per_loop)
                if steps == 1:
                    train_single_step(train_iter)
                else:
                    # Converts steps to a Tensor to avoid tf.function retracing.
                    train_steps(train_iter,
                                tf.convert_to_tensor(steps, dtype=tf.int32))
                time_callback.on_batch_end(steps_in_current_epoch +
                                           epoch * per_epoch_steps)
                steps_in_current_epoch += steps

            logging.info('Training loss: %s, accuracy: %s at epoch %d',
                         train_loss.result().numpy(),
                         training_accuracy.result().numpy(), epoch + 1)

            if (not flags_obj.skip_eval
                    and (epoch + 1) % flags_obj.epochs_between_evals == 0):
                test_loss.reset_states()
                test_accuracy.reset_states()

                test_iter = iter(test_ds)
                for _ in range(eval_steps):
                    test_step(test_iter)

                logging.info('Test loss: %s, accuracy: %s%% at epoch: %d',
                             test_loss.result().numpy(),
                             test_accuracy.result().numpy(), epoch + 1)

        time_callback.on_train_end()

        eval_result = None
        train_result = None
        if not flags_obj.skip_eval:
            eval_result = [
                test_loss.result().numpy(),
                test_accuracy.result().numpy()
            ]
            train_result = [
                train_loss.result().numpy(),
                training_accuracy.result().numpy()
            ]

        stats = build_stats(train_result, eval_result, time_callback)
        return stats
def run(flags_obj):
    """Run ResNet Cifar-10 training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_utils.set_gpu_thread_mode_and_count(
            per_gpu_thread_count=flags_obj.per_gpu_thread_count,
            gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
            num_gpus=flags_obj.num_gpus,
            datasets_num_private_threads=flags_obj.datasets_num_private_threads
        )
    common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'fp16':
        raise ValueError(
            'dtype fp16 is not supported in Keras. Use the default '
            'value(fp32).')

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.config.list_physical_devices('GPU') else
                       'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    if flags_obj.use_synthetic_data:
        synthetic_util.set_up_synthetic_data()
        input_fn = common.get_synth_input_fn(
            height=cifar_preprocessing.HEIGHT,
            width=cifar_preprocessing.WIDTH,
            num_channels=cifar_preprocessing.NUM_CHANNELS,
            num_classes=cifar_preprocessing.NUM_CLASSES,
            dtype=flags_core.get_tf_dtype(flags_obj),
            drop_remainder=True)
    else:
        synthetic_util.undo_set_up_synthetic_data()
        input_fn = cifar_preprocessing.input_fn

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        parse_record_fn=cifar_preprocessing.parse_record,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        # Setting drop_remainder to avoid the partial batch logic in normalization
        # layer, which triggers tf.where and leads to extra memory copy of input
        # sizes between host and GPU.
        drop_remainder=(not flags_obj.enable_get_next_as_optional))

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=flags_obj.batch_size,
            parse_record_fn=cifar_preprocessing.parse_record)
        options = tf.data.Options()
        options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
        eval_input_dataset = eval_input_dataset.with_options(options)

    steps_per_epoch = (cifar_preprocessing.NUM_IMAGES['train'] //
                       flags_obj.batch_size)
    lr_schedule = 0.1
    if flags_obj.use_tensor_lr:
        initial_learning_rate = common.BASE_LEARNING_RATE * flags_obj.batch_size / 128
        lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            boundaries=list(p[1] * steps_per_epoch for p in LR_SCHEDULE),
            values=[initial_learning_rate] + list(p[0] * initial_learning_rate
                                                  for p in LR_SCHEDULE))

    with strategy_scope:
        optimizer = common.get_optimizer(lr_schedule)
        model = resnet_cifar_model.resnet56(
            classes=cifar_preprocessing.NUM_CLASSES)
        model.compile(loss='sparse_categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=(['sparse_categorical_accuracy']
                               if flags_obj.report_accuracy_metrics else None),
                      run_eagerly=flags_obj.run_eagerly)

    train_epochs = flags_obj.train_epochs

    callbacks = common.get_callbacks()

    if not flags_obj.use_tensor_lr:
        lr_callback = LearningRateBatchScheduler(
            schedule=learning_rate_schedule,
            batch_size=flags_obj.batch_size,
            steps_per_epoch=steps_per_epoch)
        callbacks.append(lr_callback)

    # if mutliple epochs, ignore the train_steps flag.
    if train_epochs <= 1 and flags_obj.train_steps:
        steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
        train_epochs = 1

    num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    if not strategy and flags_obj.explicit_gpu_placement:
        # TODO(b/135607227): Add device scope automatically in Keras training loop
        # when not using distribition strategy.
        no_dist_strat_device = tf.device('/device:GPU:0')
        no_dist_strat_device.__enter__()

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)
    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

    if not strategy and flags_obj.explicit_gpu_placement:
        no_dist_strat_device.__exit__()

    stats = common.build_stats(history, eval_output, callbacks)
    return stats
Esempio n. 10
0
def run_customized_training_loop(strategy,
                                 model_fn,
                                 loss_fn,
                                 out_dir,
                                 train_input_fn,
                                 eval_input_fn,
                                 test_input_fn,
                                 steps_per_stats,
                                 num_train_steps,
                                 steps_per_eval,
                                 run_eagerly,
                                 explicit_allreduce,
                                 pmetric_name,
                                 all_metric_fn_lst,
                                 keep_checkpoint_max,
                                 feature_type2name,
                                 optimizer_fn,
                                 bert_optimizer_fn,
                                 task_ids,
                                 task_weights,
                                 num_eval_steps,
                                 scale_loss=True,
                                 custom_callbacks=None,
                                 pre_allreduce_callbacks=None,
                                 post_allreduce_callbacks=None):
    """Run DeText model training using low-level API

    Reference: official/nlp/bert/model_training_utils.py

    :param strategy: Distribution strategy on which to run low level training loop.
    :param model_fn: Function that returns a tuple (model, sub_model). Caller of this function should add optimizer to the `model` via calling
      `model.compile()` API or manually setting `model.optimizer` attribute.
    :param loss_fn: Function with signature func(labels, logits) and returns a loss tensor.
    :param scale_loss: Whether to divide the raw loss by number of replicas before gradients calculation.
    :param keep_checkpoint_max: Maximum checkpoints to keep.
    :param feature_type2name: Mapping from feature types to feature names.
    :param out_dir: Model directory used during training for restoring/saving model weights.
    :param train_input_fn: Function that returns a tf.data.Dataset used for training.
    :param eval_input_fn: Function that returns evaluation dataset
    :param test_input_fn: Function that returns test dataset
    :param num_train_steps: Total training steps.
    :param steps_per_eval: Number of steps to run per eval. At the end of each eval, model checkpoint will be saved and evaluation will be conducted
      if evaluation dataset is provided.
    :param steps_per_stats: Number of steps per graph-mode loop. In order to reduce communication in eager context, training logs are printed every
      steps_per_stats.
    :param task_ids Task ids related to multi-task training
    :param task_weights Task weights related to multi-task training
    :param optimizer_fn: Function to create optimizer for non-bert parameters
    :param bert_optimizer_fn: Function to create optimzier for bert parameters
    :param pmetric_name: Primary metric name. Model with best pmetric on evaluation dataset will be exported in pb format
    :param all_metric_fn_lst: A list of metrics functions that return a Keras Metric object to record evaluation result using evaluation dataset
    :param custom_callbacks: A list of Keras Callbacks objects to run during training. More specifically, `on_batch_begin()`, `on_batch_end()`,
      methods are invoked during training.
    :param run_eagerly: Whether to run model training in pure eager execution. This should be disable for TPUStrategy.
    :param explicit_allreduce: Whether to explicitly perform gradient allreduce, instead of relying on implicit allreduce in optimizer.apply_gradients().
      default is False. For now, if training using FP16 mixed precision, explicit allreduce will aggregate gradients in FP16 format. For TPU and
      GPU training using FP32, explicit allreduce will aggregate gradients in FP32 format.
    :param pre_allreduce_callbacks: A list of callback functions that takes gradients and model variables pairs as input, manipulate them, and returns
      a new gradients and model variables paris. The callback functions will be invoked in the list order and before gradients are allreduced.
      With mixed precision training, the pre_allreduce_allbacks will be applied on scaled_gradients. Default is no callbacks. Only used when
      explicit_allreduce=True.
    :param post_allreduce_callbacks: A list of callback functions that takes gradients and model variables pairs as input, manipulate them, and
      returns a new gradients and model variables paris. The callback functions will be invoked in the list order and right before gradients
      are applied to variables for updates. Default is no callbacks. Only used when explicit_allreduce=True.

    :return Trained model.
    """

    assert tf.executing_eagerly(
    ), 'Outer training flow should be executed in eager mode'

    train_iterator = train_flow_helper.get_input_iterator(
        train_input_fn, strategy)
    with distribution_utils.get_strategy_scope(strategy):
        # To correctly place the model weights on accelerators, model and optimizer should be created in scope.
        model = model_fn()
        optimizer = optimizer_fn()
        bert_optimizer = bert_optimizer_fn()

        # Get metrics
        train_loss_metric = tf.keras.metrics.Mean('training_loss',
                                                  dtype=tf.float32)
        all_metrics = [metric_fn() for metric_fn in all_metric_fn_lst
                       ] if all_metric_fn_lst else []
        pmetric = list(filter(lambda x: x.name == pmetric_name,
                              all_metrics))[0]

        # Create summary writers
        summary_dir = train_flow_helper.create_summary_dir(strategy, out_dir)
        eval_summary_writer = tf.summary.create_file_writer(
            os.path.join(summary_dir, 'eval'))
        train_summary_writer = train_flow_helper.create_train_summary_writer(
            os.path.join(summary_dir, 'train'), steps_per_stats)

        def train_step(iterator):
            """Performs a distributed training step """
            def _train_step(inputs):
                train_step_fn(inputs, model, loss_fn, scale_loss, strategy,
                              explicit_allreduce, feature_type2name,
                              pre_allreduce_callbacks,
                              post_allreduce_callbacks, optimizer,
                              bert_optimizer, task_ids, task_weights,
                              train_loss_metric)

            strategy.run(_train_step, args=(next(iterator), ))

        def eval_step(iterator):
            """Calculates evaluation metrics on distributed devices."""
            def _eval_step(inputs):
                eval_step_fn(inputs, model, feature_type2name, all_metrics)

            strategy.run(_eval_step, args=(next(iterator), ))

        # Use AutoGraph for efficiency if not in eager mode
        if not run_eagerly:
            train_step = tf.function(train_step)
            eval_step = tf.function(eval_step)

        # Manage checkpoints
        checkpoint = tf.train.Checkpoint(model=model,
                                         optimizer=optimizer,
                                         bert_optimizer=bert_optimizer)
        manager = tf.train.CheckpointManager(checkpoint=checkpoint,
                                             directory=out_dir,
                                             max_to_keep=keep_checkpoint_max)
        if manager.latest_checkpoint:
            logging.info('Checkpoint file found and restoring')
            checkpoint.restore(manager.latest_checkpoint)

        current_step = optimizer.iterations.numpy()
        best_pmetric_val = None

        # Training loop starts here
        while current_step < num_train_steps:
            # Training loss/metric are taking average over steps inside micro training loop. We reset the their values before each round.
            train_loss_metric.reset_states()
            for metric in model.metrics:
                metric.reset_states()

            train_flow_helper.run_callbacks_on_batch_begin(
                current_step, custom_callbacks)
            steps = train_flow_helper.steps_to_run(current_step,
                                                   steps_per_eval,
                                                   steps_per_stats)

            for _ in range(steps):
                train_step(train_iterator)

            train_loss = train_flow_helper.float_metric_value(
                train_loss_metric)
            current_step += steps
            train_flow_helper.run_callbacks_on_batch_end(
                current_step - 1, {'loss': train_loss}, custom_callbacks)
            write_training_summary(train_summary_writer, train_loss,
                                   train_loss_metric, current_step)

            logging.info('Train Step: %d/%d  / loss = %s' %
                         (current_step, num_train_steps, train_loss))

            # Saves model checkpoints and run validation steps
            if current_step % steps_per_eval == 0:
                # To avoid repeated model saving, we do not save after the last step of training.
                if current_step < num_train_steps:
                    train_flow_helper.save_checkpoint(strategy, current_step,
                                                      manager)

                logging.info('Running evaluation after step: %s.',
                             current_step)
                dev_iter = train_flow_helper.get_input_iterator(eval_input_fn,
                                                                strategy,
                                                                is_eval=True)
                run_evaluation(current_step, dev_iter, eval_summary_writer,
                               all_metrics, eval_step, 'Validation',
                               num_eval_steps)

                best_pmetric_val = export_best_model(strategy, pmetric,
                                                     best_pmetric_val, model,
                                                     out_dir)

                # Re-initialize evaluation metric
                for metric in all_metrics:
                    metric.reset_states()

        train_flow_helper.save_checkpoint(strategy, current_step, manager)

        logging.info('Running final evaluation after training is complete.')
        dev_iter = train_flow_helper.get_input_iterator(eval_input_fn,
                                                        strategy,
                                                        is_eval=True)
        run_evaluation(current_step, dev_iter, eval_summary_writer,
                       all_metrics, eval_step, 'Validation', num_eval_steps)

        best_pmetric_val = export_best_model(strategy, pmetric,
                                             best_pmetric_val, model, out_dir)

        logging.info(
            f'Running evaluation on test set with best model on dev set ({pmetric_name} = {best_pmetric_val})'
        )
        for metric in all_metrics:
            metric.reset_states()
        model.load_weights(
            train_flow_helper.get_export_model_variable_dir(out_dir))
        test_iter = train_flow_helper.get_input_iterator(test_input_fn,
                                                         strategy,
                                                         is_eval=True)
        run_evaluation(current_step, test_iter, eval_summary_writer,
                       all_metrics, eval_step, 'Test', num_eval_steps)

        if not detext.utils.distributed_utils.should_export_summary(strategy):
            tf.io.gfile.rmtree(summary_dir)

        return model
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using custom training loops.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    #init horovod
    hvd.init()

    #pin GPU to be used to process local rank
    #If TF1
    #config = tf.ConfigProto()
    #config.gpu_options.visible_device_list = str(hvd.local_rank())
    #If TF2
    gpus = tf.config.experimental.list_physical_devices('GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    if gpus:
        tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()],
                                                   'GPU')

    # keras_utils.set_session_config(
    #     enable_eager=flags_obj.enable_eager,
    #     enable_xla=flags_obj.enable_xla)
    # performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))

    # if tf.config.list_physical_devices('GPU'):
    #   if flags_obj.tf_gpu_thread_mode:
    #     keras_utils.set_gpu_thread_mode_and_count(
    #         per_gpu_thread_count=flags_obj.per_gpu_thread_count,
    #         gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
    #         num_gpus=flags_obj.num_gpus,
    #         datasets_num_private_threads=flags_obj.datasets_num_private_threads)
    #   common.set_cudnn_batchnorm_mode()

    # TODO(anj-s): Set data_format without using Keras.
    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.config.list_physical_devices('GPU') else
                       'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        #num_gpus=flags_obj.num_gpus,
        num_gpus=1,  #set to 1 to force into non-distributed but GPU mode
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
        flags_obj)
    steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps)

    logging.info(
        'Training %d epochs, each epoch has %d steps, '
        'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
        train_epochs * per_epoch_steps, eval_steps)

    time_callback = keras_utils.TimeHistory(
        flags_obj.batch_size,
        flags_obj.log_steps,
        logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
    with distribution_utils.get_strategy_scope(strategy):
        runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
                                                  per_epoch_steps)

    eval_interval = flags_obj.epochs_between_evals * per_epoch_steps
    checkpoint_interval = (per_epoch_steps
                           if flags_obj.enable_checkpoint_and_export else None)
    summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None

    checkpoint_manager = tf.train.CheckpointManager(
        runnable.checkpoint,
        directory=flags_obj.model_dir,
        max_to_keep=10,
        step_counter=runnable.global_step,
        checkpoint_interval=checkpoint_interval)

    resnet_controller = controller.Controller(
        strategy,
        runnable.train,
        runnable.evaluate if not flags_obj.skip_eval else None,
        global_step=runnable.global_step,
        steps_per_loop=steps_per_loop,
        train_steps=per_epoch_steps * train_epochs,
        checkpoint_manager=checkpoint_manager,
        summary_interval=summary_interval,
        eval_steps=eval_steps,
        eval_interval=eval_interval)

    time_callback.on_train_begin()
    resnet_controller.train(evaluate=not flags_obj.skip_eval)
    time_callback.on_train_end()

    stats = build_stats(runnable, time_callback)
    return stats
Esempio n. 12
0
def run(flags_obj):
    """Run ResNet Cifar-10 training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
    # Eager is default in tf 2.0 and should not be toggled
    if keras_common.is_v2_0():
        keras_common.set_config_v2()
    else:
        config = keras_common.get_config_proto_v1()
        if flags_obj.enable_eager:
            tf.compat.v1.enable_eager_execution(config=config)
        else:
            sess = tf.Session(config=config)
            tf.keras.backend.set_session(sess)

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'fp16':
        raise ValueError(
            'dtype fp16 is not supported in Keras. Use the default '
            'value(fp32).')

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = keras_common.get_synth_input_fn(
            height=cifar_main.HEIGHT,
            width=cifar_main.WIDTH,
            num_channels=cifar_main.NUM_CHANNELS,
            num_classes=cifar_main.NUM_CLASSES,
            dtype=flags_core.get_tf_dtype(flags_obj))
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = cifar_main.input_fn

    train_input_dataset = input_fn(is_training=True,
                                   data_dir=flags_obj.data_dir,
                                   batch_size=flags_obj.batch_size,
                                   num_epochs=flags_obj.train_epochs,
                                   parse_record_fn=parse_record_keras)

    eval_input_dataset = input_fn(is_training=False,
                                  data_dir=flags_obj.data_dir,
                                  batch_size=flags_obj.batch_size,
                                  num_epochs=flags_obj.train_epochs,
                                  parse_record_fn=parse_record_keras)

    with strategy_scope:
        optimizer = keras_common.get_optimizer()
        model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES)

        model.compile(loss='categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=['categorical_accuracy'])

    callbacks = keras_common.get_callbacks(learning_rate_schedule,
                                           cifar_main.NUM_IMAGES['train'])

    train_steps = cifar_main.NUM_IMAGES['train'] // flags_obj.batch_size
    train_epochs = flags_obj.train_epochs

    if flags_obj.train_steps:
        train_steps = min(flags_obj.train_steps, train_steps)
        train_epochs = 1

    num_eval_steps = (cifar_main.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)
    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)
    stats = keras_common.build_stats(history, eval_output, callbacks)
    return stats
Esempio n. 13
0
def run(flags_obj, datasets_override=None, strategy_override=None):
    """Run MNIST model training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.
    datasets_override: A pair of `tf.data.Dataset` objects to train the model,
                       representing the train and test sets.
    strategy_override: A `tf.distribute.Strategy` object to use for model.

  Returns:
    Dictionary of training and eval stats.
  """
    strategy = strategy_override or distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        tpu_address=flags_obj.tpu)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    mnist = tfds.builder('mnist', data_dir=flags_obj.data_dir)
    if flags_obj.download:
        mnist.download_and_prepare()

    mnist_train, mnist_test = datasets_override or mnist.as_dataset(
        split=['train', 'test'],
        decoders={'image': decode_image()},  # pylint: disable=no-value-for-parameter
        as_supervised=True)
    train_input_dataset = mnist_train.cache().repeat().shuffle(
        buffer_size=50000).batch(flags_obj.batch_size)
    eval_input_dataset = mnist_test.cache().repeat().batch(
        flags_obj.batch_size)

    with strategy_scope:
        lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            0.05, decay_steps=100000, decay_rate=0.96)
        optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)

        model = build_model()
        model.compile(optimizer=optimizer,
                      loss='sparse_categorical_crossentropy',
                      metrics=['sparse_categorical_accuracy'])

    num_train_examples = mnist.info.splits['train'].num_examples
    train_steps = num_train_examples // flags_obj.batch_size
    train_epochs = flags_obj.train_epochs

    ckpt_full_path = os.path.join(flags_obj.model_dir,
                                  'model.ckpt-{epoch:04d}')
    callbacks = [
        tf.keras.callbacks.ModelCheckpoint(ckpt_full_path,
                                           save_weights_only=True),
        tf.keras.callbacks.TensorBoard(log_dir=flags_obj.model_dir),
    ]

    num_eval_examples = mnist.info.splits['test'].num_examples
    num_eval_steps = num_eval_examples // flags_obj.batch_size

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=eval_input_dataset,
                        validation_freq=flags_obj.epochs_between_evals)

    export_path = os.path.join(flags_obj.model_dir, 'saved_model')
    model.save(export_path, include_optimizer=False)

    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

    stats = common.build_stats(history, eval_output, callbacks)
    return stats
Esempio n. 14
0
def run_predict(flags_obj, datasets_override=None, strategy_override=None):
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_utils.set_gpu_thread_mode_and_count(
            per_gpu_thread_count=flags_obj.per_gpu_thread_count,
            gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
            num_gpus=flags_obj.num_gpus,
            datasets_num_private_threads=flags_obj.datasets_num_private_threads
        )
    common.set_cudnn_batchnorm_mode()

    performance.set_mixed_precision_policy(
        flags_core.get_tf_dtype(flags_obj),
        flags_core.get_loss_scale(flags_obj, default_for_fp16=128))

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    # Configures cluster spec for distribution strategy.
    _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                             flags_obj.task_index)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    distribution_utils.undo_set_up_synthetic_data()

    train_input_dataset, eval_input_dataset, tr_dataset, te_dataset = setup_datasets(
        flags_obj, shuffle=False)

    pred_input_dataset, pred_dataset = eval_input_dataset, te_dataset

    with strategy_scope:
        #model = build_model(tr_dataset.num_classes, mode='resnet50_features')
        model = build_model(100, mode='resnet50_features')

        load_path = GB_OPTIONS.pretrained_filepath
        if load_path is None:
            load_path = GB_OPTIONS.checkpoint_folder
        latest = tf.train.latest_checkpoint(load_path)
        print(latest)
        model.load_weights(latest)

        num_eval_examples = pred_dataset.num_examples_per_epoch()
        num_eval_steps = num_eval_examples // GB_OPTIONS.batch_size
        print(GB_OPTIONS.batch_size)

        pred = model.predict(pred_input_dataset,
                             batch_size=GB_OPTIONS.batch_size,
                             steps=num_eval_steps)

        lab = np.asarray(pred_dataset.data[1])
        if hasattr(pred_dataset, 'ori_labels'):
            ori_lab = pred_dataset.ori_labels
        else:
            ori_lab = lab

        print(pred.shape)

        np.save('out_X', pred)
        np.save('out_labels', lab)
        np.save('out_ori_labels', ori_lab)

        return 'good'
Esempio n. 15
0
def run_train(flags_obj):
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_utils.set_gpu_thread_mode_and_count(
            per_gpu_thread_count=flags_obj.per_gpu_thread_count,
            gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
            num_gpus=flags_obj.num_gpus,
            datasets_num_private_threads=flags_obj.datasets_num_private_threads
        )
    common.set_cudnn_batchnorm_mode()

    performance.set_mixed_precision_policy(
        flags_core.get_tf_dtype(flags_obj),
        flags_core.get_loss_scale(flags_obj, default_for_fp16=128))

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    # Configures cluster spec for distribution strategy.
    _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                             flags_obj.task_index)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    distribution_utils.undo_set_up_synthetic_data()

    train_input_dataset, eval_input_dataset, tr_dataset, te_dataset = setup_datasets(
        flags_obj)

    lr_schedule = common.PiecewiseConstantDecayWithWarmup(
        batch_size=GB_OPTIONS.batch_size,
        epoch_size=tr_dataset.num_examples_per_epoch(),
        warmup_epochs=common.LR_SCHEDULE[0][1],
        boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
        multipliers=list(p[0] for p in common.LR_SCHEDULE),
        compute_lr_on_cpu=True)

    with strategy_scope:
        optimizer = common.get_optimizer(lr_schedule)
        model = build_model(tr_dataset.num_classes, mode='resnet50')

        if GB_OPTIONS.pretrained_filepath is not None:
            latest = tf.train.latest_checkpoint(GB_OPTIONS.pretrained_filepath)
            print(latest)
            model.load_weights(latest)

        #losses = ["sparse_categorical_crossentropy"]
        #lossWeights = [1.0]
        model.compile(
            optimizer=optimizer,
            loss="sparse_categorical_crossentropy",
            #loss_weights=lossWeights,
            metrics=['sparse_categorical_accuracy'])

        num_train_examples = tr_dataset.num_examples_per_epoch()
        steps_per_epoch = num_train_examples // GB_OPTIONS.batch_size
        train_epochs = GB_OPTIONS.num_epochs

        if not hasattr(tr_dataset, "n_poison"):
            n_poison = 0
            n_cover = 0
        else:
            n_poison = tr_dataset.n_poison
            n_cover = tr_dataset.n_cover

        callbacks = common.get_callbacks(
            steps_per_epoch=steps_per_epoch,
            pruning_method=flags_obj.pruning_method,
            enable_checkpoint_and_export=False,
            model_dir=GB_OPTIONS.checkpoint_folder)
        ckpt_full_path = os.path.join(
            GB_OPTIONS.checkpoint_folder,
            'model.ckpt-{epoch:04d}-p%d-c%d' % (n_poison, n_cover))
        callbacks.append(
            tf.keras.callbacks.ModelCheckpoint(ckpt_full_path,
                                               save_weights_only=True,
                                               save_best_only=True))

        num_eval_examples = te_dataset.num_examples_per_epoch()
        num_eval_steps = num_eval_examples // GB_OPTIONS.batch_size

        if flags_obj.skip_eval:
            # Only build the training graph. This reduces memory usage introduced by
            # control flow ops in layers that have different implementations for
            # training and inference (e.g., batch norm).
            if flags_obj.set_learning_phase_to_train:
                # TODO(haoyuzhang): Understand slowdown of setting learning phase when
                # not using distribution strategy.
                tf.keras.backend.set_learning_phase(1)
            num_eval_steps = None
            eval_input_dataset = None

        history = model.fit(train_input_dataset,
                            epochs=train_epochs,
                            steps_per_epoch=steps_per_epoch,
                            callbacks=callbacks,
                            validation_steps=num_eval_steps,
                            validation_data=eval_input_dataset,
                            validation_freq=flags_obj.epochs_between_evals)

        export_path = os.path.join(GB_OPTIONS.checkpoint_folder, 'saved_model')
        model.save(export_path, include_optimizer=False)

        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

        stats = common.build_stats(history, eval_output, callbacks)

        return stats
Esempio n. 16
0
    def train(self):
        """Trains the model."""
        self.params['train'] = True
        params = self.params
        flags_obj = self.flags_obj
        # Sets config options.
        keras_utils.set_session_config(enable_xla=flags_obj.enable_xla)

        _ensure_dir(flags_obj.model_dir)
        with distribution_utils.get_strategy_scope(self.distribution_strategy):
            model = transformer.create_model(params, is_train=True)
            opt = self._create_optimizer()

            current_step = 0
            checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)
            latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir)
            if latest_checkpoint:
                checkpoint.restore(latest_checkpoint)
                logging.info("###Loaded checkpoint %s", latest_checkpoint)
                current_step = opt.iterations.numpy()
                logging.info(current_step)

            model.compile(opt)

        model.summary()

        if self.use_tpu:
            # Different from experimental_distribute_dataset,
            # experimental_distribute_datasets_from_function requires
            # per-replica/local batch size.
            params[
                "batch_size"] /= self.distribution_strategy.num_replicas_in_sync
            train_ds = (
                self.distribution_strategy.
                experimental_distribute_datasets_from_function(
                    lambda ctx: data_pipeline.train_input_fn(params, ctx)))
        else:
            train_ds = data_pipeline.train_input_fn(params)
            map_data_fn = data_pipeline.map_data_for_transformer_fn
            train_ds = train_ds.map(
                map_data_fn, num_parallel_calls=params["num_parallel_calls"])
        # if params["use_ctl"]:
        #     train_ds_iterator = iter(train_ds)

        callbacks = self._create_callbacks(flags_obj.model_dir, 0, params)

        # TODO(b/139418525): Refactor the custom training loop logic.
        train_loss = 0

        # cased_score, uncased_score = None, None
        # cased_score_history, uncased_score_history = [], []
        while current_step < flags_obj.train_steps:
            remaining_steps = flags_obj.train_steps - current_step
            train_steps_per_eval = (remaining_steps if remaining_steps <
                                    flags_obj.steps_between_evals else
                                    flags_obj.steps_between_evals)
            current_iteration = current_step // flags_obj.steps_between_evals

            logging.info(
                "Start train iteration at global step:{}".format(current_step))
            history = None
            if params["use_ctl"]:
                if not self.use_tpu:
                    raise NotImplementedError(
                        "Custom training loop on GPUs is not implemented.")
                # Runs training steps.
                # train_steps(train_ds_iterator,
                #             tf.convert_to_tensor(train_steps_per_eval, dtype=tf.int32))
                # current_step += train_steps_per_eval
                # train_loss = train_loss_metric.result().numpy().astype(float)
                # logging.info("Train Step: %d/%d / loss = %s",
                #              current_step, flags_obj.train_steps, train_loss)
                #
                # checkpoint_name = checkpoint.save(
                #     os.path.join(
                #         flags_obj.model_dir,
                #         "ctl_step_{}.ckpt".format(current_step)))
                # logging.info("Saved checkpoint to %s", checkpoint_name)
            else:
                if self.use_tpu:
                    raise NotImplementedError(
                        "Keras model.fit on TPUs is not implemented.")
                history = model.fit(
                    train_ds,
                    initial_epoch=current_iteration,
                    epochs=current_iteration + 1,
                    steps_per_epoch=train_steps_per_eval,
                    callbacks=callbacks,
                    # If TimeHistory is enabled, progress bar would be messy. Increase
                    # the verbose level to get rid of it.
                    verbose=(2 if flags_obj.enable_time_history else 1))
                current_step += train_steps_per_eval
                logging.info("Train history: {}".format(history.history))

            logging.info(
                "End train iteration at global step:{}".format(current_step))

            # if (flags_obj.bleu_source and flags_obj.bleu_ref):
            #     uncased_score, cased_score = self.eval()
            #     cased_score_history.append([current_iteration + 1, cased_score])
            #     uncased_score_history.append([current_iteration + 1, uncased_score])

        stats = ({
            "loss": train_loss
        } if history is None else misc.build_stats(history, callbacks))
        # if uncased_score and cased_score:
        #     stats["bleu_uncased"] = uncased_score
        #     stats["bleu_cased"] = cased_score
        #     stats["bleu_uncased_history"] = uncased_score_history
        #     stats["bleu_cased_history"] = cased_score_history
        return stats
Esempio n. 17
0
def run_ncf(_):
    """Run NCF training and eval with Keras."""

    keras_utils.set_session_config(enable_xla=FLAGS.enable_xla)

    if FLAGS.seed is not None:
        print("Setting tf seed")
        tf.random.set_seed(FLAGS.seed)

    params = ncf_common.parse_flags(FLAGS)
    model_helpers.apply_clean(flags.FLAGS)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.distribution_strategy,
        num_gpus=FLAGS.num_gpus,
        tpu_address=FLAGS.tpu)
    params["distribute_strategy"] = strategy

    if not keras_utils.is_v2_0() and strategy is not None:
        logging.error(
            "NCF Keras only works with distribution strategy in TF 2.0")
        return
    if (params["keras_use_ctl"]
            and (not keras_utils.is_v2_0() or strategy is None)):
        logging.error(
            "Custom training loop only works with tensorflow 2.0 and dist strat."
        )
        return
    if params["use_tpu"] and not params["keras_use_ctl"]:
        logging.error(
            "Custom training loop must be used when using TPUStrategy.")
        return

    batch_size = params["batch_size"]
    time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps)
    callbacks = [time_callback]

    producer, input_meta_data = None, None
    generate_input_online = params["train_dataset_path"] is None

    if generate_input_online:
        # Start data producing thread.
        num_users, num_items, _, _, producer = ncf_common.get_inputs(params)
        producer.start()
        per_epoch_callback = IncrementEpochCallback(producer)
        callbacks.append(per_epoch_callback)
    else:
        assert params["eval_dataset_path"] and params["input_meta_data_path"]
        with tf.io.gfile.GFile(params["input_meta_data_path"], "rb") as reader:
            input_meta_data = json.loads(reader.read().decode("utf-8"))
            num_users = input_meta_data["num_users"]
            num_items = input_meta_data["num_items"]

    params["num_users"], params["num_items"] = num_users, num_items

    if FLAGS.early_stopping:
        early_stopping_callback = CustomEarlyStopping(
            "val_HR_METRIC", desired_value=FLAGS.hr_threshold)
        callbacks.append(early_stopping_callback)

    with tf.device(tpu_lib.get_primary_cpu_task(params["use_tpu"])):
        (train_input_dataset, eval_input_dataset,
         num_train_steps, num_eval_steps) = \
          (ncf_input_pipeline.create_ncf_input_data(
              params, producer, input_meta_data, strategy))
        steps_per_epoch = None if generate_input_online else num_train_steps

        with distribution_utils.get_strategy_scope(strategy):
            keras_model = _get_keras_model(params)
            optimizer = tf.keras.optimizers.Adam(
                learning_rate=params["learning_rate"],
                beta_1=params["beta1"],
                beta_2=params["beta2"],
                epsilon=params["epsilon"])

            if params["keras_use_ctl"]:
                train_loss, eval_results = run_ncf_custom_training(
                    params,
                    strategy,
                    keras_model,
                    optimizer,
                    callbacks,
                    train_input_dataset,
                    eval_input_dataset,
                    num_train_steps,
                    num_eval_steps,
                    generate_input_online=generate_input_online)
            else:
                # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
                # a valid arg for this model. Also remove as a valid flag.
                if FLAGS.force_v2_in_keras_compile is not None:
                    keras_model.compile(optimizer=optimizer,
                                        run_eagerly=FLAGS.run_eagerly,
                                        experimental_run_tf_function=FLAGS.
                                        force_v2_in_keras_compile)
                else:
                    keras_model.compile(optimizer=optimizer,
                                        run_eagerly=FLAGS.run_eagerly)

                history = keras_model.fit(train_input_dataset,
                                          epochs=FLAGS.train_epochs,
                                          steps_per_epoch=steps_per_epoch,
                                          callbacks=callbacks,
                                          validation_data=eval_input_dataset,
                                          validation_steps=num_eval_steps,
                                          verbose=2)

                logging.info("Training done. Start evaluating")

                eval_results = keras_model.evaluate(eval_input_dataset,
                                                    steps=num_eval_steps,
                                                    verbose=2)

                logging.info("Keras evaluation is done.")

                if history and history.history:
                    train_history = history.history
                    train_loss = train_history["loss"][-1]

        stats = build_stats(train_loss, eval_results, time_callback)
        return stats
Esempio n. 18
0
def run_ncf(_):
    """Run NCF training and eval with Keras."""

    keras_utils.set_session_config(enable_xla=FLAGS.enable_xla)

    if FLAGS.seed is not None:
        print("Setting tf seed")
        tf.random.set_seed(FLAGS.seed)

    # TODO(seemuch): Support different train and eval batch sizes
    if FLAGS.eval_batch_size != FLAGS.batch_size:
        logging.warning(
            "The Keras implementation of NCF currently does not support batch_size "
            "!= eval_batch_size ({} vs. {}). Overriding eval_batch_size to match "
            "batch_size".format(FLAGS.eval_batch_size, FLAGS.batch_size))
        FLAGS.eval_batch_size = FLAGS.batch_size

    params = ncf_common.parse_flags(FLAGS)
    model_helpers.apply_clean(flags.FLAGS)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.distribution_strategy,
        num_gpus=FLAGS.num_gpus)
    params["distribute_strategy"] = strategy

    if not keras_utils.is_v2_0() and strategy is not None:
        logging.error(
            "NCF Keras only works with distribution strategy in TF 2.0")
        return

    if (params["keras_use_ctl"]
            and (not keras_utils.is_v2_0() or strategy is None)):
        logging.error(
            "Custom training loop only works with tensorflow 2.0 and dist strat."
        )
        return

    # ncf_common rounds eval_batch_size (this is needed due to a reshape during
    # eval). This carries over that rounding to batch_size as well. This is the
    # per device batch size
    params["batch_size"] = params["eval_batch_size"]
    batch_size = params["batch_size"]

    time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps)
    callbacks = [time_callback]

    producer, input_meta_data = None, None
    generate_input_online = params["train_dataset_path"] is None

    if generate_input_online:
        # Start data producing thread.
        num_users, num_items, num_train_steps, num_eval_steps, producer = (
            ncf_common.get_inputs(params))
        producer.start()
        per_epoch_callback = IncrementEpochCallback(producer)
        callbacks.append(per_epoch_callback)
    else:
        assert params["eval_dataset_path"] and params["input_meta_data_path"]
        with tf.io.gfile.GFile(params["input_meta_data_path"], "rb") as reader:
            input_meta_data = json.loads(reader.read().decode("utf-8"))
            num_users = input_meta_data["num_users"]
            num_items = input_meta_data["num_items"]

    params["num_users"], params["num_items"] = num_users, num_items
    (train_input_dataset, eval_input_dataset, num_train_steps, num_eval_steps) = \
        (ncf_input_pipeline.create_ncf_input_data(
            params, producer, input_meta_data))
    steps_per_epoch = None if generate_input_online else num_train_steps

    if FLAGS.early_stopping:
        early_stopping_callback = CustomEarlyStopping(
            "val_HR_METRIC", desired_value=FLAGS.hr_threshold)
        callbacks.append(early_stopping_callback)
    with distribution_utils.get_strategy_scope(strategy):
        keras_model = _get_keras_model(params)
        optimizer = tf.keras.optimizers.Adam(
            learning_rate=params["learning_rate"],
            beta_1=params["beta1"],
            beta_2=params["beta2"],
            epsilon=params["epsilon"])

    if params["keras_use_ctl"]:
        loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
            reduction="sum", from_logits=True)
        train_input_iterator = strategy.make_dataset_iterator(
            train_input_dataset)
        eval_input_iterator = strategy.make_dataset_iterator(
            eval_input_dataset)

        @tf.function
        def train_step():
            """Called once per step to train the model."""
            def step_fn(features):
                """Computes loss and applied gradient per replica."""
                with tf.GradientTape() as tape:
                    softmax_logits = keras_model(features)
                    labels = features[rconst.TRAIN_LABEL_KEY]
                    loss = loss_object(
                        labels,
                        softmax_logits,
                        sample_weight=features[rconst.VALID_POINT_MASK])
                    loss *= (1.0 /
                             (batch_size * strategy.num_replicas_in_sync))

                grads = tape.gradient(loss, keras_model.trainable_variables)
                # Converting gradients to dense form helps in perf on GPU for NCF
                grads = neumf_model.sparse_to_dense_grads(
                    list(zip(grads, keras_model.trainable_variables)))
                optimizer.apply_gradients(grads)
                return loss

            per_replica_losses = strategy.experimental_run(
                step_fn, train_input_iterator)
            mean_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                        per_replica_losses,
                                        axis=None)
            return mean_loss

        @tf.function
        def eval_step():
            """Called once per eval step to compute eval metrics."""
            def step_fn(features):
                """Computes eval metrics per replica."""
                softmax_logits = keras_model(features)
                in_top_k, metric_weights = metric_fn(
                    softmax_logits, features[rconst.DUPLICATE_MASK], params)
                hr_sum = tf.reduce_sum(in_top_k * metric_weights)
                hr_count = tf.reduce_sum(metric_weights)
                return hr_sum, hr_count

            per_replica_hr_sum, per_replica_hr_count = (
                strategy.experimental_run(step_fn, eval_input_iterator))
            hr_sum = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                     per_replica_hr_sum,
                                     axis=None)
            hr_count = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                       per_replica_hr_count,
                                       axis=None)
            return hr_sum, hr_count

        time_callback.on_train_begin()
        for epoch in range(FLAGS.train_epochs):
            for cb in callbacks:
                cb.on_epoch_begin(epoch)

            # As NCF dataset is sampled with randomness, not repeating
            # data elements in each epoch has significant impact on
            # convergence. As so, offline-generated TF record files
            # contains all epoch worth of data. Thus we do not need
            # to initialize dataset when reading from tf record files.
            if generate_input_online:
                train_input_iterator.initialize()

            train_loss = 0
            for step in range(num_train_steps):
                time_callback.on_batch_begin(step + epoch * num_train_steps)
                train_loss += train_step()
                time_callback.on_batch_end(step + epoch * num_train_steps)
            train_loss /= num_train_steps
            logging.info("Done training epoch %s, epoch loss=%s.", epoch + 1,
                         train_loss)
            eval_input_iterator.initialize()
            hr_sum = 0
            hr_count = 0
            for _ in range(num_eval_steps):
                step_hr_sum, step_hr_count = eval_step()
                hr_sum += step_hr_sum
                hr_count += step_hr_count
            logging.info("Done eval epoch %s, hr=%s.", epoch + 1,
                         hr_sum / hr_count)

            if (FLAGS.early_stopping
                    and float(hr_sum / hr_count) > params["hr_threshold"]):
                break

        time_callback.on_train_end()
        eval_results = [None, hr_sum / hr_count]

    else:
        with distribution_utils.get_strategy_scope(strategy):
            # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
            # a valid arg for this model. Also remove as a valid flag.
            if FLAGS.force_v2_in_keras_compile is not None:
                keras_model.compile(optimizer=optimizer,
                                    run_eagerly=FLAGS.run_eagerly,
                                    experimental_run_tf_function=FLAGS.
                                    force_v2_in_keras_compile)
            else:
                keras_model.compile(optimizer=optimizer,
                                    run_eagerly=FLAGS.run_eagerly)

            history = keras_model.fit(train_input_dataset,
                                      epochs=FLAGS.train_epochs,
                                      steps_per_epoch=steps_per_epoch,
                                      callbacks=callbacks,
                                      validation_data=eval_input_dataset,
                                      validation_steps=num_eval_steps,
                                      verbose=2)

            logging.info("Training done. Start evaluating")

            eval_results = keras_model.evaluate(eval_input_dataset,
                                                steps=num_eval_steps,
                                                verbose=2)

            logging.info("Keras evaluation is done.")

        if history and history.history:
            train_history = history.history
            train_loss = train_history["loss"][-1]

    stats = build_stats(train_loss, eval_results, time_callback)
    return stats
Esempio n. 19
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using custom training loops.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)
    performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))

    # This only affects GPU.
    common.set_cudnn_batchnorm_mode()

    # TODO(anj-s): Set data_format without using Keras.
    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.config.list_physical_devices('GPU') else
                       'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
        flags_obj)
    steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps)

    logging.info(
        'Training %d epochs, each epoch has %d steps, '
        'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
        train_epochs * per_epoch_steps, eval_steps)

    time_callback = keras_utils.TimeHistory(
        flags_obj.batch_size,
        flags_obj.log_steps,
        logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
    with distribution_utils.get_strategy_scope(strategy):
        runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
                                                  per_epoch_steps)

    eval_interval = flags_obj.epochs_between_evals * per_epoch_steps
    checkpoint_interval = (per_epoch_steps
                           if flags_obj.enable_checkpoint_and_export else None)
    summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None

    checkpoint_manager = tf.train.CheckpointManager(
        runnable.checkpoint,
        directory=flags_obj.model_dir,
        max_to_keep=10,
        step_counter=runnable.global_step,
        checkpoint_interval=checkpoint_interval)

    resnet_controller = controller.Controller(
        strategy,
        runnable.train,
        runnable.evaluate,
        global_step=runnable.global_step,
        steps_per_loop=steps_per_loop,
        train_steps=per_epoch_steps * train_epochs,
        checkpoint_manager=checkpoint_manager,
        summary_interval=summary_interval,
        eval_steps=eval_steps,
        eval_interval=eval_interval)

    time_callback.on_train_begin()
    resnet_controller.train(evaluate=not flags_obj.skip_eval)
    time_callback.on_train_end()

    stats = build_stats(runnable, time_callback)
    return stats
Esempio n. 20
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using custom training loops.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # TODO(anj-s): Set data_format without using Keras.
    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=distribution_utils.configure_cluster(),
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    train_ds, test_ds = get_input_dataset(flags_obj, strategy)
    train_steps, train_epochs, eval_steps = get_num_train_iterations(flags_obj)

    time_callback = keras_utils.TimeHistory(flags_obj.batch_size,
                                            flags_obj.log_steps)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)
    with strategy_scope:
        model = resnet_model.resnet50(
            num_classes=imagenet_preprocessing.NUM_CLASSES,
            batch_size=flags_obj.batch_size,
            use_l2_regularizer=not flags_obj.single_l2_loss_op)

        optimizer = tf.keras.optimizers.SGD(
            learning_rate=common.BASE_LEARNING_RATE,
            momentum=0.9,
            nesterov=True)

        training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'training_accuracy', dtype=tf.float32)
        test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
        test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'test_accuracy', dtype=tf.float32)

        trainable_variables = model.trainable_variables

        def train_step(train_ds_inputs):
            """Training StepFn."""
            def step_fn(inputs):
                """Per-Replica StepFn."""
                images, labels = inputs
                with tf.GradientTape() as tape:
                    logits = model(images, training=True)

                    prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
                        labels, logits)
                    loss = tf.reduce_sum(prediction_loss) * (
                        1.0 / flags_obj.batch_size)
                    num_replicas = tf.distribute.get_strategy(
                    ).num_replicas_in_sync

                    if flags_obj.single_l2_loss_op:
                        filtered_variables = [
                            tf.reshape(v, (-1, )) for v in trainable_variables
                            if 'bn' not in v.name
                        ]
                        l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss(
                            tf.concat(filtered_variables, axis=0))
                        loss += (l2_loss / num_replicas)
                    else:
                        loss += (tf.reduce_sum(model.losses) / num_replicas)
                grads = tape.gradient(loss, trainable_variables)
                optimizer.apply_gradients(zip(grads, trainable_variables))

                training_accuracy.update_state(labels, logits)
                return loss

            if strategy:
                per_replica_losses = strategy.experimental_run_v2(
                    step_fn, args=(train_ds_inputs, ))
                return strategy.reduce(tf.distribute.ReduceOp.SUM,
                                       per_replica_losses,
                                       axis=None)
            else:
                return step_fn(train_ds_inputs)

        def test_step(test_ds_inputs):
            """Evaluation StepFn."""
            def step_fn(inputs):
                images, labels = inputs
                logits = model(images, training=False)
                loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, logits)
                loss = tf.reduce_sum(loss) * (1.0 / flags_obj.batch_size)
                test_loss.update_state(loss)
                test_accuracy.update_state(labels, logits)

            if strategy:
                strategy.experimental_run_v2(step_fn, args=(test_ds_inputs, ))
            else:
                step_fn(test_ds_inputs)

        if flags_obj.use_tf_function:
            train_step = tf.function(train_step)
            test_step = tf.function(test_step)

        time_callback.on_train_begin()
        for epoch in range(train_epochs):

            train_iter = iter(train_ds)
            total_loss = 0.0
            training_accuracy.reset_states()

            for step in range(train_steps):
                optimizer.lr = resnet_imagenet_main.learning_rate_schedule(
                    epoch, step, train_steps, flags_obj.batch_size)

                time_callback.on_batch_begin(step + epoch * train_steps)
                total_loss += train_step(next(train_iter))
                time_callback.on_batch_end(step + epoch * train_steps)

            train_loss = total_loss / train_steps
            logging.info('Training loss: %s, accuracy: %s%% at epoch: %d',
                         train_loss.numpy(),
                         training_accuracy.result().numpy(), epoch)

            if (not flags_obj.skip_eval
                    and (epoch + 1) % flags_obj.epochs_between_evals == 0):
                test_loss.reset_states()
                test_accuracy.reset_states()

                test_iter = iter(test_ds)
                for _ in range(eval_steps):
                    test_step(next(test_iter))

                logging.info('Test loss: %s, accuracy: %s%% at epoch: %d',
                             test_loss.result().numpy(),
                             test_accuracy.result().numpy(), epoch)

        time_callback.on_train_end()

        eval_result = None
        train_result = None
        if not flags_obj.skip_eval:
            eval_result = [
                test_loss.result().numpy(),
                test_accuracy.result().numpy()
            ]
            train_result = [
                train_loss.numpy(),
                training_accuracy.result().numpy()
            ]

        stats = build_stats(train_result, eval_result, time_callback)
        return stats
Esempio n. 21
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        common.set_gpu_thread_mode_and_count(flags_obj)
    if flags_obj.data_delay_prefetch:
        common.data_delay_prefetch()
    common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'float16':
        loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
        policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
            'mixed_float16', loss_scale=loss_scale)
        tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
        if not keras_utils.is_v2_0():
            raise ValueError('--dtype=fp16 is not supported in TensorFlow 1.')

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    # Configures cluster spec for distribution strategy.
    num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                                       flags_obj.task_index)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=num_workers,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    # pylint: disable=protected-access
    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = common.get_synth_input_fn(
            height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            num_channels=imagenet_preprocessing.NUM_CHANNELS,
            num_classes=imagenet_preprocessing.NUM_CLASSES,
            dtype=dtype,
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = imagenet_preprocessing.input_fn

    # When `enable_xla` is True, we always drop the remainder of the batches
    # in the dataset, as XLA-GPU doesn't support dynamic shapes.
    drop_remainder = flags_obj.enable_xla

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=imagenet_preprocessing.parse_record,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        drop_remainder=drop_remainder,
        tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
    )

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=flags_obj.batch_size,
            num_epochs=flags_obj.train_epochs,
            parse_record_fn=imagenet_preprocessing.parse_record,
            dtype=dtype,
            drop_remainder=drop_remainder)

    lr_schedule = 0.1
    if flags_obj.use_tensor_lr:
        lr_schedule = common.PiecewiseConstantDecayWithWarmup(
            batch_size=flags_obj.batch_size,
            epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
            warmup_epochs=LR_SCHEDULE[0][1],
            boundaries=list(p[1] for p in LR_SCHEDULE[1:]),
            multipliers=list(p[0] for p in LR_SCHEDULE),
            compute_lr_on_cpu=True)

    with strategy_scope:
        optimizer = common.get_optimizer(lr_schedule)
        if flags_obj.fp16_implementation == 'graph_rewrite':
            # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
            # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
            # which will ensure tf.compat.v2.keras.mixed_precision and
            # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
            # up.
            optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                optimizer)

        # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
        if flags_obj.use_trivial_model:
            model = trivial_model.trivial_model(
                imagenet_preprocessing.NUM_CLASSES)
        else:
            model = resnet_model.resnet50(
                num_classes=imagenet_preprocessing.NUM_CLASSES)

        # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
        # a valid arg for this model. Also remove as a valid flag.
        if flags_obj.force_v2_in_keras_compile is not None:
            model.compile(
                loss='sparse_categorical_crossentropy',
                optimizer=optimizer,
                metrics=(['sparse_categorical_accuracy']
                         if flags_obj.report_accuracy_metrics else None),
                run_eagerly=flags_obj.run_eagerly,
                experimental_run_tf_function=flags_obj.
                force_v2_in_keras_compile)
        else:
            model.compile(
                loss='sparse_categorical_crossentropy',
                optimizer=optimizer,
                metrics=(['sparse_categorical_accuracy']
                         if flags_obj.report_accuracy_metrics else None),
                run_eagerly=flags_obj.run_eagerly)

    callbacks = common.get_callbacks(
        learning_rate_schedule, imagenet_preprocessing.NUM_IMAGES['train'])

    train_steps = (imagenet_preprocessing.NUM_IMAGES['train'] //
                   flags_obj.batch_size)
    train_epochs = flags_obj.train_epochs

    if flags_obj.train_steps:
        train_steps = min(flags_obj.train_steps, train_steps)
        train_epochs = 1

    num_eval_steps = (imagenet_preprocessing.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        # Only build the training graph. This reduces memory usage introduced by
        # control flow ops in layers that have different implementations for
        # training and inference (e.g., batch norm).
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    if not strategy and flags_obj.explicit_gpu_placement:
        # TODO(b/135607227): Add device scope automatically in Keras training loop
        # when not using distribition strategy.
        no_dist_strat_device = tf.device('/device:GPU:0')
        no_dist_strat_device.__enter__()

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)

    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

    if not strategy and flags_obj.explicit_gpu_placement:
        no_dist_strat_device.__exit__()

    stats = common.build_stats(history, eval_output, callbacks)
    return stats
Esempio n. 22
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using custom training loops.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_xla=flags_obj.enable_xla)
    performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))

    if tf.config.list_physical_devices('GPU'):
        if flags_obj.tf_gpu_thread_mode:
            keras_utils.set_gpu_thread_mode_and_count(
                per_gpu_thread_count=flags_obj.per_gpu_thread_count,
                gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
                num_gpus=flags_obj.num_gpus,
                datasets_num_private_threads=flags_obj.
                datasets_num_private_threads)
        common.set_cudnn_batchnorm_mode()

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.config.list_physical_devices('GPU') else
                       'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
        flags_obj)
    if flags_obj.steps_per_loop is None:
        steps_per_loop = per_epoch_steps
    elif flags_obj.steps_per_loop > per_epoch_steps:
        steps_per_loop = per_epoch_steps
        logging.warn('Setting steps_per_loop to %d to respect epoch boundary.',
                     steps_per_loop)
    else:
        steps_per_loop = flags_obj.steps_per_loop

    logging.info(
        'Training %d epochs, each epoch has %d steps, '
        'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
        train_epochs * per_epoch_steps, eval_steps)

    time_callback = keras_utils.TimeHistory(
        flags_obj.batch_size,
        flags_obj.log_steps,
        logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
    with distribution_utils.get_strategy_scope(strategy):
        runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
                                                  per_epoch_steps)

    eval_interval = flags_obj.epochs_between_evals * per_epoch_steps
    checkpoint_interval = (steps_per_loop * 5
                           if flags_obj.enable_checkpoint_and_export else None)
    summary_interval = steps_per_loop if flags_obj.enable_tensorboard else None

    checkpoint_manager = tf.train.CheckpointManager(
        runnable.checkpoint,
        directory=flags_obj.model_dir,
        max_to_keep=10,
        step_counter=runnable.global_step,
        checkpoint_interval=checkpoint_interval)

    resnet_controller = orbit.Controller(
        strategy,
        runnable,
        runnable if not flags_obj.skip_eval else None,
        global_step=runnable.global_step,
        steps_per_loop=steps_per_loop,
        checkpoint_manager=checkpoint_manager,
        summary_interval=summary_interval,
        summary_dir=flags_obj.model_dir,
        eval_summary_dir=os.path.join(flags_obj.model_dir, 'eval'))

    time_callback.on_train_begin()
    if not flags_obj.skip_eval:
        resnet_controller.train_and_evaluate(train_steps=per_epoch_steps *
                                             train_epochs,
                                             eval_steps=eval_steps,
                                             eval_interval=eval_interval)
    else:
        resnet_controller.train(steps=per_epoch_steps * train_epochs)
    time_callback.on_train_end()

    stats = build_stats(runnable, time_callback)
    return stats
Esempio n. 23
0
def train_and_eval(
        params: base_configs.ExperimentConfig,
        strategy_override: tf.distribute.Strategy) -> Mapping[str, Any]:
    """Runs the train and eval path using compile/fit."""
    logging.info('Running train and eval.')

    # Note: for TPUs, strategy and scope should be created before the dataset
    strategy = strategy_override or distribution_utils.get_distribution_strategy(
        distribution_strategy=params.runtime.distribution_strategy,
        all_reduce_alg=params.runtime.all_reduce_alg,
        num_gpus=params.runtime.num_gpus,
        tpu_address=params.runtime.tpu)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    logging.info('Detected %d devices.',
                 strategy.num_replicas_in_sync if strategy else 1)

    label_smoothing = params.model.loss.label_smoothing
    one_hot = label_smoothing and label_smoothing > 0

    builders = _get_dataset_builders(params, strategy, one_hot)
    datasets = [
        builder.build(strategy) if builder else None for builder in builders
    ]

    # Unpack datasets and builders based on train/val/test splits
    train_builder, validation_builder = builders  # pylint: disable=unbalanced-tuple-unpacking
    train_dataset, validation_dataset = datasets

    train_epochs = params.train.epochs
    train_steps = params.train.steps or train_builder.num_steps
    validation_steps = params.evaluation.steps or validation_builder.num_steps

    initialize(params, train_builder)

    logging.info('Global batch size: %d', train_builder.global_batch_size)

    with strategy_scope:
        model_params = params.model.model_params.as_dict()
        model = get_models()[params.model.name](**model_params)
        learning_rate = optimizer_factory.build_learning_rate(
            params=params.model.learning_rate,
            batch_size=train_builder.global_batch_size,
            train_steps=train_steps)
        optimizer = optimizer_factory.build_optimizer(
            optimizer_name=params.model.optimizer.name,
            base_learning_rate=learning_rate,
            params=params.model.optimizer.as_dict())

        metrics_map = _get_metrics(one_hot)
        metrics = [metrics_map[metric] for metric in params.train.metrics]

        if one_hot:
            loss_obj = tf.keras.losses.CategoricalCrossentropy(
                label_smoothing=params.model.loss.label_smoothing)
        else:
            loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
        model.compile(
            optimizer=optimizer,
            loss=loss_obj,
            metrics=metrics,
            experimental_steps_per_execution=params.train.steps_per_loop)

        initial_epoch = 0
        if params.train.resume_checkpoint:
            initial_epoch = resume_from_checkpoint(model=model,
                                                   model_dir=params.model_dir,
                                                   train_steps=train_steps)

        callbacks = custom_callbacks.get_callbacks(
            model_checkpoint=params.train.callbacks.
            enable_checkpoint_and_export,
            include_tensorboard=params.train.callbacks.enable_tensorboard,
            time_history=params.train.callbacks.enable_time_history,
            track_lr=params.train.tensorboard.track_lr,
            write_model_weights=params.train.tensorboard.write_model_weights,
            initial_step=initial_epoch * train_steps,
            batch_size=train_builder.global_batch_size,
            log_steps=params.train.time_history.log_steps,
            model_dir=params.model_dir)

    serialize_config(params=params, model_dir=params.model_dir)

    if params.evaluation.skip_eval:
        validation_kwargs = {}
    else:
        validation_kwargs = {
            'validation_data': validation_dataset,
            'validation_steps': validation_steps,
            'validation_freq': params.evaluation.epochs_between_evals,
        }

    history = model.fit(train_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        initial_epoch=initial_epoch,
                        callbacks=callbacks,
                        verbose=2,
                        **validation_kwargs)

    validation_output = None
    if not params.evaluation.skip_eval:
        validation_output = model.evaluate(validation_dataset,
                                           steps=validation_steps,
                                           verbose=2)

    # TODO(dankondratyuk): eval and save final test accuracy
    stats = common.build_stats(history, validation_output, callbacks)
    return stats
Esempio n. 24
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
    NotImplementedError: If some features are not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    try:
        _cudart = ctypes.CDLL('libcudart.so')
    except:
        _cudart = None

    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_utils.set_gpu_thread_mode_and_count(
            per_gpu_thread_count=flags_obj.per_gpu_thread_count,
            gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
            num_gpus=flags_obj.num_gpus,
            datasets_num_private_threads=flags_obj.datasets_num_private_threads
        )
    common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    performance.set_mixed_precision_policy(
        flags_core.get_tf_dtype(flags_obj),
        flags_core.get_loss_scale(flags_obj, default_for_fp16=128))

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    # Configures cluster spec for distribution strategy.
    _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                             flags_obj.task_index)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    # Current resnet_model.resnet50 input format is always channel-last.
    # We use keras_application mobilenet model which input format is depends on
    # the keras beckend image data format.
    # This use_keras_image_data_format flags indicates whether image preprocessor
    # output format should be same as the keras backend image data format or just
    # channel-last format.
    use_keras_image_data_format = flags_obj.keras_application_models

    preproccessing_type = imagenet_preprocessing if flags_obj.dataset == "imagenet" else cifar_preprocessing

    input_shape = (preproccessing_type.HEIGHT, preproccessing_type.WIDTH, \
       preproccessing_type.NUM_CHANNELS)

    if use_keras_image_data_format:
        if tf.keras.backend.image_data_format() == 'channels_first':
            input_shape = (preproccessing_type.NUM_CHANNELS, preproccessing_type.HEIGHT, \
               preproccessing_type.WIDTH)

    # pylint: disable=protected-access
    if flags_obj.use_synthetic_data:
        assert flags_obj.dataset == "imagenet", \
           f"Expect to only work with ImageNet, but have {flags_obj.dataset}"
        distribution_utils.set_up_synthetic_data()
        input_fn = common.get_synth_input_fn(
            height=preproccessing_type.DEFAULT_IMAGE_SIZE,
            width=preproccessing_type.DEFAULT_IMAGE_SIZE,
            num_channels=preproccessing_type.NUM_CHANNELS,
            num_classes=preproccessing_type.NUM_CLASSES,
            use_keras_image_data_format=use_keras_image_data_format,
            dtype=dtype,
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = preproccessing_type.input_fn

    # When `enable_xla` is True, we always drop the remainder of the batches
    # in the dataset, as XLA-GPU doesn't support dynamic shapes.
    drop_remainder = flags_obj.enable_xla

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        parse_record_fn=preproccessing_type.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        drop_remainder=drop_remainder,
        tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
        training_dataset_cache=flags_obj.training_dataset_cache,
    )

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=flags_obj.batch_size,
            parse_record_fn=preproccessing_type.get_parse_record_fn(
                use_keras_image_data_format=use_keras_image_data_format),
            dtype=dtype,
            drop_remainder=drop_remainder)

    lr_schedule = common.PiecewiseConstantDecayWithWarmup(
        batch_size=flags_obj.batch_size,
        epoch_size=preproccessing_type.NUM_IMAGES['train'],
        warmup_epochs=common.LR_SCHEDULE[0][1],
        boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
        multipliers=list(p[0] for p in common.LR_SCHEDULE),
        compute_lr_on_cpu=True)
    steps_per_epoch = (preproccessing_type.NUM_IMAGES['train'] //
                       flags_obj.batch_size)

    with strategy_scope:
        if flags_obj.optimizer == 'resnet50_default':
            optimizer = common.get_optimizer(lr_schedule)
        elif flags_obj.optimizer == 'mobilenet_default':
            initial_learning_rate = \
                flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
            optimizer = tf.keras.optimizers.SGD(
                learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
                    initial_learning_rate,
                    decay_steps=steps_per_epoch *
                    flags_obj.num_epochs_per_decay,
                    decay_rate=flags_obj.lr_decay_factor,
                    staircase=True),
                momentum=0.9)
        if flags_obj.fp16_implementation == 'graph_rewrite':
            # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
            # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
            # which will ensure tf.compat.v2.keras.mixed_precision and
            # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
            # up.
            optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                optimizer)

        # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
        if flags_obj.use_trivial_model:
            model = trivial_model.trivial_model(
                preproccessing_type.NUM_CLASSES)
        elif flags_obj.model == 'resnet50_v1.5':
            resnet_model.change_keras_layer(flags_obj.use_tf_keras_layers)
            model = resnet_model.resnet50(
                num_classes=preproccessing_type.NUM_CLASSES,
                input_shape=input_shape)
        elif flags_obj.model == 'mobilenet':
            # TODO(kimjaehong): Remove layers attribute when minimum TF version
            # support 2.0 layers by default.
            model = tf.keras.applications.mobilenet.MobileNet(
                weights=None,
                classes=preproccessing_type.NUM_CLASSES,
                input_shape=input_shape,
                layers=tf.keras.layers)
        elif flags_obj.keras_application_models:
            model_kfn = keras_app_models.get(flags_obj.model, None)
            if model_kfn is None:
                raise ValueError("No keras application model name %s" %
                                 flags_obj.model)
            model = model_kfn(weights=None,
                              input_shape=input_shape,
                              classes=preproccessing_type.NUM_CLASSES)
        if flags_obj.pretrained_filepath:
            model.load_weights(flags_obj.pretrained_filepath)

        if flags_obj.pruning_method == 'polynomial_decay':
            if dtype != tf.float32:
                raise NotImplementedError(
                    'Pruning is currently only supported on dtype=tf.float32.')
            pruning_params = {
                'pruning_schedule':
                tfmot.sparsity.keras.PolynomialDecay(
                    initial_sparsity=flags_obj.pruning_initial_sparsity,
                    final_sparsity=flags_obj.pruning_final_sparsity,
                    begin_step=flags_obj.pruning_begin_step,
                    end_step=flags_obj.pruning_end_step,
                    frequency=flags_obj.pruning_frequency),
            }
            model = tfmot.sparsity.keras.prune_low_magnitude(
                model, **pruning_params)
        elif flags_obj.pruning_method:
            raise NotImplementedError(
                'Only polynomial_decay is currently supported.')

        model.compile(loss='sparse_categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=(['sparse_categorical_accuracy']
                               if flags_obj.report_accuracy_metrics else None),
                      run_eagerly=flags_obj.run_eagerly)

    train_epochs = flags_obj.train_epochs

    callbacks = common.get_callbacks(
        steps_per_epoch=steps_per_epoch,
        pruning_method=flags_obj.pruning_method,
        enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
        model_dir=flags_obj.model_dir)

    # if mutliple epochs, ignore the train_steps flag.
    if train_epochs <= 1 and flags_obj.train_steps:
        steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
        train_epochs = 1

    num_eval_steps = (preproccessing_type.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        # Only build the training graph. This reduces memory usage introduced by
        # control flow ops in layers that have different implementations for
        # training and inference (e.g., batch norm).
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    if not strategy and flags_obj.explicit_gpu_placement:
        # TODO(b/135607227): Add device scope automatically in Keras training loop
        # when not using distribition strategy.
        no_dist_strat_device = tf.device('/device:GPU:0')
        no_dist_strat_device.__enter__()

    if _cudart:
        cuda_status = _cudart.cudaProfilerStart()
    else:
        cuda_status = None
    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)
    if cuda_status == 0:
        _cudart.cudaProfilerStop()
    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

    if flags_obj.pruning_method:
        model = tfmot.sparsity.keras.strip_pruning(model)
    if flags_obj.enable_checkpoint_and_export:
        if dtype == tf.bfloat16:
            logging.warning(
                'Keras model.save does not support bfloat16 dtype.')
        else:
            # Keras model.save assumes a float32 input designature.
            export_path = os.path.join(flags_obj.model_dir, 'saved_model')
            model.save(export_path, include_optimizer=False)

    if not strategy and flags_obj.explicit_gpu_placement:
        no_dist_strat_device.__exit__()

    stats = common.build_stats(history, eval_output, callbacks)
    return stats
Esempio n. 25
0
def run(flags_obj):
    """Run ResNet Cifar-10 training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla,
                                   enable_grappler_layout_optimizer=flags_obj.
                                   enable_grappler_layout_optimizer)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_common.set_gpu_thread_mode_and_count(flags_obj)
    keras_common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'fp16':
        raise ValueError(
            'dtype fp16 is not supported in Keras. Use the default '
            'value(fp32).')

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=distribution_utils.configure_cluster(),
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = keras_common.get_synth_input_fn(
            height=cifar_preprocessing.HEIGHT,
            width=cifar_preprocessing.WIDTH,
            num_channels=cifar_preprocessing.NUM_CHANNELS,
            num_classes=cifar_preprocessing.NUM_CLASSES,
            dtype=flags_core.get_tf_dtype(flags_obj),
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = cifar_preprocessing.input_fn

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=cifar_preprocessing.parse_record,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        # Setting drop_remainder to avoid the partial batch logic in normalization
        # layer, which triggers tf.where and leads to extra memory copy of input
        # sizes between host and GPU.
        drop_remainder=(not flags_obj.enable_get_next_as_optional))

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=flags_obj.batch_size,
            num_epochs=flags_obj.train_epochs,
            parse_record_fn=cifar_preprocessing.parse_record)

    with strategy_scope:
        optimizer = keras_common.get_optimizer()
        model = resnet_cifar_model.resnet56(
            classes=cifar_preprocessing.NUM_CLASSES)

        model.compile(
            loss='categorical_crossentropy',
            optimizer=optimizer,
            metrics=(['categorical_accuracy']
                     if flags_obj.report_accuracy_metrics else None),
            run_eagerly=flags_obj.run_eagerly,
            experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)

    callbacks = keras_common.get_callbacks(
        learning_rate_schedule, cifar_preprocessing.NUM_IMAGES['train'])

    train_steps = cifar_preprocessing.NUM_IMAGES[
        'train'] // flags_obj.batch_size
    train_epochs = flags_obj.train_epochs

    if flags_obj.train_steps:
        train_steps = min(flags_obj.train_steps, train_steps)
        train_epochs = 1

    num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    if not strategy and flags_obj.explicit_gpu_placement:
        # TODO(b/135607227): Add device scope automatically in Keras training loop
        # when not using distribition strategy.
        no_dist_strat_device = tf.device('/device:GPU:0')
        no_dist_strat_device.__enter__()

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)
    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

    if not strategy and flags_obj.explicit_gpu_placement:
        no_dist_strat_device.__exit__()

    stats = keras_common.build_stats(history, eval_output, callbacks)
    return stats
Esempio n. 26
0
def run(flags_obj):
  """Run ResNet Cifar-10 training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                 enable_xla=flags_obj.enable_xla)

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == 'fp16':
    raise ValueError('dtype fp16 is not supported in Keras. Use the default '
                     'value(fp32).')

  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus)

  strategy_scope = distribution_utils.get_strategy_scope(strategy)

  if flags_obj.use_synthetic_data:
    distribution_utils.set_up_synthetic_data()
    input_fn = keras_common.get_synth_input_fn(
        height=cifar_main.HEIGHT,
        width=cifar_main.WIDTH,
        num_channels=cifar_main.NUM_CHANNELS,
        num_classes=cifar_main.NUM_CLASSES,
        dtype=flags_core.get_tf_dtype(flags_obj))
  else:
    distribution_utils.undo_set_up_synthetic_data()
    input_fn = cifar_main.input_fn

  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras)

  eval_input_dataset = input_fn(
      is_training=False,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras)

  with strategy_scope:
    optimizer = keras_common.get_optimizer()
    model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES)

    model.compile(loss='categorical_crossentropy',
                  optimizer=optimizer,
                  run_eagerly=flags_obj.run_eagerly,
                  metrics=['categorical_accuracy'])

  callbacks = keras_common.get_callbacks(
      learning_rate_schedule, cifar_main.NUM_IMAGES['train'])

  train_steps = cifar_main.NUM_IMAGES['train'] // flags_obj.batch_size
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

  num_eval_steps = (cifar_main.NUM_IMAGES['validation'] //
                    flags_obj.batch_size)

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
    num_eval_steps = None
    validation_data = None

  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
    # when not using distribition strategy.
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
                      callbacks=callbacks,
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
                      validation_freq=flags_obj.epochs_between_evals,
                      verbose=2)
  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

  stats = keras_common.build_stats(history, eval_output, callbacks)
  return stats
Esempio n. 27
0
def run_ncf(_):
    """Run NCF training and eval with Keras."""

    keras_utils.set_session_config(enable_xla=FLAGS.enable_xla)

    if FLAGS.seed is not None:
        print("Setting tf seed")
        tf.random.set_seed(FLAGS.seed)

    model_helpers.apply_clean(FLAGS)

    if FLAGS.dtype == "fp16" and FLAGS.fp16_implementation == "keras":
        policy = tf.keras.mixed_precision.experimental.Policy(
            "mixed_float16",
            loss_scale=flags_core.get_loss_scale(FLAGS,
                                                 default_for_fp16="dynamic"))
        tf.keras.mixed_precision.experimental.set_policy(policy)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.distribution_strategy,
        num_gpus=FLAGS.num_gpus,
        tpu_address=FLAGS.tpu)

    params = ncf_common.parse_flags(FLAGS)
    params["distribute_strategy"] = strategy

    if not keras_utils.is_v2_0() and strategy is not None:
        logging.error(
            "NCF Keras only works with distribution strategy in TF 2.0")
        return
    if (params["keras_use_ctl"]
            and (not keras_utils.is_v2_0() or strategy is None)):
        logging.error(
            "Custom training loop only works with tensorflow 2.0 and dist strat."
        )
        return
    if params["use_tpu"] and not params["keras_use_ctl"]:
        logging.error(
            "Custom training loop must be used when using TPUStrategy.")
        return

    batch_size = params["batch_size"]
    time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps)
    callbacks = [time_callback]

    producer, input_meta_data = None, None
    generate_input_online = params["train_dataset_path"] is None

    if generate_input_online:
        # Start data producing thread.
        num_users, num_items, _, _, producer = ncf_common.get_inputs(params)
        producer.start()
        per_epoch_callback = IncrementEpochCallback(producer)
        callbacks.append(per_epoch_callback)
    else:
        assert params["eval_dataset_path"] and params["input_meta_data_path"]
        with tf.io.gfile.GFile(params["input_meta_data_path"], "rb") as reader:
            input_meta_data = json.loads(reader.read().decode("utf-8"))
            num_users = input_meta_data["num_users"]
            num_items = input_meta_data["num_items"]

    params["num_users"], params["num_items"] = num_users, num_items

    if FLAGS.early_stopping:
        early_stopping_callback = CustomEarlyStopping(
            "val_HR_METRIC", desired_value=FLAGS.hr_threshold)
        callbacks.append(early_stopping_callback)

    (train_input_dataset, eval_input_dataset,
     num_train_steps, num_eval_steps) = \
      (ncf_input_pipeline.create_ncf_input_data(
          params, producer, input_meta_data, strategy))
    steps_per_epoch = None if generate_input_online else num_train_steps

    with distribution_utils.get_strategy_scope(strategy):
        keras_model = _get_keras_model(params)
        optimizer = tf.keras.optimizers.Adam(
            learning_rate=params["learning_rate"],
            beta_1=params["beta1"],
            beta_2=params["beta2"],
            epsilon=params["epsilon"])
        if FLAGS.fp16_implementation == "graph_rewrite":
            optimizer = \
              tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(
                  optimizer,
                  loss_scale=flags_core.get_loss_scale(FLAGS,
                                                       default_for_fp16="dynamic"))
        elif FLAGS.dtype == "fp16" and params["keras_use_ctl"]:
            # When keras_use_ctl is False, instead Model.fit() automatically applies
            # loss scaling so we don't need to create a LossScaleOptimizer.
            optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
                optimizer,
                tf.keras.mixed_precision.experimental.global_policy().
                loss_scale)

        if params["keras_use_ctl"]:
            train_loss, eval_results = run_ncf_custom_training(
                params,
                strategy,
                keras_model,
                optimizer,
                callbacks,
                train_input_dataset,
                eval_input_dataset,
                num_train_steps,
                num_eval_steps,
                generate_input_online=generate_input_online)
        else:
            # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
            # a valid arg for this model. Also remove as a valid flag.
            if FLAGS.force_v2_in_keras_compile is not None:
                keras_model.compile(optimizer=optimizer,
                                    run_eagerly=FLAGS.run_eagerly,
                                    experimental_run_tf_function=FLAGS.
                                    force_v2_in_keras_compile)
            else:
                keras_model.compile(optimizer=optimizer,
                                    run_eagerly=FLAGS.run_eagerly)

            if not FLAGS.ml_perf:
                # Create Tensorboard summary and checkpoint callbacks.
                summary_dir = os.path.join(FLAGS.model_dir, "summaries")
                summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
                checkpoint_path = os.path.join(FLAGS.model_dir, "checkpoint")
                checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
                    checkpoint_path, save_weights_only=True)

                callbacks += [summary_callback, checkpoint_callback]

            history = keras_model.fit(train_input_dataset,
                                      epochs=FLAGS.train_epochs,
                                      steps_per_epoch=steps_per_epoch,
                                      callbacks=callbacks,
                                      validation_data=eval_input_dataset,
                                      validation_steps=num_eval_steps,
                                      verbose=2)

            logging.info("Training done. Start evaluating")

            eval_loss_and_metrics = keras_model.evaluate(eval_input_dataset,
                                                         steps=num_eval_steps,
                                                         verbose=2)

            logging.info("Keras evaluation is done.")

            # Keras evaluate() API returns scalar loss and metric values from
            # evaluation as a list. Here, the returned list would contain
            # [evaluation loss, hr sum, hr count].
            eval_hit_rate = eval_loss_and_metrics[1] / eval_loss_and_metrics[2]

            # Format evaluation result into [eval loss, eval hit accuracy].
            eval_results = [eval_loss_and_metrics[0], eval_hit_rate]

            if history and history.history:
                train_history = history.history
                train_loss = train_history["loss"][-1]

    stats = build_stats(train_loss, eval_results, time_callback)
    return stats
Esempio n. 28
0
def run(flags_obj):
    """Run ResNet Cifar-10 training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_utils.set_gpu_thread_mode_and_count(
            per_gpu_thread_count=flags_obj.per_gpu_thread_count,
            gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
            num_gpus=flags_obj.num_gpus,
            datasets_num_private_threads=flags_obj.datasets_num_private_threads
        )
    common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'fp16':
        raise ValueError(
            'dtype fp16 is not supported in Keras. Use the default '
            'value(fp32).')

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=distribution_utils.configure_cluster(),
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = common.get_synth_input_fn(
            height=cifar_preprocessing.HEIGHT,
            width=cifar_preprocessing.WIDTH,
            num_channels=cifar_preprocessing.NUM_CHANNELS,
            num_classes=cifar_preprocessing.NUM_CLASSES,
            dtype=flags_core.get_tf_dtype(flags_obj),
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = cifar_preprocessing.input_fn

    #train_input_dataset = input_fn(
    #    is_training=True,
    #    data_dir=flags_obj.data_dir,
    #    batch_size=flags_obj.batch_size,
    #    num_epochs=flags_obj.train_epochs,
    #    parse_record_fn=cifar_preprocessing.parse_record,
    #    datasets_num_private_threads=flags_obj.datasets_num_private_threads,
    #    dtype=dtype,
    #    # Setting drop_remainder to avoid the partial batch logic in normalization
    #    # layer, which triggers tf.where and leads to extra memory copy of input
    #    # sizes between host and GPU.
    #    drop_remainder=(not flags_obj.enable_get_next_as_optional))

    # eval_input_dataset = None
    # if not flags_obj.skip_eval:
    #   eval_input_dataset = input_fn(
    #       is_training=False,
    #       data_dir=flags_obj.data_dir,
    #       batch_size=flags_obj.batch_size,
    #       num_epochs=flags_obj.train_epochs,
    #       parse_record_fn=cifar_preprocessing.parse_record)

    (x_train, y_train), (x_test,
                         y_test) = tf.keras.datasets.cifar10.load_data()
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255
    x_test /= 255
    y_train = tf.keras.utils.to_categorical(y_train, num_classes)
    y_test = tf.keras.utils.to_categorical(y_test, num_classes)

    # optimizer = common.get_optimizer()

    opt = tf.keras.optimizers.SGD(learning_rate=0.1)

    logging.info(opt.__dict__)
    optimizer = SynchronousSGDOptimizer(opt, use_locking=True)
    optimizer._hyper = opt._hyper

    logging.info(optimizer.__dict__)

    model = Conv4_model(x_train, num_classes)

    # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
    # a valid arg for this model. Also remove as a valid flag.
    if flags_obj.force_v2_in_keras_compile is not None:
        model.compile(
            loss='categorical_crossentropy',
            optimizer=optimizer,
            metrics=(['accuracy']),
            run_eagerly=flags_obj.run_eagerly,
            experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)
    else:
        model.compile(loss='categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=(['accuracy']),
                      run_eagerly=flags_obj.run_eagerly)

    cluster_size = current_cluster_size()
    steps_per_epoch = (cifar_preprocessing.NUM_IMAGES['train'] //
                       flags_obj.batch_size)
    steps_per_epoch = steps_per_epoch // cluster_size
    train_epochs = flags_obj.train_epochs

    callbacks = common.get_callbacks(steps_per_epoch, current_rank(),
                                     cluster_size, learning_rate_schedule)
    callbacks.append(BroadcastGlobalVariablesCallback())

    if flags_obj.train_steps:
        steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)

    num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    # validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    tf.compat.v1.logging.info(x_train.shape)
    history = model.fit(x_train,
                        y_train,
                        batch_size=flags_obj.batch_size,
                        epochs=train_epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=(x_test, y_test),
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)
    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate((x_test, y_test),
                                     steps=num_eval_steps,
                                     verbose=2)
    stats = common.build_stats(history, eval_output, callbacks)
    return stats
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla,
                                   enable_grappler_layout_optimizer=flags_obj.
                                   enable_grappler_layout_optimizer)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_common.set_gpu_thread_mode_and_count(flags_obj)
    if flags_obj.data_delay_prefetch:
        keras_common.data_delay_prefetch()
    keras_common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'float16':
        policy = tf.keras.mixed_precision.experimental.Policy(
            'infer_float32_vars')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        num_workers=distribution_utils.configure_cluster(),
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    # pylint: disable=protected-access
    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = keras_common.get_synth_input_fn(
            height=imagenet_main.DEFAULT_IMAGE_SIZE,
            width=imagenet_main.DEFAULT_IMAGE_SIZE,
            num_channels=imagenet_main.NUM_CHANNELS,
            num_classes=imagenet_main.NUM_CLASSES,
            dtype=dtype,
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = imagenet_main.input_fn

    # When `enable_xla` is True, we always drop the remainder of the batches
    # in the dataset, as XLA-GPU doesn't support dynamic shapes.
    drop_remainder = flags_obj.enable_xla

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=parse_record_keras,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        drop_remainder=drop_remainder,
        tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
    )

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(is_training=False,
                                      data_dir=flags_obj.data_dir,
                                      batch_size=flags_obj.batch_size,
                                      num_epochs=flags_obj.train_epochs,
                                      parse_record_fn=parse_record_keras,
                                      dtype=dtype,
                                      drop_remainder=drop_remainder)

    lr_schedule = 0.1
    if flags_obj.use_tensor_lr:
        lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup(
            batch_size=flags_obj.batch_size,
            epoch_size=imagenet_main.NUM_IMAGES['train'],
            warmup_epochs=LR_SCHEDULE[0][1],
            boundaries=list(p[1] for p in LR_SCHEDULE[1:]),
            multipliers=list(p[0] for p in LR_SCHEDULE),
            compute_lr_on_cpu=True)

    with strategy_scope:
        optimizer = keras_common.get_optimizer(lr_schedule)
        from tensorflow.keras.mixed_precision.experimental import LossScaleOptimizer
        optimizer = LossScaleOptimizer(optimizer, loss_scale=256)
        # if dtype == 'float16':
        # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
        # can be enabled with a single line of code.
        # optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
        # optimizer, loss_scale=flags_core.get_loss_scale(flags_obj,
        # default_for_fp16=128))

        if flags_obj.use_trivial_model:
            model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES,
                                                dtype)
        else:
            model = resnet_model.resnet18(
                num_classes=imagenet_main.NUM_CLASSES, dtype=dtype)
            # model = alex_net.alex_net(num_classes=imagenet_main.NUM_CLASSES, dtype=dtype)

        model.compile(loss='sparse_categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=(['sparse_categorical_accuracy']
                               if flags_obj.report_accuracy_metrics else None),
                      run_eagerly=flags_obj.run_eagerly,
                      run_distributed=flags_obj.force_v2_in_keras_compile)

    callbacks = keras_common.get_callbacks(learning_rate_schedule,
                                           imagenet_main.NUM_IMAGES['train'])
    from tensorflow.python.keras.callbacks import TensorBoard
    from time import time
    tensorboard = TensorBoard(log_dir="log/{}".format(time()))
    callbacks += [tensorboard]

    train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
    train_epochs = flags_obj.train_epochs

    if flags_obj.train_steps:
        train_steps = min(flags_obj.train_steps, train_steps)
        train_epochs = 1

    num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        # Only build the training graph. This reduces memory usage introduced by
        # control flow ops in layers that have different implementations for
        # training and inference (e.g., batch norm).
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    if not strategy and flags_obj.explicit_gpu_placement:
        # TODO(b/135607227): Add device scope automatically in Keras training loop
        # when not using distribition strategy.
        no_dist_strat_device = tf.device('/device:GPU:0')
        no_dist_strat_device.__enter__()

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)

    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

    if not strategy and flags_obj.explicit_gpu_placement:
        no_dist_strat_device.__exit__()

    stats = keras_common.build_stats(history, eval_output, callbacks)
    return stats
def run_customized_training_loop(
        # pylint: disable=invalid-name
        _sentinel=None,
        # pylint: enable=invalid-name
        strategy=None,
        model_fn=None,
        loss_fn=None,
        model_dir=None,
        train_input_fn=None,
        steps_per_epoch=None,
        steps_per_loop=1,
        epochs=1,
        eval_input_fn=None,
        eval_steps=None,
        metric_fn=None,
        init_checkpoint=None,
        use_remote_tpu=False,
        custom_callbacks=None,
        run_eagerly=False):
    """Run BERT pretrain model training using low-level API.

  Arguments:
      _sentinel: Used to prevent positional parameters. Internal, do not use.
      strategy: Distribution strategy on which to run low level training loop.
      model_fn: Function that returns a tuple (model, sub_model). Caller of this
        function should add optimizer to the `model` via calling
        `model.compile()` API or manually setting `model.optimizer` attribute.
        Second element of the returned tuple(sub_model) is an optional sub model
        to be used for initial checkpoint -- if provided.
      loss_fn: Function with signature func(labels, logits) and returns a loss
        tensor.
      model_dir: Model directory used during training for restoring/saving model
        weights.
      train_input_fn: Function that returns a tf.data.Dataset used for training.
      steps_per_epoch: Number of steps to run per epoch. At the end of each
        epoch, model checkpoint will be saved and evaluation will be conducted
        if evaluation dataset is provided.
      steps_per_loop: Number of steps per graph-mode loop. In order to reduce
        communication in eager context, training logs are printed every
        steps_per_loop.
      epochs: Number of epochs to train.
      eval_input_fn: Function that returns evaluation dataset. If none,
        evaluation is skipped.
      eval_steps: Number of steps to run evaluation. Required if `eval_input_fn`
        is not none.
      metric_fn: A metrics function that returns a Keras Metric object to record
        evaluation result using evaluation dataset or with training dataset
        after every epoch.
      init_checkpoint: Optional checkpoint to load to `sub_model` returned by
        `model_fn`.
      use_remote_tpu: Ignored, will be removed in the future.
      custom_callbacks: A list of Keras Callbacks objects to run during
        training. More specifically, `on_batch_begin()`, `on_batch_end()`,
        methods are invoked during training.
      run_eagerly: Whether to run model training in pure eager execution. This
        should be disable for TPUStrategy.

  Returns:
      Trained model.

  Raises:
      ValueError: (1) When model returned by `model_fn` does not have optimizer
        attribute or when required parameters are set to none. (2) eval args are
        not specified correctly. (3) metric_fn must be a callable if specified.
  """
    # TODO(bfontain): Remove use_remote_tpu once there are no models using it.
    del use_remote_tpu

    if _sentinel is not None:
        raise ValueError('only call `run_customized_training_loop()` '
                         'with named arguments.')

    required_arguments = [
        strategy, model_fn, loss_fn, model_dir, steps_per_epoch, train_input_fn
    ]
    if [arg for arg in required_arguments if arg is None]:
        raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, '
                         '`steps_per_loop` and `steps_per_epoch` are required '
                         'parameters.')
    if steps_per_loop > steps_per_epoch:
        logging.error(
            'steps_per_loop: %d is specified to be greater than '
            ' steps_per_epoch: %d, we will use steps_per_epoch as'
            ' steps_per_loop.', steps_per_loop, steps_per_epoch)
        steps_per_loop = steps_per_epoch
    assert tf.executing_eagerly()

    if run_eagerly:
        if steps_per_loop > 1:
            raise ValueError(
                'steps_per_loop is used for performance optimization. When you want '
                'to run eagerly, you cannot leverage graph mode loop.')
        if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
            raise ValueError(
                'TPUStrategy should not run eagerly as it heavily replies on graph'
                ' optimization for the distributed system.')

    if eval_input_fn and (eval_steps is None or metric_fn is None):
        raise ValueError(
            '`eval_step` and `metric_fn` are required when `eval_input_fn ` '
            'is not none.')
    if metric_fn and not callable(metric_fn):
        raise ValueError(
            'if `metric_fn` is specified, metric_fn must be a callable.')

    total_training_steps = steps_per_epoch * epochs

    # To reduce unnecessary send/receive input pipeline operation, we place input
    # pipeline ops in worker task.
    train_iterator = _get_input_iterator(train_input_fn, strategy)

    with distribution_utils.get_strategy_scope(strategy):
        # To correctly place the model weights on accelerators,
        # model and optimizer should be created in scope.
        model, sub_model = model_fn()
        if not hasattr(model, 'optimizer'):
            raise ValueError('User should set optimizer attribute to model '
                             'inside `model_fn`.')
        optimizer = model.optimizer
        use_float16 = isinstance(
            optimizer,
            tf.keras.mixed_precision.experimental.LossScaleOptimizer)

        if init_checkpoint:
            logging.info(
                'Checkpoint file %s found and restoring from '
                'initial checkpoint for core model.', init_checkpoint)
            checkpoint = tf.train.Checkpoint(model=sub_model)
            checkpoint.restore(
                init_checkpoint).assert_existing_objects_matched()
            logging.info('Loading from checkpoint file completed')

        train_loss_metric = tf.keras.metrics.Mean('training_loss',
                                                  dtype=tf.float32)
        eval_metrics = [metric_fn()] if metric_fn else []
        # If evaluation is required, make a copy of metric as it will be used by
        # both train and evaluation.
        train_metrics = [
            metric.__class__.from_config(metric.get_config())
            for metric in eval_metrics
        ]

        # Create summary writers
        summary_dir = os.path.join(model_dir, 'summaries')
        eval_summary_writer = tf.summary.create_file_writer(
            os.path.join(summary_dir, 'eval'))
        if steps_per_loop >= _MIN_SUMMARY_STEPS:
            # Only writes summary when the stats are collected sufficiently over
            # enough steps.
            train_summary_writer = tf.summary.create_file_writer(
                os.path.join(summary_dir, 'train'))
        else:
            train_summary_writer = None

        # Collects training variables.
        training_vars = model.trainable_variables

        def _replicated_step(inputs):
            """Replicated training step."""

            inputs, labels = inputs
            with tf.GradientTape() as tape:
                model_outputs = model(inputs, training=True)
                loss = loss_fn(labels, model_outputs)
                if use_float16:
                    scaled_loss = optimizer.get_scaled_loss(loss)

            if use_float16:
                scaled_grads = tape.gradient(scaled_loss, training_vars)
                grads = optimizer.get_unscaled_gradients(scaled_grads)
            else:
                grads = tape.gradient(loss, training_vars)
            optimizer.apply_gradients(zip(grads, training_vars))
            # For reporting, the metric takes the mean of losses.
            train_loss_metric.update_state(loss)
            for metric in train_metrics:
                metric.update_state(labels, model_outputs)

        @tf.function
        def train_steps(iterator, steps):
            """Performs distributed training steps in a loop.

      Args:
        iterator: the distributed iterator of training datasets.
        steps: an tf.int32 integer tensor to specify number of steps to run
          inside host training loop.

      Raises:
        ValueError: Any of the arguments or tensor shapes are invalid.
      """
            if not isinstance(steps, tf.Tensor):
                raise ValueError(
                    'steps should be an Tensor. Python object may cause '
                    'retracing.')

            for _ in tf.range(steps):
                strategy.experimental_run_v2(_replicated_step,
                                             args=(next(iterator), ))

        def train_single_step(iterator):
            """Performs a distributed training step.

      Args:
        iterator: the distributed iterator of training datasets.

      Raises:
        ValueError: Any of the arguments or tensor shapes are invalid.
      """
            strategy.experimental_run_v2(_replicated_step,
                                         args=(next(iterator), ))

        def test_step(iterator):
            """Calculates evaluation metrics on distributed devices."""
            def _test_step_fn(inputs):
                """Replicated accuracy calculation."""

                inputs, labels = inputs
                model_outputs = model(inputs, training=False)
                for metric in eval_metrics:
                    metric.update_state(labels, model_outputs)

            strategy.experimental_run_v2(_test_step_fn,
                                         args=(next(iterator), ))

        if not run_eagerly:
            train_single_step = tf.function(train_single_step)
            test_step = tf.function(test_step)

        def _run_evaluation(current_training_step, test_iterator):
            """Runs validation steps and aggregate metrics."""
            for _ in range(eval_steps):
                test_step(test_iterator)

            with eval_summary_writer.as_default():
                for metric in eval_metrics + model.metrics:
                    metric_value = _float_metric_value(metric)
                    logging.info('Step: [%d] Validation %s = %f',
                                 current_training_step, metric.name,
                                 metric_value)
                    tf.summary.scalar(metric.name,
                                      metric_value,
                                      step=current_training_step)
                eval_summary_writer.flush()

        def _run_callbacks_on_batch_begin(batch):
            """Runs custom callbacks at the start of every step."""
            if not custom_callbacks:
                return
            for callback in custom_callbacks:
                callback.on_batch_begin(batch)

        def _run_callbacks_on_batch_end(batch):
            """Runs custom callbacks at the end of every step."""
            if not custom_callbacks:
                return
            for callback in custom_callbacks:
                callback.on_batch_end(batch)

        # Training loop starts here.
        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
        if latest_checkpoint_file:
            logging.info(
                'Checkpoint file %s found and restoring from '
                'checkpoint', latest_checkpoint_file)
            checkpoint.restore(latest_checkpoint_file)
            logging.info('Loading from checkpoint file completed')

        current_step = optimizer.iterations.numpy()
        checkpoint_name = 'ctl_step_{step}.ckpt'

        while current_step < total_training_steps:
            # Training loss/metric are taking average over steps inside micro
            # training loop. We reset the their values before each round.
            train_loss_metric.reset_states()
            for metric in train_metrics + model.metrics:
                metric.reset_states()

            _run_callbacks_on_batch_begin(current_step)
            # Runs several steps in the host while loop.
            steps = _steps_to_run(current_step, steps_per_epoch,
                                  steps_per_loop)

            if steps == 1:
                # TODO(zongweiz): merge with train_steps once tf.while_loop
                # GPU performance bugs are fixed.
                train_single_step(train_iterator)
            else:
                # Converts steps to a Tensor to avoid tf.function retracing.
                train_steps(train_iterator,
                            tf.convert_to_tensor(steps, dtype=tf.int32))
            _run_callbacks_on_batch_end(current_step)
            current_step += steps

            train_loss = _float_metric_value(train_loss_metric)
            # Updates training logging.
            training_status = 'Train Step: %d/%d  / loss = %s' % (
                current_step, total_training_steps, train_loss)

            if train_summary_writer:
                with train_summary_writer.as_default():
                    tf.summary.scalar(train_loss_metric.name,
                                      train_loss,
                                      step=current_step)
                    for metric in train_metrics + model.metrics:
                        metric_value = _float_metric_value(metric)
                        training_status += '  %s = %f' % (metric.name,
                                                          metric_value)
                        tf.summary.scalar(metric.name,
                                          metric_value,
                                          step=current_step)
                    train_summary_writer.flush()
            logging.info(training_status)

            # Saves model checkpoints and run validation steps at every epoch end.
            if current_step % steps_per_epoch == 0:
                # To avoid repeated model saving, we do not save after the last
                # step of training.
                if current_step < total_training_steps:
                    _save_checkpoint(checkpoint, model_dir,
                                     checkpoint_name.format(step=current_step))

                if eval_input_fn:
                    logging.info('Running evaluation after step: %s.',
                                 current_step)
                    _run_evaluation(
                        current_step,
                        _get_input_iterator(eval_input_fn, strategy))
                    # Re-initialize evaluation metric.
                    for metric in eval_metrics + model.metrics:
                        metric.reset_states()

        _save_checkpoint(checkpoint, model_dir,
                         checkpoint_name.format(step=current_step))

        if eval_input_fn:
            logging.info(
                'Running final evaluation after training is complete.')
            _run_evaluation(current_step,
                            _get_input_iterator(eval_input_fn, strategy))

        training_summary = {
            'total_training_steps': total_training_steps,
            'train_loss': _float_metric_value(train_loss_metric),
        }
        if eval_metrics:
            # TODO(hongkuny): Cleans up summary reporting in text.
            training_summary['last_train_metrics'] = _float_metric_value(
                train_metrics[0])
            training_summary['eval_metrics'] = _float_metric_value(
                eval_metrics[0])

        write_txt_summary(training_summary, summary_dir)

        return model
Esempio n. 31
0
def run_ncf(_):
  """Run NCF training and eval with Keras."""
  # TODO(seemuch): Support different train and eval batch sizes
  if FLAGS.eval_batch_size != FLAGS.batch_size:
    logging.warning(
        "The Keras implementation of NCF currently does not support batch_size "
        "!= eval_batch_size ({} vs. {}). Overriding eval_batch_size to match "
        "batch_size".format(FLAGS.eval_batch_size, FLAGS.batch_size)
        )
    FLAGS.eval_batch_size = FLAGS.batch_size

  params = ncf_common.parse_flags(FLAGS)
  batch_size = params["batch_size"]

  # ncf_common rounds eval_batch_size (this is needed due to a reshape during
  # eval). This carries over that rounding to batch_size as well.
  params['batch_size'] = params['eval_batch_size']

  num_users, num_items, num_train_steps, num_eval_steps, producer = (
      ncf_common.get_inputs(params))

  params["num_users"], params["num_items"] = num_users, num_items
  producer.start()
  model_helpers.apply_clean(flags.FLAGS)

  batches_per_step = params["batches_per_step"]
  train_input_dataset, eval_input_dataset = _get_train_and_eval_data(producer,
                                                                     params)
  # It is required that for distributed training, the dataset must call
  # batch(). The parameter of batch() here is the number of replicas involed,
  # such that each replica evenly gets a slice of data.
  train_input_dataset = train_input_dataset.batch(batches_per_step)
  eval_input_dataset = eval_input_dataset.batch(batches_per_step)

  strategy = ncf_common.get_distribution_strategy(params)
  with distribution_utils.get_strategy_scope(strategy):
    keras_model = _get_keras_model(params)
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=params["learning_rate"],
        beta_1=params["beta1"],
        beta_2=params["beta2"],
        epsilon=params["epsilon"])
    time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps)

    keras_model.compile(
        loss=_keras_loss,
        metrics=[_get_metric_fn(params)],
        optimizer=optimizer)

    history = keras_model.fit(train_input_dataset,
                              epochs=FLAGS.train_epochs,
                              callbacks=[
                                  IncrementEpochCallback(producer),
                                  time_callback],
                              verbose=2)

    logging.info("Training done. Start evaluating")

    eval_results = keras_model.evaluate(
        eval_input_dataset,
        steps=num_eval_steps,
        verbose=2)

  logging.info("Keras evaluation is done.")

  stats = build_stats(history, eval_results, time_callback)
  return stats
def run(flags_obj):
  """Run ResNet ImageNet training and eval loop using custom training loops.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
  print('@@@@enable_eager = {}'.format(flags_obj.enable_eager))
  keras_utils.set_session_config(
      enable_eager=flags_obj.enable_eager,
      enable_xla=flags_obj.enable_xla)

  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == tf.float16:
    policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
        'mixed_float16')
    tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
  elif dtype == tf.bfloat16:
    policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
        'mixed_bfloat16')
    tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)

  # This only affects GPU.
  common.set_cudnn_batchnorm_mode()

  # TODO(anj-s): Set data_format without using Keras.
  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)

  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
      num_workers=distribution_utils.configure_cluster(),
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs,
      tpu_address=flags_obj.tpu)

  train_ds, test_ds = get_input_dataset(flags_obj, strategy)
  per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
      flags_obj)
  steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps)
  logging.info("Training %d epochs, each epoch has %d steps, "
               "total steps: %d; Eval %d steps",
               train_epochs, per_epoch_steps, train_epochs * per_epoch_steps,
               eval_steps)

  time_callback = keras_utils.TimeHistory(flags_obj.batch_size,
                                          flags_obj.log_steps)

  with distribution_utils.get_strategy_scope(strategy):
    resnet_model.change_keras_layer(flags_obj.use_tf_keras_layers)
    use_l2_regularizer = not flags_obj.single_l2_loss_op

    if flags_obj.use_resnet_d:
      resnetd = network_tweaks.ResnetD(image_data_format=tf.keras.backend.image_data_format(),
                                       use_l2_regularizer=use_l2_regularizer)
    else:
      resnetd = None

    model = resnet_model.resnet50(
        num_classes=imagenet_preprocessing.NUM_CLASSES,
        batch_size=flags_obj.batch_size,
        zero_gamma=flags_obj.zero_gamma,
        last_pool_channel_type=flags_obj.last_pool_channel_type,
        use_l2_regularizer=use_l2_regularizer,
        resnetd=resnetd)

    if flags_obj.learning_rate_decay_type == 'piecewise':
        lr_schedule = common.PiecewiseConstantDecayWithWarmup(
            batch_size=flags_obj.batch_size,
            epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
            warmup_epochs=common.LR_SCHEDULE[0][1],
            boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
            multipliers=list(p[0] for p in common.LR_SCHEDULE),
            compute_lr_on_cpu=True)
    elif flags_obj.learning_rate_decay_type == 'cosine':
        lr_schedule = common.CosineDecayWithWarmup(
            base_lr=flags_obj.base_learning_rate,
            batch_size=flags_obj.batch_size,
            epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
            warmup_epochs=common.LR_SCHEDULE[0][1],
            train_epochs=flags_obj.train_epochs,
            compute_lr_on_cpu=True)
    else:
        raise NotImplementedError


    optimizer = common.get_optimizer(lr_schedule)

    if dtype == tf.float16:
      loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
      optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
          optimizer, loss_scale)
    elif flags_obj.fp16_implementation == 'graph_rewrite':
      # `dtype` is still float32 in this case. We built the graph in float32 and
      # let the graph rewrite change parts of it float16.
      if not flags_obj.use_tf_function:
        raise ValueError('--fp16_implementation=graph_rewrite requires '
                         '--use_tf_function to be true')
      loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
      optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          optimizer, loss_scale)

    current_step = 0
    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
    latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir)
    if latest_checkpoint:
      checkpoint.restore(latest_checkpoint)
      logging.info("Load checkpoint %s", latest_checkpoint)
      current_step = optimizer.iterations.numpy()

    train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
    test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)

    categorical_cross_entopy_and_acc = losses.CategoricalCrossEntropyAndAcc(
                                          batch_size=flags_obj.batch_size,
                                          num_classes=imagenet_preprocessing.NUM_CLASSES,
                                          label_smoothing=flags_obj.label_smoothing)
    trainable_variables = model.trainable_variables

    def step_fn(inputs):
      """Per-Replica StepFn."""
      images, labels = inputs
      with tf.GradientTape() as tape:
        logits = model(images, training=True)
        loss = categorical_cross_entopy_and_acc.loss_and_update_acc(labels, logits, training=True)
        #loss = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size)
        num_replicas = tf.distribute.get_strategy().num_replicas_in_sync

        if flags_obj.single_l2_loss_op:
          l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.add_n([
              tf.nn.l2_loss(v)
              for v in trainable_variables
              if 'bn' not in v.name
          ])

          loss += (l2_loss / num_replicas)
        else:
          loss += (tf.reduce_sum(model.losses) / num_replicas)

        # Scale the loss
        if flags_obj.dtype == "fp16":
          loss = optimizer.get_scaled_loss(loss)

      grads = tape.gradient(loss, trainable_variables)

      # Unscale the grads
      if flags_obj.dtype == "fp16":
        grads = optimizer.get_unscaled_gradients(grads)

      optimizer.apply_gradients(zip(grads, trainable_variables))
      train_loss.update_state(loss)

    @tf.function
    def train_steps(iterator, steps):
      """Performs distributed training steps in a loop."""
      for _ in tf.range(steps):
        strategy.experimental_run_v2(step_fn, args=(next(iterator),))

    def train_single_step(iterator):
      if strategy:
        strategy.experimental_run_v2(step_fn, args=(next(iterator),))
      else:
        return step_fn(next(iterator))

    def test_step(iterator):
      """Evaluation StepFn."""
      def step_fn(inputs):
        images, labels = inputs
        logits = model(images, training=False)
        loss = categorical_cross_entopy_and_acc.loss_and_update_acc(labels, logits, training=False)
        #loss = tf.reduce_sum(loss) * (1.0/ flags_obj.batch_size)
        test_loss.update_state(loss)

      if strategy:
        strategy.experimental_run_v2(step_fn, args=(next(iterator),))
      else:
        step_fn(next(iterator))

    if flags_obj.use_tf_function:
      train_single_step = tf.function(train_single_step)
      test_step = tf.function(test_step)

    if flags_obj.enable_tensorboard:
      summary_writer = tf.summary.create_file_writer(flags_obj.model_dir)
    else:
      summary_writer = None

    train_iter = iter(train_ds)
    time_callback.on_train_begin()
    for epoch in range(current_step // per_epoch_steps, train_epochs):
      train_loss.reset_states()
      categorical_cross_entopy_and_acc.training_accuracy.reset_states()

      steps_in_current_epoch = 0
      while steps_in_current_epoch < per_epoch_steps:
        time_callback.on_batch_begin(
            steps_in_current_epoch+epoch*per_epoch_steps)
        steps = _steps_to_run(steps_in_current_epoch, per_epoch_steps,
                              steps_per_loop)
        if steps == 1:
          train_single_step(train_iter)
        else:
          # Converts steps to a Tensor to avoid tf.function retracing.
          train_steps(train_iter, tf.convert_to_tensor(steps, dtype=tf.int32))
        time_callback.on_batch_end( steps_in_current_epoch+epoch*per_epoch_steps)
        steps_in_current_epoch += steps

      #temp_loss = array_ops.identity(categorical_cross_entopy_and_acc.training_loss).numpy()
      #temp_loss = categorical_cross_entopy_and_acc.training_loss.numpy()
      logging.info('Training loss: %s, accuracy: %s, cross_entropy: %s at epoch %d',
                   train_loss.result().numpy(),
                   categorical_cross_entopy_and_acc.training_accuracy.result().numpy(),
                   0.,
                   epoch + 1)

      if (not flags_obj.skip_eval and
          (epoch + 1) % flags_obj.epochs_between_evals == 0):
        test_loss.reset_states()
        categorical_cross_entopy_and_acc.test_accuracy.reset_states()

        test_iter = iter(test_ds)
        for _ in range(eval_steps):
          test_step(test_iter)

        logging.info('Test loss: %s, accuracy: %s%% at epoch: %d',
                     test_loss.result().numpy(),
                     categorical_cross_entopy_and_acc.test_accuracy.result().numpy(),
                     epoch + 1)

      if flags_obj.enable_checkpoint_and_export:
        checkpoint_name = checkpoint.save(
            os.path.join(flags_obj.model_dir,
                         'model.ckpt-{}'.format(epoch + 1)))
        logging.info('Saved checkpoint to %s', checkpoint_name)

      if summary_writer:
        current_steps = steps_in_current_epoch + (epoch * per_epoch_steps)
        with summary_writer.as_default():
          #tf.summary.scalar('train_cross_entropy', categorical_cross_entopy_and_acc.training_loss.numpy(), current_steps)
          tf.summary.scalar('train_loss', train_loss.result(), current_steps)
          tf.summary.scalar('train_accuracy', categorical_cross_entopy_and_acc.training_accuracy.result(),
                            current_steps)
          lr_for_monitor = lr_schedule(current_steps)
          if callable(lr_for_monitor):
            lr_for_monitor = lr_for_monitor()
          tf.summary.scalar('learning_rate', lr_for_monitor, current_steps)
          tf.summary.scalar('eval_loss', test_loss.result(), current_steps)
          tf.summary.scalar(
              'eval_accuracy', categorical_cross_entopy_and_acc.test_accuracy.result(), current_steps)

    time_callback.on_train_end()
    if summary_writer:
      summary_writer.close()

    eval_result = None
    train_result = None
    if not flags_obj.skip_eval:
      eval_result = [test_loss.result().numpy(),
                     categorical_cross_entopy_and_acc.test_accuracy.result().numpy()]
      train_result = [train_loss.result().numpy(),
                      categorical_cross_entopy_and_acc.training_accuracy.result().numpy()]

    stats = build_stats(train_result, eval_result, time_callback)
    return stats