Ejemplo n.º 1
0
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
    NotImplementedError: If some features are not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_utils.set_gpu_thread_mode_and_count(
            per_gpu_thread_count=flags_obj.per_gpu_thread_count,
            gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
            num_gpus=flags_obj.num_gpus,
            datasets_num_private_threads=flags_obj.datasets_num_private_threads
        )
    common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    performance.set_mixed_precision_policy(
        flags_core.get_tf_dtype(flags_obj),
        flags_core.get_loss_scale(flags_obj, default_for_fp16=128))

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    # Configures cluster spec for distribution strategy.
    _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                             flags_obj.task_index)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    # pylint: disable=protected-access
    if flags_obj.use_synthetic_data:
        distribution_utils.set_up_synthetic_data()
        input_fn = common.get_synth_input_fn(
            height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
            num_channels=imagenet_preprocessing.NUM_CHANNELS,
            num_classes=imagenet_preprocessing.NUM_CLASSES,
            dtype=dtype,
            drop_remainder=True)
    else:
        distribution_utils.undo_set_up_synthetic_data()
        input_fn = imagenet_preprocessing.input_fn

    # When `enable_xla` is True, we always drop the remainder of the batches
    # in the dataset, as XLA-GPU doesn't support dynamic shapes.
    drop_remainder = flags_obj.enable_xla

    # Current resnet_model.resnet50 input format is always channel-last.
    # We use keras_application mobilenet model which input format is depends on
    # the keras beckend image data format.
    # This use_keras_image_data_format flags indicates whether image preprocessor
    # output format should be same as the keras backend image data format or just
    # channel-last format.
    use_keras_image_data_format = (flags_obj.model == 'mobilenet')
    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        drop_remainder=drop_remainder,
        tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
        training_dataset_cache=flags_obj.training_dataset_cache,
    )

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=flags_obj.batch_size,
            parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
                use_keras_image_data_format=use_keras_image_data_format),
            dtype=dtype,
            drop_remainder=drop_remainder)

    lr_schedule = common.PiecewiseConstantDecayWithWarmup(
        batch_size=flags_obj.batch_size,
        epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
        warmup_epochs=common.LR_SCHEDULE[0][1],
        boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
        multipliers=list(p[0] for p in common.LR_SCHEDULE),
        compute_lr_on_cpu=True)
    steps_per_epoch = (imagenet_preprocessing.NUM_IMAGES['train'] //
                       flags_obj.batch_size)

    with strategy_scope:
        if flags_obj.optimizer == 'resnet50_default':
            optimizer = common.get_optimizer(lr_schedule)
        elif flags_obj.optimizer == 'mobilenet_default':
            initial_learning_rate = \
                flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
            optimizer = tf.keras.optimizers.SGD(
                learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
                    initial_learning_rate,
                    decay_steps=steps_per_epoch *
                    flags_obj.num_epochs_per_decay,
                    decay_rate=flags_obj.lr_decay_factor,
                    staircase=True),
                momentum=0.9)
        if flags_obj.fp16_implementation == 'graph_rewrite':
            # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
            # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
            # which will ensure tf.compat.v2.keras.mixed_precision and
            # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
            # up.
            optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                optimizer)

        # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
        if flags_obj.use_trivial_model:
            model = test_utils.trivial_model(
                imagenet_preprocessing.NUM_CLASSES)
        elif flags_obj.model == 'resnet50_v1.5':
            model = resnet_model.resnet50(
                num_classes=imagenet_preprocessing.NUM_CLASSES)
        elif flags_obj.model == 'mobilenet':
            # TODO(kimjaehong): Remove layers attribute when minimum TF version
            # support 2.0 layers by default.
            model = tf.keras.applications.mobilenet.MobileNet(
                weights=None,
                classes=imagenet_preprocessing.NUM_CLASSES,
                layers=tf.keras.layers)
        if flags_obj.pretrained_filepath:
            model.load_weights(flags_obj.pretrained_filepath)

        if flags_obj.pruning_method == 'polynomial_decay':
            if dtype != tf.float32:
                raise NotImplementedError(
                    'Pruning is currently only supported on dtype=tf.float32.')
            pruning_params = {
                'pruning_schedule':
                tfmot.sparsity.keras.PolynomialDecay(
                    initial_sparsity=flags_obj.pruning_initial_sparsity,
                    final_sparsity=flags_obj.pruning_final_sparsity,
                    begin_step=flags_obj.pruning_begin_step,
                    end_step=flags_obj.pruning_end_step,
                    frequency=flags_obj.pruning_frequency),
            }
            model = tfmot.sparsity.keras.prune_low_magnitude(
                model, **pruning_params)
        elif flags_obj.pruning_method:
            raise NotImplementedError(
                'Only polynomial_decay is currently supported.')

        model.compile(loss='sparse_categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=(['sparse_categorical_accuracy']
                               if flags_obj.report_accuracy_metrics else None),
                      run_eagerly=flags_obj.run_eagerly)


