Beispiel #1
0
 def test_schedule_stretching(self):
     """Test that schedules can be properly stretched."""
     max_training_steps = 100
     lr_hparams = config_dict.ConfigDict({
         'schedule': 'mlperf_polynomial',
         'base_lr': 10.0,
         'warmup_steps': 10,
         'decay_end': -1,
         'end_lr': 1e-4,
         'power': 2.0,
         'start_lr': 0.0,
         'warmup_power': 1.0,
     })
     lr_fn = schedules.get_schedule_fn(lr_hparams, max_training_steps)
     stretch_factor = 3
     stretched_lr_fn = schedules.get_schedule_fn(
         lr_hparams, max_training_steps, stretch_factor=stretch_factor)
     lrs = [lr_fn(t) for t in range(max_training_steps)]
     stretched_lrs = [
         stretched_lr_fn(t)
         for t in range(stretch_factor * max_training_steps)
     ]
     self.assertEqual(lrs, stretched_lrs[::stretch_factor])
     self.assertEqual(lrs, stretched_lrs[1::stretch_factor])
     self.assertEqual(lrs, stretched_lrs[2::stretch_factor])
     # Assert that the stretched schedule has proper staircase behavior.
     for update_step in range(max_training_steps):
         start = update_step * stretch_factor
         end = (update_step + 1) * stretch_factor
         expected = [lrs[update_step]] * stretch_factor
         self.assertEqual(stretched_lrs[start:end], expected)
Beispiel #2
0
    def test_raises(self):
        """Test that an exception is raised with extra hparams."""
        good_hps = config_dict.ConfigDict(
            dict(
                lr_hparams={
                    'schedule': 'mlperf_polynomial',
                    'base_lr': .1,
                    'warmup_steps': 200,
                    'decay_end': -1,
                    'end_lr': 1e-4,
                    'power': 2.0,
                    'start_lr': 0.0,
                    'warmup_power': 1.0,
                }))
        bad_hps = config_dict.ConfigDict(
            dict(
                lr_hparams={
                    'schedule': 'mlperf_polynomial',
                    'base_lr': .1,
                    'warmup_steps': 200,
                    'initial_value': .1,
                    'decay_end': -1,
                    'end_lr': 1e-4,
                    'power': 2.0,
                    'start_lr': 0.0,
                }))
        bad_hps2 = config_dict.ConfigDict(
            dict(
                lr_hparams={
                    'schedule': 'polynomial',
                    'power': 2.0,
                    'initial_value': .1,
                    'end_factor': .01,
                    'decay_steps': 200,
                    'decay_steps_factor': 0.5
                }))
        # This should pass.
        schedules.get_schedule_fn(good_hps.lr_hparams, 1)

        # This should raise an exception due to the extra hparam.
        with self.assertRaises(ValueError):
            schedules.get_schedule_fn(bad_hps.lr_hparams, 1)

        # This should raise an exception due to the mutually exclusive hparams.
        with self.assertRaises(ValueError):
            schedules.get_schedule_fn(bad_hps2.lr_hparams, 1)
Beispiel #3
0
 def test_polynomial_decay_decay_steps(self):
     """Test polynomial schedule works correctly with decay_steps."""
     hps = config_dict.ConfigDict(
         dict(
             lr_hparams={
                 'schedule': 'polynomial',
                 'power': 2.0,
                 'initial_value': .1,
                 'end_factor': .01,
                 'decay_steps': 200,
             }))
     max_training_steps = 400
     lr_fn = schedules.get_schedule_fn(hps.lr_hparams, max_training_steps)
     hps = hps.lr_hparams
     decay_steps = hps['decay_steps']
     for step in range(max_training_steps):
         expected_learning_rate = tf.train.polynomial_decay(
             hps['initial_value'],
             step,
             decay_steps,
             hps['end_factor'] * hps['initial_value'],
             power=hps['power'])().numpy()
         self.assertAlmostEqual(lr_fn(step), expected_learning_rate)
