def test_optstate_sumsq(self):
        """Test that optstate sumsq and sumsq are computed correctly."""
        init_fn, update_fn, _ = make_training_metrics(
            self.num_train_steps,
            optstate_sumsq_fields=['nu'],
            optstate_sum_fields=['nu'])
        initial_metrics_state = init_fn(self.mock_params0)
        self.assertTrue(
            pytree_equal(initial_metrics_state['optstate_sumsq'],
                         {'nu': jnp.zeros(self.num_train_steps)}))
        self.assertTrue(
            pytree_equal(initial_metrics_state['optstate_sum'],
                         {'nu': jnp.zeros(self.num_train_steps)}))
        updated_metrics_state = update_fn(initial_metrics_state, 0,
                                          self.mock_cost0, self.mock_grad1,
                                          self.mock_params0, self.mock_params1,
                                          self.mock_optimizer_state0)
        updated_metrics_state = update_fn(updated_metrics_state, 1,
                                          self.mock_cost1, self.mock_grad2,
                                          self.mock_params1, self.mock_params2,
                                          self.mock_optimizer_state1)

        self.assertEqual(updated_metrics_state['optstate_sumsq']['nu'][0],
                         total_tree_norm_sql2(self.mock_nu0))
        self.assertEqual(updated_metrics_state['optstate_sumsq']['nu'][1],
                         total_tree_norm_sql2(self.mock_nu1))

        self.assertEqual(updated_metrics_state['optstate_sum']['nu'][0],
                         total_tree_sum(self.mock_nu0))
        self.assertEqual(updated_metrics_state['optstate_sum']['nu'][1],
                         total_tree_sum(self.mock_nu1))
    def test_update_param_norms(self):
        """Ensure that we update param norms correctly."""

        init_fn, update_fn, _ = make_training_metrics(self.num_train_steps,
                                                      enable_param_norms=True)
        initial_metrics_state = init_fn(self.mock_params0)
        updated_metrics_state = update_fn(initial_metrics_state, 0,
                                          self.mock_cost0, self.mock_grad1,
                                          self.mock_params0, self.mock_params1,
                                          self.mock_optimizer_state0)
        updated_metrics_state = update_fn(updated_metrics_state, 1,
                                          self.mock_cost1, self.mock_grad2,
                                          self.mock_params1, self.mock_params2,
                                          self.mock_optimizer_state1)
        self.assertTrue(
            pytree_equal(
                updated_metrics_state['param_norms'], {
                    'foo':
                    jnp.array([
                        jnp.linalg.norm(self.mock_params0['foo']),
                        jnp.linalg.norm(self.mock_params1['foo']), 0.0, 0.0,
                        0.0
                    ]),
                    'bar': {
                        'baz':
                        jnp.array([
                            jnp.linalg.norm(self.mock_params0['bar']['baz']),
                            jnp.linalg.norm(self.mock_params1['bar']['baz']),
                            0.0, 0.0, 0.0
                        ])
                    }
                }))
 def test_summarize(self):
     """Test the training metrics summarizer."""
     _, _, summarize_fn = make_training_metrics(self.num_train_steps,
                                                enable_train_cost=True,
                                                enable_ema=True)
     metrics_state = {
         'train_cost': jnp.array([1.0, 0.5, 0.25, 0.0, 0.0]),
         'param_norm': {
             'foo': 7.0,
             'bar': {
                 'baz': 2.0
             }
         },
         'grad_ema': {
             'foo': 1 * jnp.ones(5),
             'bar': {
                 'baz': 2 * jnp.ones(10)
             }
         },
         'grad_sq_ema': {
             'foo': 2 * jnp.ones(5),
             'bar': {
                 'baz': 6 * jnp.ones(10)
             }
         },
         'update_ema': {
             'foo': 2 * jnp.ones(5),
             'bar': {
                 'baz': 1 * jnp.ones(10)
             }
         },
         'update_sq_ema': {
             'foo': 6 * jnp.ones(5),
             'bar': {
                 'baz': 2 * jnp.ones(10)
             }
         },
     }
     tree_summary = summarize_fn(metrics_state)
     self.assertTrue(
         pytree_equal(
             tree_summary, {
                 'param_norm': {
                     '/foo': 7.0,
                     '/bar/baz': 2.0
                 },
                 'grad_var': {
                     '/foo': 5 * (2 - 1**2),
                     '/bar/baz': 10 * (6 - 2**2)
                 },
                 'update_var': {
                     '/foo': 5 * (6 - 2**2),
                     '/bar/baz': 10 * (2 - 1**2)
                 },
                 'update_ratio': {
                     '/foo': 5 * (6 - 2**2) / 7.0,
                     '/bar/baz': 10 * (2 - 1**2) / 2.0
                 }
             }))
    def test_init(self):
        """Test the training metrics initializer."""

        zeros_like_params = jax.tree_map(jnp.zeros_like, self.mock_params0)
        zeros_scalar_like_params = jax.tree_map(lambda x: 0.0,
                                                self.mock_params0)
        zeros_timeseries = jnp.zeros(self.num_train_steps)
        zeros_timeseries_like_params = jax.tree_map(
            lambda x: jnp.zeros(self.num_train_steps), self.mock_params0)

        # Test init with everything disabled.
        init_fn, _, _ = make_training_metrics(self.num_train_steps)
        initial_metrics_state = init_fn(self.mock_params0)
        self.assertTrue(
            pytree_equal({'param_norm': zeros_scalar_like_params},
                         initial_metrics_state))

        # Test init with enable_ema = True and enable_train_cost=True.
        init_fn, _, _ = make_training_metrics(self.num_train_steps,
                                              enable_ema=True,
                                              enable_train_cost=True,
                                              enable_param_norms=True,
                                              enable_gradient_norm=True,
                                              enable_update_norm=True,
                                              enable_update_norms=True)
        initial_metrics_state = init_fn(self.mock_params0)
        self.assertTrue(
            pytree_equal(
                initial_metrics_state, {
                    'train_cost': zeros_timeseries,
                    'param_norm': zeros_scalar_like_params,
                    'grad_ema': zeros_like_params,
                    'grad_sq_ema': zeros_like_params,
                    'update_ema': zeros_like_params,
                    'update_sq_ema': zeros_like_params,
                    'param_norms': zeros_timeseries_like_params,
                    'gradient_norm': zeros_timeseries,
                    'update_norm': zeros_timeseries,
                    'update_norms': zeros_timeseries_like_params
                }))
    def test_update_update_norms(self):
        """Ensure that we update gradient and update norms correctly."""
        init_fn, update_fn, _ = make_training_metrics(
            self.num_train_steps,
            enable_gradient_norm=True,
            enable_update_norm=True,
            enable_update_norms=True)
        initial_metrics_state = init_fn(self.mock_params0)
        updated_metrics_state = update_fn(initial_metrics_state, 0,
                                          self.mock_cost0, self.mock_grad1,
                                          self.mock_params0, self.mock_params1,
                                          self.mock_optimizer_state0)
        updated_metrics_state = update_fn(updated_metrics_state, 1,
                                          self.mock_cost1, self.mock_grad2,
                                          self.mock_params1, self.mock_params2,
                                          self.mock_optimizer_state1)
        self.assertTrue(
            pytree_equal(
                updated_metrics_state['update_norms'], {
                    'foo':
                    jnp.array([
                        self.step_size *
                        jnp.linalg.norm(self.mock_grad1['foo']),
                        self.step_size *
                        jnp.linalg.norm(self.mock_grad2['foo']), 0.0, 0.0, 0.0
                    ]),
                    'bar': {
                        'baz':
                        jnp.array([
                            self.step_size *
                            jnp.linalg.norm(self.mock_grad1['bar']['baz']),
                            self.step_size *
                            jnp.linalg.norm(self.mock_grad2['bar']['baz']),
                            0.0, 0.0, 0.0
                        ])
                    }
                }))

        self.assertEqual(updated_metrics_state['update_norm'][0],
                         total_tree_norm_l2(self.mock_grad1))
        self.assertEqual(updated_metrics_state['update_norm'][1],
                         total_tree_norm_l2(self.mock_grad2))

        self.assertEqual(updated_metrics_state['update_norm'][0],
                         self.step_size * total_tree_norm_l2(self.mock_grad1))
        self.assertEqual(updated_metrics_state['update_norm'][1],
                         self.step_size * total_tree_norm_l2(self.mock_grad2))
    def test_train_cost(self):
        """Ensure that the train cost is logged correctly."""
        init_fn, update_fn, _ = make_training_metrics(self.num_train_steps,
                                                      enable_train_cost=True)
        initial_metrics_state = init_fn(self.mock_params0)
        updated_metrics_state = update_fn(initial_metrics_state, 0,
                                          self.mock_cost0, self.mock_grad1,
                                          self.mock_params0, self.mock_params1,
                                          self.mock_optimizer_state0)
        updated_metrics_state = update_fn(updated_metrics_state, 1,
                                          self.mock_cost1, self.mock_grad2,
                                          self.mock_params1, self.mock_params2,
                                          self.mock_optimizer_state1)

        self.assertTrue(
            pytree_equal(
                updated_metrics_state['train_cost'],
                jnp.array([self.mock_cost0, self.mock_cost1, 0.0, 0.0, 0.0])))
    def test_update_grad_ema(self):
        """Ensure that the training metrics updater updates grad ema correctly."""

        init_fn, update_fn, _ = make_training_metrics(self.num_train_steps,
                                                      enable_ema=True,
                                                      ema_beta=0.5)
        initial_metrics_state = init_fn(self.mock_params0)
        updated_metrics_state = update_fn(initial_metrics_state, 0,
                                          self.mock_cost0, self.mock_grad1,
                                          self.mock_params0, self.mock_params1,
                                          self.mock_optimizer_state0)
        updated_metrics_state = update_fn(updated_metrics_state, 1,
                                          self.mock_cost1, self.mock_grad2,
                                          self.mock_params1, self.mock_params2,
                                          self.mock_optimizer_state1)

        self.assertTrue(
            pytree_equal(
                updated_metrics_state['grad_ema'],
                jax.tree_map(lambda x, y, z: 0.25 * x + 0.25 * y + 0.5 * z,
                             self.mock_zeros, self.mock_grad1,
                             self.mock_grad2)))