# Harry up you can taken
    train_epochs = flags_obj.train_epochs

    callbacks = common.get_callbacks(
        steps_per_epoch=steps_per_epoch,
        pruning_method=flags_obj.pruning_method,
        enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
        model_dir=flags_obj.model_dir)

    # if mutliple epochs, ignore the train_steps flag.
    # and the num of epochs is define train_steps
    if train_epochs <= 1 and flags_obj.train_steps:
        steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
        train_epochs = 1

    num_eval_steps = (imagenet_preprocessing.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        # Only build the training graph. This reduces memory usage introduced by
        # control flow ops in layers that have different implementations for
        # training and inference (e.g., batch norm).
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    if not strategy and flags_obj.explicit_gpu_placement:
        # TODO(b/135607227): Add device scope automatically in Keras training loop
        # when not using distribition strategy.
        no_dist_strat_device = tf.device('/device:GPU:0')
        no_dist_strat_device.__enter__()

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)

    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

    if flags_obj.pruning_method:
        model = tfmot.sparsity.keras.strip_pruning(model)
    if flags_obj.enable_checkpoint_and_export:
        if dtype == tf.bfloat16:
            logging.warning(
                'Keras model.save does not support bfloat16 dtype.')
        else:
            # Keras model.save assumes a float32 input designature.
            #
            export_path = os.path.join(flags_obj.model_dir, 'saved_model')
            model.save(export_path, include_optimizer=False)

    if not strategy and flags_obj.explicit_gpu_placement:
        no_dist_strat_device.__exit__()

    stats = common.build_stats(history, eval_output, callbacks)
    return stats
Ejemplo n.º 2
0
    def __init__(self, flags_obj, time_callback, epoch_steps):
        standard_runnable.StandardTrainable.__init__(
            self, flags_obj.use_tf_while_loop, flags_obj.use_tf_function)
        standard_runnable.StandardEvaluable.__init__(self,
                                                     flags_obj.use_tf_function)

        self.strategy = tf.distribute.get_strategy()
        self.flags_obj = flags_obj
        self.dtype = flags_core.get_tf_dtype(flags_obj)
        self.time_callback = time_callback

        # Input pipeline related
        batch_size = flags_obj.batch_size
        if batch_size % self.strategy.num_replicas_in_sync != 0:
            raise ValueError(
                'Batch size must be divisible by number of replicas : {}'.
                format(self.strategy.num_replicas_in_sync))

        # As auto rebatching is not supported in
        # `experimental_distribute_datasets_from_function()` API, which is
        # required when cloning dataset to multiple workers in eager mode,
        # we use per-replica batch size.
        self.batch_size = int(batch_size / self.strategy.num_replicas_in_sync)

        if self.flags_obj.use_synthetic_data:
            self.input_fn = common.get_synth_input_fn(
                height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
                width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
                num_channels=imagenet_preprocessing.NUM_CHANNELS,
                num_classes=imagenet_preprocessing.NUM_CLASSES,
                dtype=self.dtype,
                drop_remainder=True)
        else:
            self.input_fn = imagenet_preprocessing.input_fn

        self.model = resnet_model.resnet50(
            num_classes=imagenet_preprocessing.NUM_CLASSES,
            batch_size=flags_obj.batch_size,
            use_l2_regularizer=not flags_obj.single_l2_loss_op)

        lr_schedule = common.PiecewiseConstantDecayWithWarmup(
            batch_size=flags_obj.batch_size,
            epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
            warmup_epochs=common.LR_SCHEDULE[0][1],
            boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
            multipliers=list(p[0] for p in common.LR_SCHEDULE),
            compute_lr_on_cpu=True)
        self.optimizer = common.get_optimizer(lr_schedule)
        # Make sure iterations variable is created inside scope.
        self.global_step = self.optimizer.iterations

        use_graph_rewrite = flags_obj.fp16_implementation == 'graph_rewrite'
        if use_graph_rewrite and not flags_obj.use_tf_function:
            raise ValueError('--fp16_implementation=graph_rewrite requires '
                             '--use_tf_function to be true')
        self.optimizer = performance.configure_optimizer(
            self.optimizer,
            use_float16=self.dtype == tf.float16,
            use_graph_rewrite=use_graph_rewrite,
            loss_scale=flags_core.get_loss_scale(flags_obj,
                                                 default_for_fp16=128))

        self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
        self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'train_accuracy', dtype=tf.float32)
        self.test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
        self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'test_accuracy', dtype=tf.float32)

        self.checkpoint = tf.train.Checkpoint(model=self.model,
                                              optimizer=self.optimizer)

        # Handling epochs.
        self.epoch_steps = epoch_steps
        self.epoch_helper = utils.EpochHelper(epoch_steps, self.global_step)