Beispiel #4
0
def train(train_dir,
          model,
          dataset_builder,
          initializer,
          num_train_steps,
          hps,
          rng,
          eval_batch_size,
          eval_num_batches,
          eval_train_num_batches,
          eval_frequency,
          checkpoint_steps,
          early_stopping_target_name=None,
          early_stopping_target_value=None,
          early_stopping_mode=None,
          eval_steps=None,
          metrics_logger=None,
          init_logger=None,
          training_metrics_config=None,
          callback_configs=None,
          external_checkpoint_path=None):
    """Main training loop.

  Trains the given network on the specified dataset for the given number of
  epochs. Saves the training curve in train_dir/r=3/results.tsv.

  Args:
    train_dir: (str) Path of the training directory.
    model: (BaseModel) Model object to be trained.
    dataset_builder: dataset builder returned by datasets.get_dataset.
    initializer: Must have API as defined in initializers.py
    num_train_steps: (int) Number of steps to train on.
    hps: (tf.HParams) Model, initialization and training hparams.
    rng: (jax.random.PRNGKey) Rng seed used in model initialization and data
      shuffling.
    eval_batch_size: the evaluation batch size. If None, use hps.batch_size.
    eval_num_batches: (int) The number of batches used for evaluating on
      validation and test sets. Set to None to evaluate on the whole train set.
    eval_train_num_batches: (int) The number of batches for evaluating on train.
      Set to None to evaluate on the whole training set.
    eval_frequency: (int) Evaluate every k steps.
    checkpoint_steps: List of integers indicating special steps to save
      checkpoints at. These checkpoints do not get used for preemption recovery.
    early_stopping_target_name: A string naming the metric to use to perform
       early stopping. If this metric reaches the value
      `early_stopping_target_value`, training will stop. Must include the
      dataset split (ex: validation/error_rate).
    early_stopping_target_value: A float indicating the value at which to stop
      training.
    early_stopping_mode: One of "above" or "below", indicates if we should stop
      when the metric is above or below the threshold value. Example: if
      "above", then training will stop when
      `report[early_stopping_target_name] >= early_stopping_target_value`.
    eval_steps: List of integers indicating which steps to perform evals. If
      provided, eval_frequency will be ignored. Performing an eval implies
      saving a checkpoint that will be used to resume training in the case of
      preemption.
    metrics_logger: Used to log all eval metrics during training. See
      utils.MetricLogger for API definition.
    init_logger: Used for black box initializers that have learning curves.
    training_metrics_config: Dict specifying the configuration of the
      training_metrics_grabber. Set to None to skip logging of advanced training
      metrics.
    callback_configs: List of configs specifying general callbacks to run
      during the eval phase. Empty list means no callbacks are run. See
      callbacks.py for details on what is expected in a config.
    external_checkpoint_path: (str) If this argument is set, we will load the
      optimizer_state, params, batch_stats, and training_metrics from the
      checkpoint at this location.

  Yields:
    metrics: A dictionary of all eval metrics from the given epoch.
  """
    # NOTE: the initialization RNG should *not* be per-host, as this will create
    # different sets of weights per host. However, all other RNGs should be
    # per-host.
    # TODO(znado,gilmer,gdahl): implement replicating the same initialization
    # across hosts.
    rng, init_rng = jax.random.split(rng)
    rng = jax.random.fold_in(rng, jax.process_index())
    rng, data_rng = jax.random.split(rng)

    # only used if checkpoints_steps is non-empty.
    checkpoint_dir = os.path.join(train_dir, 'checkpoints')

    # For logging / processing off the main thread
    pool = multiprocessing.pool.ThreadPool()

    if jax.process_index() == 0:
        logging.info('Let the training begin!')
        logging.info('Dataset input shape: %r', hps.input_shape)
        logging.info('Hyperparameters: %s', hps)

    if eval_batch_size is None:
        eval_batch_size = hps.batch_size
    if callback_configs is None:
        callback_configs = []

    # Maybe run the initializer.
    unreplicated_params, unreplicated_batch_stats = init_utils.initialize(
        model.flax_module, initializer, model.loss_fn,
        hps.input_shape, hps.output_shape, hps, init_rng, init_logger,
        model.get_fake_batch(hps))

    if jax.process_index() == 0:
        utils.log_pytree_shape_and_statistics(unreplicated_params)
        logging.info('train_size: %d,', hps.train_size)

    # Note that global_step refers to the number of gradients calculations, not
    # the number of model updates. This means when using gradient accumulation,
    # one must supply configs where the number of steps are in units of gradient
    # calculations, not model updates, and in post processing one must divide
    # global_step by grad_accum_step_multiplier to get the number of updates.
    #
    # If using gradient accumulation, stretch the learning rate schedule by the
    # number of gradient calculations per weight update.
    stretch_factor = 1
    if hps.get('total_accumulated_batch_size') is not None:
        stretch_factor = hps.total_accumulated_batch_size // hps.batch_size
    lr_fn = schedules.get_schedule_fn(hps.lr_hparams,
                                      num_train_steps,
                                      stretch_factor=stretch_factor)

    optimizer_init_fn, optimizer_update_fn = optimizers.get_optimizer(
        hps, model)
    unreplicated_optimizer_state = optimizer_init_fn(unreplicated_params)

    (unreplicated_metrics_state, metrics_update_fn,
     metrics_summary_fn) = None, None, None
    if training_metrics_config is not None:
        (metrics_init_fn, metrics_update_fn,
         metrics_summary_fn) = make_training_metrics(num_train_steps,
                                                     **training_metrics_config)
        unreplicated_metrics_state = metrics_init_fn(unreplicated_params)

    (optimizer_state, params, batch_stats, metrics_state, global_step,
     sum_train_cost, preemption_count,
     is_restored) = checkpoint.replicate_and_maybe_restore_checkpoint(
         unreplicated_optimizer_state,
         unreplicated_params,
         unreplicated_batch_stats,
         unreplicated_metrics_state,
         train_dir=train_dir,
         external_checkpoint_path=external_checkpoint_path)

    if is_restored:
        preemption_count += 1
        # Fold the restored step into the dataset RNG so that we will get a
        # different shuffle each time we restore, so that we do not repeat a
        # previous dataset ordering again after restoring. This is not the only
        # difference in shuffling each pre-emption, because we often times reshuffle
        # the input files each time in a non-deterministic manner.
        #
        # Note that if we are pre-empted more than once per epoch then we will
        # retrain more on the beginning of the training split, because each time we
        # restore we refill the shuffle buffer with the first `shuffle_buffer_size`
        # elements from the training split to continue training.
        #
        # Also note that for evaluating on the training split, because we are
        # reshuffling each time, we will get a new eval_train split each time we are
        # pre-empted.
        data_rng = jax.random.fold_in(data_rng, global_step)

    assert hps.batch_size % (jax.device_count()) == 0
    assert eval_batch_size % (jax.device_count()) == 0
    dataset = dataset_builder(
        data_rng,
        hps.batch_size,
        eval_batch_size=eval_batch_size,
        hps=hps,
    )

    update_fn = functools.partial(update,
                                  training_cost=model.training_cost,
                                  grad_clip=hps.get('grad_clip'),
                                  optimizer_update_fn=optimizer_update_fn,
                                  metrics_update_fn=metrics_update_fn)
    # in_axes = (
    #     optimizer_state = 0,
    #     params = 0,
    #     batch_stats = 0,
    #     metrics_state = 0,
    #     batch = 0,
    #     step = None,
    #     lr = None,
    #     rng = None,
    #     local_device_index = 0,
    #     running_train_cost = 0,
    #     training_cost,
    #     grad_clip,
    #     optimizer_update_fn,
    #     metrics_state_update_fn)
    # Also, we can donate buffers for 'optimizer', 'batch_stats',
    # 'batch' and 'training_metrics_state' for update's pmapped computation.
    update_pmapped = jax.pmap(update_fn,
                              axis_name='batch',
                              in_axes=(0, 0, 0, 0, 0, None, None, None, 0, 0),
                              donate_argnums=(0, 1, 2, 8))
    # During eval, we can donate the 'batch' buffer. We don't donate the
    # 'params' and 'batch_stats' buffers as we don't re-assign those values in
    # eval, we do that only in train.
    evaluate_batch_pmapped = jax.pmap(model.evaluate_batch,
                                      axis_name='batch',
                                      donate_argnums=(2, ))
    start_time = time.time()
    start_step = global_step
    prev_eval_step = start_step

    def get_step_frequency(cur_step):
        return float(cur_step - start_step) / (time.time() - start_time)

    if jax.process_index() == 0:
        trainer_utils.log_message('Starting training!', pool, xm_work_unit)

    # Numpy array of range(0, local_device_count) to send to each device to be
    # folded into the RNG inside each train step to get a unique per-device RNG.
    local_device_indices = np.arange(jax.local_device_count())

    # Start at the resumed step and continue until we have finished the number of
    # training steps. If building a dataset iterator using a tf.data.Dataset, in
    # the case of a batch size that does not evenly divide the training dataset
    # size, if using `ds.batch(..., drop_remainer=True)` on the training dataset
    # then the final batch in this iterator will be a partial batch. However, if
    # `drop_remainer=False`, then this iterator will always return batches of the
    # same size, and the final batch will have elements from the start of the
    # (num_epochs + 1)-th epoch.
    train_iter = itertools.islice(dataset.train_iterator_fn(), global_step,
                                  num_train_steps)

    eval_callbacks = []
    rng, callback_rng = jax.random.split(rng)
    callback_rngs = jax.random.split(callback_rng, len(callback_configs))
    for callback_rng, config in zip(callback_rngs, callback_configs):
        eval_callback = callbacks.get_callback(config['callback_name'])(
            model, params, batch_stats, optimizer_state, dataset, hps, config,
            train_dir, callback_rng)
        eval_callbacks.append(eval_callback)

    train_iter = trainer_utils.prefetch_input_pipeline(
        train_iter, hps.num_device_prefetches)

    eval_start_time = start_time
    eval_start_step = start_step
    for _ in range(start_step, num_train_steps):
        with jax.profiler.StepTraceAnnotation('train', step_num=global_step):
            # NOTE(dsuo): to properly profile each step, we must include batch
            # creation in the StepTraceContext (as opposed to putting `train_iter`
            # directly in the top-level for loop).
            batch = next(train_iter)

            if global_step in checkpoint_steps and jax.process_index() == 0:
                checkpoint.save_unreplicated_checkpoint_background(
                    checkpoint_dir,
                    optimizer_state,
                    params,
                    batch_stats,
                    metrics_state,
                    global_step,
                    preemption_count,
                    sum_train_cost,
                    max_to_keep=None)
            lr = lr_fn(global_step)
            optimizer_state, params, batch_stats, sum_train_cost, metrics_state, grad_norm = update_pmapped(
                optimizer_state, params, batch_stats, metrics_state, batch,
                global_step, lr, rng, local_device_indices, sum_train_cost)
            global_step += 1
            # TODO(gdahl, gilmer): consider moving this test up.
            # NB: Since this test is after we increment global_step, having 0 in
            # eval_steps does nothing.
            if trainer_utils.should_eval(global_step, eval_frequency,
                                         eval_steps):
                train_steps_per_sec = (global_step - eval_start_step) / (
                    time.time() - eval_start_time)
                eval_start_step = global_step
                eval_start_time = time.time()
                batch_stats = trainer_utils.maybe_sync_batchnorm_stats(
                    batch_stats)
                report, eval_time = eval_metrics(params, batch_stats, dataset,
                                                 eval_num_batches,
                                                 eval_train_num_batches,
                                                 evaluate_batch_pmapped)
                mean_train_cost = sum_train_cost.mean().item() / max(
                    1, global_step - prev_eval_step)
                report.update(
                    learning_rate=float(lr),
                    global_step=global_step,
                    epoch=global_step * hps.batch_size // hps.train_size,
                    train_steps_per_sec=train_steps_per_sec,
                    overall_steps_per_sec=get_step_frequency(global_step),
                    eval_time=eval_time,
                    grad_norm=np.mean(grad_norm),
                    preemption_count=preemption_count,
                    train_cost=mean_train_cost)

                for eval_callback in eval_callbacks:
                    callback_metrics = eval_callback.run_eval(
                        params, batch_stats, optimizer_state, global_step)
                    if set(callback_metrics.keys()).intersection(
                            set(report.keys())):
                        raise ValueError(
                            'There was a collision between the callback'
                            'metrics and the standard eval metrics keys')
                    report.update(callback_metrics)
                yield report
                if jax.process_index() == 0:
                    trainer_utils.log_eta(pool, xm_work_unit, global_step,
                                          train_steps_per_sec, num_train_steps,
                                          start_time, eval_frequency,
                                          eval_steps, eval_time)
                    trainer_utils.log_epoch_report(report, metrics_logger)
                    trainer_utils.maybe_log_training_metrics(
                        metrics_state, metrics_summary_fn, metrics_logger)
                    checkpoint.save_unreplicated_checkpoint_background(
                        train_dir, optimizer_state, params, batch_stats,
                        metrics_state, global_step, preemption_count,
                        sum_train_cost)
                sum_train_cost = jnp.zeros(jax.local_device_count())
                prev_eval_step = global_step

                early_stopping_condition = trainer_utils.check_for_early_stopping(
                    early_stopping_target_name, early_stopping_target_value,
                    early_stopping_mode, report)
                if early_stopping_condition:
                    comparison_string = '>=' if early_stopping_mode == 'above' else '<='
                    logging.info(
                        'Early stopping because metric %s=%f, reached the target value '
                        'of %s %f.', early_stopping_target_name,
                        report[early_stopping_target_name], comparison_string,
                        early_stopping_target_value)
                    return

    # Always log and checkpoint on host 0 at the end of training.
    # If we moved where in the loop body evals happen then we would not need this
    # test.
    if prev_eval_step != num_train_steps:
        train_steps_per_sec = (global_step - eval_start_step) / (
            time.time() - eval_start_time)
        batch_stats = trainer_utils.maybe_sync_batchnorm_stats(batch_stats)
        report, eval_time = eval_metrics(params, batch_stats, dataset,
                                         eval_num_batches,
                                         eval_train_num_batches,
                                         evaluate_batch_pmapped)
        lr = lr_fn(global_step)
        # Correct the average for the final partial epoch.
        mean_train_cost = sum_train_cost.mean().item() / max(
            1, global_step - prev_eval_step)
        report.update(learning_rate=float(lr),
                      global_step=global_step,
                      epoch=global_step * hps.batch_size // hps.train_size,
                      train_steps_per_sec=train_steps_per_sec,
                      overall_steps_per_sec=get_step_frequency(global_step),
                      eval_time=eval_time,
                      grad_norm=np.mean(grad_norm),
                      preemption_count=preemption_count,
                      train_cost=mean_train_cost)
        yield report
        if jax.process_index() == 0:
            trainer_utils.log_eta(pool, xm_work_unit, global_step,
                                  train_steps_per_sec, num_train_steps,
                                  start_time, eval_frequency, eval_steps,
                                  eval_time)
            trainer_utils.log_epoch_report(report, metrics_logger)
            trainer_utils.maybe_log_training_metrics(metrics_state,
                                                     metrics_summary_fn,
                                                     metrics_logger)
            checkpoint.save_unreplicated_checkpoint_background(
                train_dir, optimizer_state, params, batch_stats, metrics_state,
                global_step, preemption_count, sum_train_cost)
    # To make sure the last checkpoint was correctly saved.
    checkpoint.wait_for_checkpoint_save()