Ejemplo n.º 8
0
def train(train_dir,
          model,
          dataset_builder,
          initializer,
          num_train_steps,
          hps,
          rng,
          eval_batch_size,
          eval_num_batches,
          eval_train_num_batches,
          eval_frequency,
          checkpoint_steps,
          early_stopping_target_name=None,
          early_stopping_target_value=None,
          early_stopping_mode=None,
          eval_steps=None,
          metrics_logger=None,
          init_logger=None,
          training_metrics_config=None,
          callback_configs=None,
          external_checkpoint_path=None):
    """Main training loop.

  Trains the given network on the specified dataset for the given number of
  epochs. Saves the training curve in train_dir/r=3/results.tsv.

  Args:
    train_dir: (str) Path of the training directory.
    model: (BaseModel) Model object to be trained.
    dataset_builder: dataset builder returned by datasets.get_dataset.
    initializer: Must have API as defined in initializers.py
    num_train_steps: (int) Number of steps to train on.
    hps: (tf.HParams) Model, initialization and training hparams.
    rng: (jax.random.PRNGKey) Rng seed used in model initialization and data
      shuffling.
    eval_batch_size: the evaluation batch size. If None, use hps.batch_size.
    eval_num_batches: (int) The number of batches used for evaluating on
      validation and test sets. Set to None to evaluate on the whole train set.
    eval_train_num_batches: (int) The number of batches for evaluating on train.
      Set to None to evaluate on the whole training set.
    eval_frequency: (int) Evaluate every k steps.
    checkpoint_steps: List of integers indicating special steps to save
      checkpoints at. These checkpoints do not get used for preemption recovery.
    early_stopping_target_name: A string naming the metric to use to perform
       early stopping. If this metric reaches the value
      `early_stopping_target_value`, training will stop. Must include the
      dataset split (ex: validation/error_rate).
    early_stopping_target_value: A float indicating the value at which to stop
      training.
    early_stopping_mode: One of "above" or "below", indicates if we should stop
      when the metric is above or below the threshold value. Example: if
      "above", then training will stop when
      `report[early_stopping_target_name] >= early_stopping_target_value`.
    eval_steps: List of integers indicating which steps to perform evals. If
      provided, eval_frequency will be ignored. Performing an eval implies
      saving a checkpoint that will be used to resume training in the case of
      preemption.
    metrics_logger: Used to log all eval metrics during training. See
      utils.MetricLogger for API definition.
    init_logger: Used for black box initializers that have learning curves.
    training_metrics_config: Dict specifying the configuration of the
      training_metrics_grabber. Set to None to skip logging of advanced training
      metrics.
    callback_configs: List of configs specifying general callbacks to run
      during the eval phase. Empty list means no callbacks are run. See
      callbacks.py for details on what is expected in a config.
    external_checkpoint_path: (str) If this argument is set, we will load the
      optimizer_state, params, batch_stats, and training_metrics from the
      checkpoint at this location.

  Yields:
    metrics: A dictionary of all eval metrics from the given epoch.
  """
    # NOTE: the initialization RNG should *not* be per-host, as this will create
    # different sets of weights per host. However, all other RNGs should be
    # per-host.
    # TODO(znado,gilmer,gdahl): implement replicating the same initialization
    # across hosts.
    rng, init_rng = jax.random.split(rng)
    rng = jax.random.fold_in(rng, jax.process_index())
    rng, data_rng = jax.random.split(rng)

    # only used if checkpoints_steps is non-empty.
    checkpoint_dir = os.path.join(train_dir, 'checkpoints')

    # For logging / processing off the main thread
    pool = multiprocessing.pool.ThreadPool()

    if jax.process_index() == 0:
        logging.info('Let the training begin!')
        logging.info('Dataset input shape: %r', hps.input_shape)
        logging.info('Hyperparameters: %s', hps)

    if eval_batch_size is None:
        eval_batch_size = hps.batch_size
    if callback_configs is None:
        callback_configs = []

    # Maybe run the initializer.
    unreplicated_params, unreplicated_batch_stats = init_utils.initialize(
        model.flax_module, initializer, model.loss_fn,
        hps.input_shape, hps.output_shape, hps, init_rng, init_logger,
        model.get_fake_batch(hps))

    if jax.process_index() == 0:
        utils.log_pytree_shape_and_statistics(unreplicated_params)
        logging.info('train_size: %d,', hps.train_size)

    # Note that global_step refers to the number of gradients calculations, not
    # the number of model updates. This means when using gradient accumulation,
    # one must supply configs where the number of steps are in units of gradient
    # calculations, not model updates, and in post processing one must divide
    # global_step by grad_accum_step_multiplier to get the number of updates.
    #
    # If using gradient accumulation, stretch the learning rate schedule by the
    # number of gradient calculations per weight update.
    stretch_factor = 1
    if hps.get('total_accumulated_batch_size') is not None:
        stretch_factor = hps.total_accumulated_batch_size // hps.batch_size
    lr_fn = schedules.get_schedule_fn(hps.lr_hparams,
                                      num_train_steps,
                                      stretch_factor=stretch_factor)

    optimizer_init_fn, optimizer_update_fn = optimizers.get_optimizer(
        hps, model)
    unreplicated_optimizer_state = optimizer_init_fn(unreplicated_params)

    (unreplicated_metrics_state, metrics_update_fn,
     metrics_summary_fn) = None, None, None
    if training_metrics_config is not None:
        (metrics_init_fn, metrics_update_fn,
         metrics_summary_fn) = make_training_metrics(num_train_steps,
                                                     **training_metrics_config)
        unreplicated_metrics_state = metrics_init_fn(unreplicated_params)

    (optimizer_state, params, batch_stats, metrics_state, global_step,
     sum_train_cost, preemption_count,
     is_restored) = checkpoint.replicate_and_maybe_restore_checkpoint(
         unreplicated_optimizer_state,
         unreplicated_params,
         unreplicated_batch_stats,
         unreplicated_metrics_state,
         train_dir=train_dir,
         external_checkpoint_path=external_checkpoint_path)

    if is_restored:
        preemption_count += 1
        # Fold the restored step into the dataset RNG so that we will get a
        # different shuffle each time we restore, so that we do not repeat a
        # previous dataset ordering again after restoring. This is not the only
        # difference in shuffling each pre-emption, because we often times reshuffle
        # the input files each time in a non-deterministic manner.
        #
        # Note that if we are pre-empted more than once per epoch then we will
        # retrain more on the beginning of the training split, because each time we
        # restore we refill the shuffle buffer with the first `shuffle_buffer_size`
        # elements from the training split to continue training.
        #
        # Also note that for evaluating on the training split, because we are
        # reshuffling each time, we will get a new eval_train split each time we are
        # pre-empted.
        data_rng = jax.random.fold_in(data_rng, global_step)

    assert hps.batch_size % (jax.device_count()) == 0
    assert eval_batch_size % (jax.device_count()) == 0
    dataset = dataset_builder(
        data_rng,
        hps.batch_size,
        eval_batch_size=eval_batch_size,
        hps=hps,
    )

    update_fn = functools.partial(update,
                                  training_cost=model.training_cost,
                                  grad_clip=hps.get('grad_clip'),
                                  optimizer_update_fn=optimizer_update_fn,
                                  metrics_update_fn=metrics_update_fn)
    # in_axes = (
    #     optimizer_state = 0,
    #     params = 0,
    #     batch_stats = 0,
    #     metrics_state = 0,
    #     batch = 0,
    #     step = None,
    #     lr = None,
    #     rng = None,
    #     local_device_index = 0,
    #     running_train_cost = 0,
    #     training_cost,
    #     grad_clip,
    #     optimizer_update_fn,
    #     metrics_state_update_fn)
    # Also, we can donate buffers for 'optimizer', 'batch_stats',
    # 'batch' and 'training_metrics_state' for update's pmapped computation.
    update_pmapped = jax.pmap(update_fn,
                              axis_name='batch',
                              in_axes=(0, 0, 0, 0, 0, None, None, None, 0, 0),
                              donate_argnums=(0, 1, 2, 8))
    # During eval, we can donate the 'batch' buffer. We don't donate the
    # 'params' and 'batch_stats' buffers as we don't re-assign those values in
    # eval, we do that only in train.
    evaluate_batch_pmapped = jax.pmap(model.evaluate_batch,
                                      axis_name='batch',
                                      donate_argnums=(2, ))
    start_time = time.time()
    start_step = global_step
    prev_eval_step = start_step

    def get_step_frequency(cur_step):
        return float(cur_step - start_step) / (time.time() - start_time)

    if jax.process_index() == 0:
        trainer_utils.log_message('Starting training!', pool, xm_work_unit)

    # Numpy array of range(0, local_device_count) to send to each device to be
    # folded into the RNG inside each train step to get a unique per-device RNG.
    local_device_indices = np.arange(jax.local_device_count())

    # Start at the resumed step and continue until we have finished the number of
    # training steps. If building a dataset iterator using a tf.data.Dataset, in
    # the case of a batch size that does not evenly divide the training dataset
    # size, if using `ds.batch(..., drop_remainer=True)` on the training dataset
    # then the final batch in this iterator will be a partial batch. However, if
    # `drop_remainer=False`, then this iterator will always return batches of the
    # same size, and the final batch will have elements from the start of the
    # (num_epochs + 1)-th epoch.
    train_iter = itertools.islice(dataset.train_iterator_fn(), global_step,
                                  num_train_steps)

    eval_callbacks = []
    rng, callback_rng = jax.random.split(rng)
    callback_rngs = jax.random.split(callback_rng, len(callback_configs))
    for callback_rng, config in zip(callback_rngs, callback_configs):
        eval_callback = callbacks.get_callback(config['callback_name'])(
            model, params, batch_stats, optimizer_state, dataset, hps, config,
            train_dir, callback_rng)
        eval_callbacks.append(eval_callback)

    train_iter = trainer_utils.prefetch_input_pipeline(
        train_iter, hps.num_device_prefetches)

    eval_start_time = start_time
    eval_start_step = start_step
    for _ in range(start_step, num_train_steps):
        with jax.profiler.StepTraceAnnotation('train', step_num=global_step):
            # NOTE(dsuo): to properly profile each step, we must include batch
            # creation in the StepTraceContext (as opposed to putting `train_iter`
            # directly in the top-level for loop).
            batch = next(train_iter)

            if global_step in checkpoint_steps and jax.process_index() == 0:
                checkpoint.save_unreplicated_checkpoint_background(
                    checkpoint_dir,
                    optimizer_state,
                    params,
                    batch_stats,
                    metrics_state,
                    global_step,
                    preemption_count,
                    sum_train_cost,
                    max_to_keep=None)
            lr = lr_fn(global_step)
            optimizer_state, params, batch_stats, sum_train_cost, metrics_state, grad_norm = update_pmapped(
                optimizer_state, params, batch_stats, metrics_state, batch,
                global_step, lr, rng, local_device_indices, sum_train_cost)
            global_step += 1
            # TODO(gdahl, gilmer): consider moving this test up.
            # NB: Since this test is after we increment global_step, having 0 in
            # eval_steps does nothing.
            if trainer_utils.should_eval(global_step, eval_frequency,
                                         eval_steps):
                train_steps_per_sec = (global_step - eval_start_step) / (
                    time.time() - eval_start_time)
                eval_start_step = global_step
                eval_start_time = time.time()
                batch_stats = trainer_utils.maybe_sync_batchnorm_stats(
                    batch_stats)
                report, eval_time = eval_metrics(params, batch_stats, dataset,
                                                 eval_num_batches,
                                                 eval_train_num_batches,
                                                 evaluate_batch_pmapped)
                mean_train_cost = sum_train_cost.mean().item() / max(
                    1, global_step - prev_eval_step)
                report.update(
                    learning_rate=float(lr),
                    global_step=global_step,
                    epoch=global_step * hps.batch_size // hps.train_size,
                    train_steps_per_sec=train_steps_per_sec,
                    overall_steps_per_sec=get_step_frequency(global_step),
                    eval_time=eval_time,
                    grad_norm=np.mean(grad_norm),
                    preemption_count=preemption_count,
                    train_cost=mean_train_cost)

                for eval_callback in eval_callbacks:
                    callback_metrics = eval_callback.run_eval(
                        params, batch_stats, optimizer_state, global_step)
                    if set(callback_metrics.keys()).intersection(
                            set(report.keys())):
                        raise ValueError(
                            'There was a collision between the callback'
                            'metrics and the standard eval metrics keys')
                    report.update(callback_metrics)
                yield report
                if jax.process_index() == 0:
                    trainer_utils.log_eta(pool, xm_work_unit, global_step,
                                          train_steps_per_sec, num_train_steps,
                                          start_time, eval_frequency,
                                          eval_steps, eval_time)
                    trainer_utils.log_epoch_report(report, metrics_logger)
                    trainer_utils.maybe_log_training_metrics(
                        metrics_state, metrics_summary_fn, metrics_logger)
                    checkpoint.save_unreplicated_checkpoint_background(
                        train_dir, optimizer_state, params, batch_stats,
                        metrics_state, global_step, preemption_count,
                        sum_train_cost)
                sum_train_cost = jnp.zeros(jax.local_device_count())
                prev_eval_step = global_step

                early_stopping_condition = trainer_utils.check_for_early_stopping(
                    early_stopping_target_name, early_stopping_target_value,
                    early_stopping_mode, report)
                if early_stopping_condition:
                    comparison_string = '>=' if early_stopping_mode == 'above' else '<='
                    logging.info(
                        'Early stopping because metric %s=%f, reached the target value '
                        'of %s %f.', early_stopping_target_name,
                        report[early_stopping_target_name], comparison_string,
                        early_stopping_target_value)
                    return

    # Always log and checkpoint on host 0 at the end of training.
    # If we moved where in the loop body evals happen then we would not need this
    # test.
    if prev_eval_step != num_train_steps:
        train_steps_per_sec = (global_step - eval_start_step) / (
            time.time() - eval_start_time)
        batch_stats = trainer_utils.maybe_sync_batchnorm_stats(batch_stats)
        report, eval_time = eval_metrics(params, batch_stats, dataset,
                                         eval_num_batches,
                                         eval_train_num_batches,
                                         evaluate_batch_pmapped)
        lr = lr_fn(global_step)
        # Correct the average for the final partial epoch.
        mean_train_cost = sum_train_cost.mean().item() / max(
            1, global_step - prev_eval_step)
        report.update(learning_rate=float(lr),
                      global_step=global_step,
                      epoch=global_step * hps.batch_size // hps.train_size,
                      train_steps_per_sec=train_steps_per_sec,
                      overall_steps_per_sec=get_step_frequency(global_step),
                      eval_time=eval_time,
                      grad_norm=np.mean(grad_norm),
                      preemption_count=preemption_count,
                      train_cost=mean_train_cost)
        yield report
        if jax.process_index() == 0:
            trainer_utils.log_eta(pool, xm_work_unit, global_step,
                                  train_steps_per_sec, num_train_steps,
                                  start_time, eval_frequency, eval_steps,
                                  eval_time)
            trainer_utils.log_epoch_report(report, metrics_logger)
            trainer_utils.maybe_log_training_metrics(metrics_state,
                                                     metrics_summary_fn,
                                                     metrics_logger)
            checkpoint.save_unreplicated_checkpoint_background(
                train_dir, optimizer_state, params, batch_stats, metrics_state,
                global_step, preemption_count, sum_train_cost)
    # To make sure the last checkpoint was correctly saved.
    checkpoint.wait_for_checkpoint_save()