Ejemplo n.º 3
0
def run(flags_obj):
    """Run ResNet Cifar-10 training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_utils.set_gpu_thread_mode_and_count(
            per_gpu_thread_count=flags_obj.per_gpu_thread_count,
            gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
            num_gpus=flags_obj.num_gpus,
            datasets_num_private_threads=flags_obj.datasets_num_private_threads
        )
    common.set_cudnn_batchnorm_mode()

    dtype = flags_core.get_tf_dtype(flags_obj)
    if dtype == 'fp16':
        raise ValueError(
            'dtype fp16 is not supported in Keras. Use the default '
            'value(fp32).')

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.config.list_physical_devices('GPU') else
                       'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    if flags_obj.use_synthetic_data:
        synthetic_util.set_up_synthetic_data()
        input_fn = common.get_synth_input_fn(
            height=cifar_preprocessing.HEIGHT,
            width=cifar_preprocessing.WIDTH,
            num_channels=cifar_preprocessing.NUM_CHANNELS,
            num_classes=cifar_preprocessing.NUM_CLASSES,
            dtype=flags_core.get_tf_dtype(flags_obj),
            drop_remainder=True)
    else:
        synthetic_util.undo_set_up_synthetic_data()
        input_fn = cifar_preprocessing.input_fn

    train_input_dataset = input_fn(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        parse_record_fn=cifar_preprocessing.parse_record,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        dtype=dtype,
        # Setting drop_remainder to avoid the partial batch logic in normalization
        # layer, which triggers tf.where and leads to extra memory copy of input
        # sizes between host and GPU.
        drop_remainder=(not flags_obj.enable_get_next_as_optional))

    eval_input_dataset = None
    if not flags_obj.skip_eval:
        eval_input_dataset = input_fn(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=flags_obj.batch_size,
            parse_record_fn=cifar_preprocessing.parse_record)

    steps_per_epoch = (cifar_preprocessing.NUM_IMAGES['train'] //
                       flags_obj.batch_size)
    lr_schedule = 0.1
    if flags_obj.use_tensor_lr:
        initial_learning_rate = common.BASE_LEARNING_RATE * flags_obj.batch_size / 128
        lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            boundaries=list(p[1] * steps_per_epoch for p in LR_SCHEDULE),
            values=[initial_learning_rate] + list(p[0] * initial_learning_rate
                                                  for p in LR_SCHEDULE))

    with strategy_scope:
        optimizer = common.get_optimizer(lr_schedule)
        model = resnet_cifar_model.resnet56(
            classes=cifar_preprocessing.NUM_CLASSES)
        model.compile(loss='sparse_categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=(['sparse_categorical_accuracy']
                               if flags_obj.report_accuracy_metrics else None),
                      run_eagerly=flags_obj.run_eagerly)

    train_epochs = flags_obj.train_epochs

    callbacks = common.get_callbacks(steps_per_epoch)

    if not flags_obj.use_tensor_lr:
        lr_callback = LearningRateBatchScheduler(
            schedule=learning_rate_schedule,
            batch_size=flags_obj.batch_size,
            steps_per_epoch=steps_per_epoch)
        callbacks.append(lr_callback)

    # if mutliple epochs, ignore the train_steps flag.
    if train_epochs <= 1 and flags_obj.train_steps:
        steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
        train_epochs = 1

    num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
                      flags_obj.batch_size)

    validation_data = eval_input_dataset
    if flags_obj.skip_eval:
        if flags_obj.set_learning_phase_to_train:
            # TODO(haoyuzhang): Understand slowdown of setting learning phase when
            # not using distribution strategy.
            tf.keras.backend.set_learning_phase(1)
        num_eval_steps = None
        validation_data = None

    if not strategy and flags_obj.explicit_gpu_placement:
        # TODO(b/135607227): Add device scope automatically in Keras training loop
        # when not using distribition strategy.
        no_dist_strat_device = tf.device('/device:GPU:0')
        no_dist_strat_device.__enter__()

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=validation_data,
                        validation_freq=flags_obj.epochs_between_evals,
                        verbose=2)
    eval_output = None
    if not flags_obj.skip_eval:
        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

    if not strategy and flags_obj.explicit_gpu_placement:
        no_dist_strat_device.__exit__()

    stats = common.build_stats(history, eval_output, callbacks)
    return stats
Ejemplo n.º 4
0
def run_train(flags_obj):
    keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
                                   enable_xla=flags_obj.enable_xla)

    # Execute flag override logic for better model performance
    if flags_obj.tf_gpu_thread_mode:
        keras_utils.set_gpu_thread_mode_and_count(
            per_gpu_thread_count=flags_obj.per_gpu_thread_count,
            gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
            num_gpus=flags_obj.num_gpus,
            datasets_num_private_threads=flags_obj.datasets_num_private_threads
        )
    common.set_cudnn_batchnorm_mode()

    performance.set_mixed_precision_policy(
        flags_core.get_tf_dtype(flags_obj),
        flags_core.get_loss_scale(flags_obj, default_for_fp16=128))

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    # Configures cluster spec for distribution strategy.
    _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                             flags_obj.task_index)

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    if strategy:
        # flags_obj.enable_get_next_as_optional controls whether enabling
        # get_next_as_optional behavior in DistributedIterator. If true, last
        # partial batch can be supported.
        strategy.extended.experimental_enable_get_next_as_optional = (
            flags_obj.enable_get_next_as_optional)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    distribution_utils.undo_set_up_synthetic_data()

    train_input_dataset, eval_input_dataset, tr_dataset, te_dataset = setup_datasets(
        flags_obj)

    lr_schedule = common.PiecewiseConstantDecayWithWarmup(
        batch_size=GB_OPTIONS.batch_size,
        epoch_size=tr_dataset.num_examples_per_epoch(),
        warmup_epochs=common.LR_SCHEDULE[0][1],
        boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
        multipliers=list(p[0] for p in common.LR_SCHEDULE),
        compute_lr_on_cpu=True)

    with strategy_scope:
        optimizer = common.get_optimizer(lr_schedule)
        model = build_model(tr_dataset.num_classes, mode='resnet50')

        if GB_OPTIONS.pretrained_filepath is not None:
            latest = tf.train.latest_checkpoint(GB_OPTIONS.pretrained_filepath)
            print(latest)
            model.load_weights(latest)

        #losses = ["sparse_categorical_crossentropy"]
        #lossWeights = [1.0]
        model.compile(
            optimizer=optimizer,
            loss="sparse_categorical_crossentropy",
            #loss_weights=lossWeights,
            metrics=['sparse_categorical_accuracy'])

        num_train_examples = tr_dataset.num_examples_per_epoch()
        steps_per_epoch = num_train_examples // GB_OPTIONS.batch_size
        train_epochs = GB_OPTIONS.num_epochs

        if not hasattr(tr_dataset, "n_poison"):
            n_poison = 0
            n_cover = 0
        else:
            n_poison = tr_dataset.n_poison
            n_cover = tr_dataset.n_cover

        callbacks = common.get_callbacks(
            steps_per_epoch=steps_per_epoch,
            pruning_method=flags_obj.pruning_method,
            enable_checkpoint_and_export=False,
            model_dir=GB_OPTIONS.checkpoint_folder)
        ckpt_full_path = os.path.join(
            GB_OPTIONS.checkpoint_folder,
            'model.ckpt-{epoch:04d}-p%d-c%d' % (n_poison, n_cover))
        callbacks.append(
            tf.keras.callbacks.ModelCheckpoint(ckpt_full_path,
                                               save_weights_only=True,
                                               save_best_only=True))

        num_eval_examples = te_dataset.num_examples_per_epoch()
        num_eval_steps = num_eval_examples // GB_OPTIONS.batch_size

        if flags_obj.skip_eval:
            # Only build the training graph. This reduces memory usage introduced by
            # control flow ops in layers that have different implementations for
            # training and inference (e.g., batch norm).
            if flags_obj.set_learning_phase_to_train:
                # TODO(haoyuzhang): Understand slowdown of setting learning phase when
                # not using distribution strategy.
                tf.keras.backend.set_learning_phase(1)
            num_eval_steps = None
            eval_input_dataset = None

        history = model.fit(train_input_dataset,
                            epochs=train_epochs,
                            steps_per_epoch=steps_per_epoch,
                            callbacks=callbacks,
                            validation_steps=num_eval_steps,
                            validation_data=eval_input_dataset,
                            validation_freq=flags_obj.epochs_between_evals)

        export_path = os.path.join(GB_OPTIONS.checkpoint_folder, 'saved_model')
        model.save(export_path, include_optimizer=False)

        eval_output = model.evaluate(eval_input_dataset,
                                     steps=num_eval_steps,
                                     verbose=2)

        stats = common.build_stats(history, eval_output, callbacks)

        return stats