Beispiel #5
0
 def test_mlperf_schedule(self):
     """Test there are no changes to the MLPerf polynomial decay schedule."""
     expected_lrs = [
         0.0,
         0.2,
         0.4,
         0.6,
         0.8,
         1.0,
         1.2,
         1.4,
         1.6,
         1.8,
         2.0,
         2.2,
         2.4,
         2.6,
         2.8,
         3.0,
         3.2,
         3.4,
         3.6,
         3.8,
         4.0,
         4.2,
         4.4,
         4.6,
         4.8,
         5.0,
         5.2,
         5.4,
         5.6,
         5.8,
         6.0,
         6.2,
         6.4,
         6.6,
         6.8,
         7.0,
         7.2,
         7.4,
         7.6,
         7.8,
         8.0,
         8.2,
         8.4,
         8.6,
         8.8,
         9.0,
         9.2,
         9.4,
         9.6,
         9.8,
         10.0,
         9.802962,
         9.607885,
         9.414769,
         9.223614,
         9.034419,
         8.847184,
         8.661909,
         8.478596,
         8.297242,
         8.117851,
         7.940418,
         7.764947,
         7.5914364,
         7.419886,
         7.2502966,
         7.082668,
         6.917,
         6.7532916,
         6.591545,
         6.4317584,
         6.273932,
         6.1180663,
         5.964162,
         5.812217,
         5.662234,
         5.5142093,
         5.368148,
         5.2240453,
         5.0819044,
         4.941723,
         4.803503,
         4.6672425,
         4.532944,
         4.4006047,
         4.2702274,
         4.1418095,
         4.0153522,
         3.8908558,
         3.7683203,
         3.647745,
         3.5291305,
         3.4124763,
         3.297783,
         3.18505,
         3.0742776,
         2.965466,
         2.858614,
         2.753724,
         2.6507936,
         2.5498245,
         2.4508152,
         2.353767,
         2.2586792,
         2.1655521,
         2.0743854,
         1.9851794,
         1.8979341,
         1.8126491,
         1.7293249,
         1.6479613,
         1.568558,
         1.4911155,
         1.4156334,
         1.342112,
         1.2705511,
         1.2009507,
         1.133311,
         1.0676318,
         1.0039133,
         0.94215524,
         0.8823574,
         0.8245205,
         0.7686443,
         0.71472853,
         0.66277343,
         0.61277884,
         0.56474483,
         0.5186714,
         0.4745585,
         0.43240622,
         0.39221448,
         0.35398334,
         0.31771275,
         0.28340274,
         0.2510533,
         0.22066444,
         0.19223614,
         0.16576843,
         0.14126128,
         0.11871469,
         0.098128565,
         0.079503134,
         0.06283828,
         0.048134,
         0.035390284,
         0.02460714,
         0.01578457,
         0.00892257,
         0.004021142,
     ]
     hps = config_dict.ConfigDict(
         dict(
             lr_hparams={
                 'schedule': 'mlperf_polynomial',
                 'base_lr': 10.0,
                 'warmup_steps': 50,
                 'decay_end': -1,
                 'end_lr': 1e-4,
                 'power': 2.0,
                 'start_lr': 0.0,
                 'warmup_power': 1.0,
             }))
     max_training_steps = 50
     lr_fn = schedules.get_schedule_fn(hps.lr_hparams, max_training_steps)
     for step in range(max_training_steps):
         self.assertAlmostEqual(lr_fn(step), expected_lrs[step])
Beispiel #6
0
def train(train_dir,
          model,
          dataset_builder,
          initializer,
          num_train_steps,
          hps,
          rng,
          eval_batch_size,
          eval_num_batches,
          eval_train_num_batches,
          eval_frequency,
          checkpoint_steps,
          eval_steps=None,
          metrics_logger=None,
          init_logger=None,
          training_metrics_config=None,
          use_deprecated_checkpointing=True):
    """Main training loop.

  Trains the given network on the specified dataset for the given number of
  epochs. Saves the training curve in train_dir/r=3/results.tsv.

  Args:
    train_dir: (str) Path of the training directory.
    model: (BaseModel) Model object to be trained.
    dataset_builder: dataset builder returned by datasets.get_dataset.
    initializer: Must have API as defined in initializers.py
    num_train_steps: (int) Number of steps to train on.
    hps: (tf.HParams) Model, initialization and training hparams.
    rng: (jax.random.PRNGKey) Rng seed used in model initialization and data
      shuffling.
    eval_batch_size: the evaluation batch size. If None, use hps.batch_size.
    eval_num_batches: (int) The number of batches used for evaluating on
      validation and test sets. Set to None to evaluate on the whole train set.
    eval_train_num_batches: (int) The number of batches for evaluating on train.
      Set to None to evaluate on the whole training set.
    eval_frequency: (int) Evaluate every k steps.
    checkpoint_steps: List of integers indicating special steps to save
      checkpoints at. These checkpoints do not get used for preemption recovery.
    eval_steps: List of integers indicating which steps to perform evals. If
      provided, eval_frequency will be ignored. Performing an eval implies
      saving a checkpoint that will be used to resume training in the case of
      preemption.
    metrics_logger: Used to log all eval metrics during training. See
      utils.MetricLogger for API definition.
    init_logger: Used for black box initializers that have learning curves.
    training_metrics_config: Dict specifying the configuration of the
      training_metrics_grabber. Set to None to skip logging of advanced training
      metrics.
    use_deprecated_checkpointing: Whether to use deprecated checkpointing.

  Yields:
    metrics: A dictionary of all eval metrics from the given epoch.
  """
    # NOTE: the initialization RNG should *not* be per-host, as this will create
    # different sets of weights per host. However, all other RNGs should be
    # per-host.
    # TODO(znado,gilmer,gdahl): implement replicating the same initialization
    # across hosts.
    rng, init_rng = jax.random.split(rng)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, data_rng = jax.random.split(rng)

    # only used if checkpoints_steps is non-empty.
    checkpoint_dir = os.path.join(train_dir, 'checkpoints')

    if jax.host_id() == 0:
        logging.info('Let the training begin!')
        logging.info('Dataset input shape: %r', hps.input_shape)
        logging.info('Hyperparameters: %s', hps)

    if eval_batch_size is None:
        eval_batch_size = hps.batch_size

    # Maybe run the initializer.
    flax_module, batch_stats = initialize(model.flax_module_def, initializer,
                                          model.loss_fn, hps.input_shape,
                                          hps.output_shape, hps, init_rng,
                                          init_logger)

    if jax.host_id() == 0:
        utils.log_pytree_shape_and_statistics(flax_module.params)
        logging.info('train_size: %d,', hps.train_size)

    lr_fn = schedules.get_schedule_fn(hps.lr_hparams, num_train_steps)

    optimizer = get_optimizer(hps).create(flax_module)

    training_metrics_grabber = None
    if training_metrics_config:
        training_metrics_grabber = utils.TrainingMetricsGrabber.create(
            optimizer.target.params, training_metrics_config)

    (optimizer, batch_stats, training_metrics_grabber, global_step,
     sum_train_cost, preemption_count,
     is_restored) = _maybe_restore_latest_checkpoint(
         unreplicated_optimizer=optimizer,
         unreplicated_batch_stats=batch_stats,
         unreplicated_training_metrics_grabber=training_metrics_grabber,
         train_dir=train_dir,
         use_deprecated_checkpointing=use_deprecated_checkpointing)

    if is_restored:
        preemption_count += 1
        # Fold the restored step into the dataset RNG so that we will get a
        # different shuffle each time we restore, so that we do not repeat a
        # previous dataset ordering again after restoring. This is not the only
        # difference in shuffling each pre-emption, because we often times reshuffle
        # the input files each time in a non-deterministic manner.
        #
        # Note that if we are pre-empted more than once per epoch then we will
        # retrain more on the beginning of the training split, because each time we
        # restore we refill the shuffle buffer with the first `shuffle_buffer_size`
        # elements from the training split to continue training.
        #
        # Also note that for evaluating on the training split, because we are
        # reshuffling each time, we will get a new eval_train split each time we are
        # pre-empted.
        data_rng = jax.random.fold_in(data_rng, global_step)

    assert hps.batch_size % (jax.device_count()) == 0
    assert eval_batch_size % (jax.device_count()) == 0
    dataset = dataset_builder(
        data_rng,
        hps.batch_size,
        eval_batch_size=eval_batch_size,
        hps=hps,
    )

    # pmap functions for the training loop
    # in_axes = (optimizer = 0, batch_stats = 0, batch = 0, step = None,
    # lr = None, rng = None, local_device_index = 0, training_metrics_grabber = 0,
    # training_metrics_grabber, training_cost )
    update_pmapped = jax.pmap(functools.partial(
        update, training_cost=model.training_cost),
                              axis_name='batch',
                              in_axes=(0, 0, 0, None, None, None, 0, 0))
    evaluate_batch_pmapped = jax.pmap(model.evaluate_batch, axis_name='batch')
    start_time = time.time()
    start_step = global_step
    prev_eval_step = start_step

    def get_step_frequency(cur_step):
        return float(cur_step - start_step) / (time.time() - start_time)

    if jax.host_id() == 0:
        logging.info('Starting training!')

    # Numpy array of range(0, local_device_count) to send to each device to be
    # folded into the RNG inside each train step to get a unique per-device RNG.
    local_device_indices = np.arange(jax.local_device_count())

    # Start at the resumed step and continue until we have finished the number of
    # training steps. If building a dataset iterator using a tf.data.Dataset, in
    # the case of a batch size that does not evenly divide the training dataset
    # size, if using `ds.batch(..., drop_remainer=True)` on the training dataset
    # then the final batch in this iterator will be a partial batch. However, if
    # `drop_remainer=False`, then this iterator will always return batches of the
    # same size, and the final batch will have elements from the start of the
    # (num_epochs + 1)-th epoch.
    train_iter = itertools.islice(dataset.train_iterator_fn(), global_step,
                                  num_train_steps)
    for batch in train_iter:
        if global_step in checkpoint_steps and jax.host_id() == 0:
            save_checkpoint(
                checkpoint_dir, {
                    'optimizer': optimizer,
                    'batch_stats': batch_stats,
                    'training_metrics_grabber': training_metrics_grabber
                },
                global_step,
                preemption_count,
                sum_train_cost,
                max_to_keep=None,
                use_deprecated_checkpointing=use_deprecated_checkpointing)
        batch = data_utils.shard(batch)
        lr = lr_fn(global_step)
        optimizer, batch_stats, cost_val, training_metrics_grabber = update_pmapped(
            optimizer, batch_stats, batch, global_step, lr, rng,
            local_device_indices, training_metrics_grabber)
        # Calling float is needed since cost_val is a shape (1,) DeviceArray.
        sum_train_cost += float(jnp.mean(cost_val))
        global_step += 1
        # TODO(gdahl, gilmer): consider moving this test up.
        # NB: Since this test is after we increment global_step, having 0 in
        # eval_steps does nothing.
        if should_eval(global_step, eval_frequency, eval_steps):
            batch_stats = _maybe_sync_batchnorm_stats(batch_stats)
            report, eval_time = eval_metrics(optimizer.target, batch_stats,
                                             dataset, eval_num_batches,
                                             eval_train_num_batches,
                                             evaluate_batch_pmapped)
            mean_train_cost = sum_train_cost / max(
                1, global_step - prev_eval_step)
            report.update(learning_rate=float(lr),
                          global_step=global_step,
                          epoch=global_step * hps.batch_size // hps.train_size,
                          steps_per_sec=get_step_frequency(global_step),
                          eval_time=eval_time,
                          preemption_count=preemption_count,
                          train_cost=mean_train_cost)
            yield report
            if jax.host_id() == 0:
                _log_epoch_report(report, metrics_logger)
                _maybe_log_training_metrics(training_metrics_grabber,
                                            metrics_logger)
                save_checkpoint(
                    train_dir, {
                        'optimizer': optimizer,
                        'batch_stats': batch_stats,
                        'training_metrics_grabber': training_metrics_grabber
                    },
                    global_step,
                    preemption_count,
                    sum_train_cost,
                    use_deprecated_checkpointing=use_deprecated_checkpointing)
            sum_train_cost = 0.0
            prev_eval_step = global_step

    # Always log and checkpoint on host 0 at the end of training.
    # If we moved where in the loop body evals happen then we would not need this
    # test.
    if prev_eval_step != num_train_steps:
        batch_stats = _maybe_sync_batchnorm_stats(batch_stats)
        report, eval_time = eval_metrics(optimizer.target, batch_stats,
                                         dataset, eval_num_batches,
                                         eval_train_num_batches,
                                         evaluate_batch_pmapped)
        lr = lr_fn(global_step)
        # Correct the average for the final partial epoch.
        mean_train_cost = sum_train_cost / max(1, global_step - prev_eval_step)
        report.update(learning_rate=float(lr),
                      global_step=global_step,
                      epoch=global_step * hps.batch_size // hps.train_size,
                      steps_per_sec=get_step_frequency(global_step),
                      eval_time=eval_time,
                      preemption_count=preemption_count,
                      train_cost=mean_train_cost)
        yield report
        if jax.host_id() == 0:
            _log_epoch_report(report, metrics_logger)
            _maybe_log_training_metrics(training_metrics_grabber,
                                        metrics_logger)
            save_checkpoint(
                train_dir, {
                    'optimizer': optimizer,
                    'batch_stats': batch_stats,
                    'training_metrics_grabber': training_metrics_grabber
                },
                global_step,
                preemption_count,
                sum_train_cost,
                use_deprecated_checkpointing=use_deprecated_checkpointing)
    # To make sure the last checkpoint was correctly saved.
    checkpoint.wait_for_checkpoint_save()