示例#1
0
    def test_initializers(self, init):
        """Test that each initializer runs, and the output is a valid pytree."""

        rng = jax.random.PRNGKey(0)
        flax_module, params, input_shape, model_hps = _load_model(
            'fully_connected')
        _, init_rng = jax.random.split(rng)
        initializer = initializers.get_initializer(init)
        init_hps = initializers.get_initializer_hparams(init)
        init_hps.update(model_hps)
        loss_name = 'cross_entropy'
        loss_fn = losses.get_loss_fn(loss_name)
        new_params = initializer(loss_fn=loss_fn,
                                 flax_module=flax_module,
                                 params=params,
                                 hps=init_hps,
                                 input_shape=input_shape[1:],
                                 output_shape=OUTPUT_SHAPE,
                                 rng_key=init_rng)

        # Check new params are still valid params
        outputs = flax_module.apply({'params': new_params},
                                    jnp.ones(input_shape),
                                    train=True)
        utils.log_pytree_shape_and_statistics(new_params)
        self.assertEqual(outputs.shape, (input_shape[0], OUTPUT_SHAPE[-1]))
示例#2
0
  def test_initialize_rescale(self):
    """Test rescaling a single layer of a model."""
    input_shape = (28, 28, 1)
    output_shape = (10,)
    model_str = 'fully_connected'
    model_cls = models.get_model(model_str)
    model_hps = models.get_model_hparams(model_str)
    loss_name = 'cross_entropy'
    metrics_name = 'classification_metrics'
    hps = copy.copy(model_hps)
    hps.update({'output_shape': output_shape})
    rng = jax.random.PRNGKey(0)
    model = model_cls(hps, {}, loss_name, metrics_name)
    initializer = initializers.get_initializer('noop')

    rng, init_rng = jax.random.split(rng)

    # First initialize with no rescale.
    flax_module, _ = trainer.initialize(
        model.flax_module_def,
        initializer,
        model.loss_fn,
        input_shape,
        output_shape,
        hps,
        init_rng,
        metrics_logger=None)

    utils.log_pytree_shape_and_statistics(flax_module.params)
    # Now rescale a layer by 100.
    rescale_factor = 100
    hps.layer_rescale_factors = {
        '/Dense_1/kernel': rescale_factor,
    }

    rescaled_module, _ = trainer.initialize(
        model.flax_module_def,
        initializer,
        model.loss_fn,
        input_shape,
        output_shape,
        hps,
        init_rng,
        metrics_logger=None)

    # Check the right variable is rescaled
    v1 = flax_module.params['Dense_1']['kernel']
    v2 = rescaled_module.params['Dense_1']['kernel']
    diff = np.linalg.norm(v1.reshape(-1) * rescale_factor - v2.reshape(-1))
    self.assertAlmostEqual(diff, 0.0)

    # Check that other variables are the same
    v1 = flax_module.params['Dense_2']['kernel']
    v2 = rescaled_module.params['Dense_2']['kernel']
    diff = np.linalg.norm(v1.reshape(-1) - v2.reshape(-1))
    self.assertAlmostEqual(diff, 0.0)
示例#3
0
def _load_model(model_name):
    """Load a test model."""
    rng = jax.random.PRNGKey(0)
    model_cls = models.get_model(model_name)
    loss_name = 'cross_entropy'
    metrics_name = 'classification_metrics'
    model_hps = models.get_model_hparams(model_name)

    hps = copy.copy(model_hps)
    hps.update({'output_shape': OUTPUT_SHAPE})
    model = model_cls(hps, {}, loss_name, metrics_name)

    input_shape = (BATCH_SIZE, ) + MODEL_TO_INPUT_SHAPE[model_name]
    _, flax_module = model.flax_module_def.create_by_shape(rng, [input_shape],
                                                           train=True)
    utils.log_pytree_shape_and_statistics(flax_module.params)
    return flax_module, input_shape
示例#4
0
def _load_model(model_name):
    """Load a test model."""
    rng = jax.random.PRNGKey(0)
    model_cls = models.get_model(model_name)
    loss_name = 'cross_entropy'
    metrics_name = 'classification_metrics'
    model_hps = models.get_model_hparams(model_name)

    hps = copy.copy(model_hps)
    hps.update({'output_shape': OUTPUT_SHAPE})
    model = model_cls(hps, {}, loss_name, metrics_name)

    input_shape = (BATCH_SIZE, ) + MODEL_TO_INPUT_SHAPE[model_name]
    model_init_fn = jax.jit(
        functools.partial(model.flax_module.init, train=True))
    init_dict = model_init_fn({'params': rng}, jnp.zeros(input_shape))
    # Trainable model parameters.
    params = init_dict['params']
    utils.log_pytree_shape_and_statistics(params)
    return model.flax_module, params, input_shape, hps
示例#5
0
def train(train_dir,
          model,
          dataset_builder,
          initializer,
          num_train_steps,
          hps,
          rng,
          eval_batch_size,
          eval_num_batches,
          eval_train_num_batches,
          eval_frequency,
          checkpoint_steps,
          early_stopping_target_name=None,
          early_stopping_target_value=None,
          early_stopping_mode=None,
          eval_steps=None,
          metrics_logger=None,
          init_logger=None,
          training_metrics_config=None,
          callback_configs=None,
          external_checkpoint_path=None):
    """Main training loop.

  Trains the given network on the specified dataset for the given number of
  epochs. Saves the training curve in train_dir/r=3/results.tsv.

  Args:
    train_dir: (str) Path of the training directory.
    model: (BaseModel) Model object to be trained.
    dataset_builder: dataset builder returned by datasets.get_dataset.
    initializer: Must have API as defined in initializers.py
    num_train_steps: (int) Number of steps to train on.
    hps: (tf.HParams) Model, initialization and training hparams.
    rng: (jax.random.PRNGKey) Rng seed used in model initialization and data
      shuffling.
    eval_batch_size: the evaluation batch size. If None, use hps.batch_size.
    eval_num_batches: (int) The number of batches used for evaluating on
      validation and test sets. Set to None to evaluate on the whole train set.
    eval_train_num_batches: (int) The number of batches for evaluating on train.
      Set to None to evaluate on the whole training set.
    eval_frequency: (int) Evaluate every k steps.
    checkpoint_steps: List of integers indicating special steps to save
      checkpoints at. These checkpoints do not get used for preemption recovery.
    early_stopping_target_name: A string naming the metric to use to perform
       early stopping. If this metric reaches the value
      `early_stopping_target_value`, training will stop. Must include the
      dataset split (ex: validation/error_rate).
    early_stopping_target_value: A float indicating the value at which to stop
      training.
    early_stopping_mode: One of "above" or "below", indicates if we should stop
      when the metric is above or below the threshold value. Example: if
      "above", then training will stop when
      `report[early_stopping_target_name] >= early_stopping_target_value`.
    eval_steps: List of integers indicating which steps to perform evals. If
      provided, eval_frequency will be ignored. Performing an eval implies
      saving a checkpoint that will be used to resume training in the case of
      preemption.
    metrics_logger: Used to log all eval metrics during training. See
      utils.MetricLogger for API definition.
    init_logger: Used for black box initializers that have learning curves.
    training_metrics_config: Dict specifying the configuration of the
      training_metrics_grabber. Set to None to skip logging of advanced training
      metrics.
    callback_configs: List of configs specifying general callbacks to run
      during the eval phase. Empty list means no callbacks are run. See
      callbacks.py for details on what is expected in a config.
    external_checkpoint_path: (str) If this argument is set, we will load the
      optimizer_state, params, batch_stats, and training_metrics from the
      checkpoint at this location.

  Yields:
    metrics: A dictionary of all eval metrics from the given epoch.
  """
    # NOTE: the initialization RNG should *not* be per-host, as this will create
    # different sets of weights per host. However, all other RNGs should be
    # per-host.
    # TODO(znado,gilmer,gdahl): implement replicating the same initialization
    # across hosts.
    rng, init_rng = jax.random.split(rng)
    rng = jax.random.fold_in(rng, jax.process_index())
    rng, data_rng = jax.random.split(rng)

    # only used if checkpoints_steps is non-empty.
    checkpoint_dir = os.path.join(train_dir, 'checkpoints')

    # For logging / processing off the main thread
    pool = multiprocessing.pool.ThreadPool()

    if jax.process_index() == 0:
        logging.info('Let the training begin!')
        logging.info('Dataset input shape: %r', hps.input_shape)
        logging.info('Hyperparameters: %s', hps)

    if eval_batch_size is None:
        eval_batch_size = hps.batch_size
    if callback_configs is None:
        callback_configs = []

    # Maybe run the initializer.
    unreplicated_params, unreplicated_batch_stats = init_utils.initialize(
        model.flax_module, initializer, model.loss_fn,
        hps.input_shape, hps.output_shape, hps, init_rng, init_logger,
        model.get_fake_batch(hps))

    if jax.process_index() == 0:
        utils.log_pytree_shape_and_statistics(unreplicated_params)
        logging.info('train_size: %d,', hps.train_size)

    # Note that global_step refers to the number of gradients calculations, not
    # the number of model updates. This means when using gradient accumulation,
    # one must supply configs where the number of steps are in units of gradient
    # calculations, not model updates, and in post processing one must divide
    # global_step by grad_accum_step_multiplier to get the number of updates.
    #
    # If using gradient accumulation, stretch the learning rate schedule by the
    # number of gradient calculations per weight update.
    stretch_factor = 1
    if hps.get('total_accumulated_batch_size') is not None:
        stretch_factor = hps.total_accumulated_batch_size // hps.batch_size
    lr_fn = schedules.get_schedule_fn(hps.lr_hparams,
                                      num_train_steps,
                                      stretch_factor=stretch_factor)

    optimizer_init_fn, optimizer_update_fn = optimizers.get_optimizer(
        hps, model)
    unreplicated_optimizer_state = optimizer_init_fn(unreplicated_params)

    (unreplicated_metrics_state, metrics_update_fn,
     metrics_summary_fn) = None, None, None
    if training_metrics_config is not None:
        (metrics_init_fn, metrics_update_fn,
         metrics_summary_fn) = make_training_metrics(num_train_steps,
                                                     **training_metrics_config)
        unreplicated_metrics_state = metrics_init_fn(unreplicated_params)

    (optimizer_state, params, batch_stats, metrics_state, global_step,
     sum_train_cost, preemption_count,
     is_restored) = checkpoint.replicate_and_maybe_restore_checkpoint(
         unreplicated_optimizer_state,
         unreplicated_params,
         unreplicated_batch_stats,
         unreplicated_metrics_state,
         train_dir=train_dir,
         external_checkpoint_path=external_checkpoint_path)

    if is_restored:
        preemption_count += 1
        # Fold the restored step into the dataset RNG so that we will get a
        # different shuffle each time we restore, so that we do not repeat a
        # previous dataset ordering again after restoring. This is not the only
        # difference in shuffling each pre-emption, because we often times reshuffle
        # the input files each time in a non-deterministic manner.
        #
        # Note that if we are pre-empted more than once per epoch then we will
        # retrain more on the beginning of the training split, because each time we
        # restore we refill the shuffle buffer with the first `shuffle_buffer_size`
        # elements from the training split to continue training.
        #
        # Also note that for evaluating on the training split, because we are
        # reshuffling each time, we will get a new eval_train split each time we are
        # pre-empted.
        data_rng = jax.random.fold_in(data_rng, global_step)

    assert hps.batch_size % (jax.device_count()) == 0
    assert eval_batch_size % (jax.device_count()) == 0
    dataset = dataset_builder(
        data_rng,
        hps.batch_size,
        eval_batch_size=eval_batch_size,
        hps=hps,
    )

    update_fn = functools.partial(update,
                                  training_cost=model.training_cost,
                                  grad_clip=hps.get('grad_clip'),
                                  optimizer_update_fn=optimizer_update_fn,
                                  metrics_update_fn=metrics_update_fn)
    # in_axes = (
    #     optimizer_state = 0,
    #     params = 0,
    #     batch_stats = 0,
    #     metrics_state = 0,
    #     batch = 0,
    #     step = None,
    #     lr = None,
    #     rng = None,
    #     local_device_index = 0,
    #     running_train_cost = 0,
    #     training_cost,
    #     grad_clip,
    #     optimizer_update_fn,
    #     metrics_state_update_fn)
    # Also, we can donate buffers for 'optimizer', 'batch_stats',
    # 'batch' and 'training_metrics_state' for update's pmapped computation.
    update_pmapped = jax.pmap(update_fn,
                              axis_name='batch',
                              in_axes=(0, 0, 0, 0, 0, None, None, None, 0, 0),
                              donate_argnums=(0, 1, 2, 8))
    # During eval, we can donate the 'batch' buffer. We don't donate the
    # 'params' and 'batch_stats' buffers as we don't re-assign those values in
    # eval, we do that only in train.
    evaluate_batch_pmapped = jax.pmap(model.evaluate_batch,
                                      axis_name='batch',
                                      donate_argnums=(2, ))
    start_time = time.time()
    start_step = global_step
    prev_eval_step = start_step

    def get_step_frequency(cur_step):
        return float(cur_step - start_step) / (time.time() - start_time)

    if jax.process_index() == 0:
        trainer_utils.log_message('Starting training!', pool, xm_work_unit)

    # Numpy array of range(0, local_device_count) to send to each device to be
    # folded into the RNG inside each train step to get a unique per-device RNG.
    local_device_indices = np.arange(jax.local_device_count())

    # Start at the resumed step and continue until we have finished the number of
    # training steps. If building a dataset iterator using a tf.data.Dataset, in
    # the case of a batch size that does not evenly divide the training dataset
    # size, if using `ds.batch(..., drop_remainer=True)` on the training dataset
    # then the final batch in this iterator will be a partial batch. However, if
    # `drop_remainer=False`, then this iterator will always return batches of the
    # same size, and the final batch will have elements from the start of the
    # (num_epochs + 1)-th epoch.
    train_iter = itertools.islice(dataset.train_iterator_fn(), global_step,
                                  num_train_steps)

    eval_callbacks = []
    rng, callback_rng = jax.random.split(rng)
    callback_rngs = jax.random.split(callback_rng, len(callback_configs))
    for callback_rng, config in zip(callback_rngs, callback_configs):
        eval_callback = callbacks.get_callback(config['callback_name'])(
            model, params, batch_stats, optimizer_state, dataset, hps, config,
            train_dir, callback_rng)
        eval_callbacks.append(eval_callback)

    train_iter = trainer_utils.prefetch_input_pipeline(
        train_iter, hps.num_device_prefetches)

    eval_start_time = start_time
    eval_start_step = start_step
    for _ in range(start_step, num_train_steps):
        with jax.profiler.StepTraceAnnotation('train', step_num=global_step):
            # NOTE(dsuo): to properly profile each step, we must include batch
            # creation in the StepTraceContext (as opposed to putting `train_iter`
            # directly in the top-level for loop).
            batch = next(train_iter)

            if global_step in checkpoint_steps and jax.process_index() == 0:
                checkpoint.save_unreplicated_checkpoint_background(
                    checkpoint_dir,
                    optimizer_state,
                    params,
                    batch_stats,
                    metrics_state,
                    global_step,
                    preemption_count,
                    sum_train_cost,
                    max_to_keep=None)
            lr = lr_fn(global_step)
            optimizer_state, params, batch_stats, sum_train_cost, metrics_state, grad_norm = update_pmapped(
                optimizer_state, params, batch_stats, metrics_state, batch,
                global_step, lr, rng, local_device_indices, sum_train_cost)
            global_step += 1
            # TODO(gdahl, gilmer): consider moving this test up.
            # NB: Since this test is after we increment global_step, having 0 in
            # eval_steps does nothing.
            if trainer_utils.should_eval(global_step, eval_frequency,
                                         eval_steps):
                train_steps_per_sec = (global_step - eval_start_step) / (
                    time.time() - eval_start_time)
                eval_start_step = global_step
                eval_start_time = time.time()
                batch_stats = trainer_utils.maybe_sync_batchnorm_stats(
                    batch_stats)
                report, eval_time = eval_metrics(params, batch_stats, dataset,
                                                 eval_num_batches,
                                                 eval_train_num_batches,
                                                 evaluate_batch_pmapped)
                mean_train_cost = sum_train_cost.mean().item() / max(
                    1, global_step - prev_eval_step)
                report.update(
                    learning_rate=float(lr),
                    global_step=global_step,
                    epoch=global_step * hps.batch_size // hps.train_size,
                    train_steps_per_sec=train_steps_per_sec,
                    overall_steps_per_sec=get_step_frequency(global_step),
                    eval_time=eval_time,
                    grad_norm=np.mean(grad_norm),
                    preemption_count=preemption_count,
                    train_cost=mean_train_cost)

                for eval_callback in eval_callbacks:
                    callback_metrics = eval_callback.run_eval(
                        params, batch_stats, optimizer_state, global_step)
                    if set(callback_metrics.keys()).intersection(
                            set(report.keys())):
                        raise ValueError(
                            'There was a collision between the callback'
                            'metrics and the standard eval metrics keys')
                    report.update(callback_metrics)
                yield report
                if jax.process_index() == 0:
                    trainer_utils.log_eta(pool, xm_work_unit, global_step,
                                          train_steps_per_sec, num_train_steps,
                                          start_time, eval_frequency,
                                          eval_steps, eval_time)
                    trainer_utils.log_epoch_report(report, metrics_logger)
                    trainer_utils.maybe_log_training_metrics(
                        metrics_state, metrics_summary_fn, metrics_logger)
                    checkpoint.save_unreplicated_checkpoint_background(
                        train_dir, optimizer_state, params, batch_stats,
                        metrics_state, global_step, preemption_count,
                        sum_train_cost)
                sum_train_cost = jnp.zeros(jax.local_device_count())
                prev_eval_step = global_step

                early_stopping_condition = trainer_utils.check_for_early_stopping(
                    early_stopping_target_name, early_stopping_target_value,
                    early_stopping_mode, report)
                if early_stopping_condition:
                    comparison_string = '>=' if early_stopping_mode == 'above' else '<='
                    logging.info(
                        'Early stopping because metric %s=%f, reached the target value '
                        'of %s %f.', early_stopping_target_name,
                        report[early_stopping_target_name], comparison_string,
                        early_stopping_target_value)
                    return

    # Always log and checkpoint on host 0 at the end of training.
    # If we moved where in the loop body evals happen then we would not need this
    # test.
    if prev_eval_step != num_train_steps:
        train_steps_per_sec = (global_step - eval_start_step) / (
            time.time() - eval_start_time)
        batch_stats = trainer_utils.maybe_sync_batchnorm_stats(batch_stats)
        report, eval_time = eval_metrics(params, batch_stats, dataset,
                                         eval_num_batches,
                                         eval_train_num_batches,
                                         evaluate_batch_pmapped)
        lr = lr_fn(global_step)
        # Correct the average for the final partial epoch.
        mean_train_cost = sum_train_cost.mean().item() / max(
            1, global_step - prev_eval_step)
        report.update(learning_rate=float(lr),
                      global_step=global_step,
                      epoch=global_step * hps.batch_size // hps.train_size,
                      train_steps_per_sec=train_steps_per_sec,
                      overall_steps_per_sec=get_step_frequency(global_step),
                      eval_time=eval_time,
                      grad_norm=np.mean(grad_norm),
                      preemption_count=preemption_count,
                      train_cost=mean_train_cost)
        yield report
        if jax.process_index() == 0:
            trainer_utils.log_eta(pool, xm_work_unit, global_step,
                                  train_steps_per_sec, num_train_steps,
                                  start_time, eval_frequency, eval_steps,
                                  eval_time)
            trainer_utils.log_epoch_report(report, metrics_logger)
            trainer_utils.maybe_log_training_metrics(metrics_state,
                                                     metrics_summary_fn,
                                                     metrics_logger)
            checkpoint.save_unreplicated_checkpoint_background(
                train_dir, optimizer_state, params, batch_stats, metrics_state,
                global_step, preemption_count, sum_train_cost)
    # To make sure the last checkpoint was correctly saved.
    checkpoint.wait_for_checkpoint_save()
示例#6
0
def train(train_dir,
          model,
          dataset_builder,
          initializer,
          num_train_steps,
          hps,
          rng,
          eval_batch_size,
          eval_num_batches,
          eval_train_num_batches,
          eval_frequency,
          checkpoint_steps,
          eval_steps=None,
          metrics_logger=None,
          init_logger=None,
          training_metrics_config=None,
          use_deprecated_checkpointing=True):
    """Main training loop.

  Trains the given network on the specified dataset for the given number of
  epochs. Saves the training curve in train_dir/r=3/results.tsv.

  Args:
    train_dir: (str) Path of the training directory.
    model: (BaseModel) Model object to be trained.
    dataset_builder: dataset builder returned by datasets.get_dataset.
    initializer: Must have API as defined in initializers.py
    num_train_steps: (int) Number of steps to train on.
    hps: (tf.HParams) Model, initialization and training hparams.
    rng: (jax.random.PRNGKey) Rng seed used in model initialization and data
      shuffling.
    eval_batch_size: the evaluation batch size. If None, use hps.batch_size.
    eval_num_batches: (int) The number of batches used for evaluating on
      validation and test sets. Set to None to evaluate on the whole train set.
    eval_train_num_batches: (int) The number of batches for evaluating on train.
      Set to None to evaluate on the whole training set.
    eval_frequency: (int) Evaluate every k steps.
    checkpoint_steps: List of integers indicating special steps to save
      checkpoints at. These checkpoints do not get used for preemption recovery.
    eval_steps: List of integers indicating which steps to perform evals. If
      provided, eval_frequency will be ignored. Performing an eval implies
      saving a checkpoint that will be used to resume training in the case of
      preemption.
    metrics_logger: Used to log all eval metrics during training. See
      utils.MetricLogger for API definition.
    init_logger: Used for black box initializers that have learning curves.
    training_metrics_config: Dict specifying the configuration of the
      training_metrics_grabber. Set to None to skip logging of advanced training
      metrics.
    use_deprecated_checkpointing: Whether to use deprecated checkpointing.

  Yields:
    metrics: A dictionary of all eval metrics from the given epoch.
  """
    # NOTE: the initialization RNG should *not* be per-host, as this will create
    # different sets of weights per host. However, all other RNGs should be
    # per-host.
    # TODO(znado,gilmer,gdahl): implement replicating the same initialization
    # across hosts.
    rng, init_rng = jax.random.split(rng)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, data_rng = jax.random.split(rng)

    # only used if checkpoints_steps is non-empty.
    checkpoint_dir = os.path.join(train_dir, 'checkpoints')

    if jax.host_id() == 0:
        logging.info('Let the training begin!')
        logging.info('Dataset input shape: %r', hps.input_shape)
        logging.info('Hyperparameters: %s', hps)

    if eval_batch_size is None:
        eval_batch_size = hps.batch_size

    # Maybe run the initializer.
    flax_module, batch_stats = initialize(model.flax_module_def, initializer,
                                          model.loss_fn, hps.input_shape,
                                          hps.output_shape, hps, init_rng,
                                          init_logger)

    if jax.host_id() == 0:
        utils.log_pytree_shape_and_statistics(flax_module.params)
        logging.info('train_size: %d,', hps.train_size)

    lr_fn = schedules.get_schedule_fn(hps.lr_hparams, num_train_steps)

    optimizer = get_optimizer(hps).create(flax_module)

    training_metrics_grabber = None
    if training_metrics_config:
        training_metrics_grabber = utils.TrainingMetricsGrabber.create(
            optimizer.target.params, training_metrics_config)

    (optimizer, batch_stats, training_metrics_grabber, global_step,
     sum_train_cost, preemption_count,
     is_restored) = _maybe_restore_latest_checkpoint(
         unreplicated_optimizer=optimizer,
         unreplicated_batch_stats=batch_stats,
         unreplicated_training_metrics_grabber=training_metrics_grabber,
         train_dir=train_dir,
         use_deprecated_checkpointing=use_deprecated_checkpointing)

    if is_restored:
        preemption_count += 1
        # Fold the restored step into the dataset RNG so that we will get a
        # different shuffle each time we restore, so that we do not repeat a
        # previous dataset ordering again after restoring. This is not the only
        # difference in shuffling each pre-emption, because we often times reshuffle
        # the input files each time in a non-deterministic manner.
        #
        # Note that if we are pre-empted more than once per epoch then we will
        # retrain more on the beginning of the training split, because each time we
        # restore we refill the shuffle buffer with the first `shuffle_buffer_size`
        # elements from the training split to continue training.
        #
        # Also note that for evaluating on the training split, because we are
        # reshuffling each time, we will get a new eval_train split each time we are
        # pre-empted.
        data_rng = jax.random.fold_in(data_rng, global_step)

    assert hps.batch_size % (jax.device_count()) == 0
    assert eval_batch_size % (jax.device_count()) == 0
    dataset = dataset_builder(
        data_rng,
        hps.batch_size,
        eval_batch_size=eval_batch_size,
        hps=hps,
    )

    # pmap functions for the training loop
    # in_axes = (optimizer = 0, batch_stats = 0, batch = 0, step = None,
    # lr = None, rng = None, local_device_index = 0, training_metrics_grabber = 0,
    # training_metrics_grabber, training_cost )
    update_pmapped = jax.pmap(functools.partial(
        update, training_cost=model.training_cost),
                              axis_name='batch',
                              in_axes=(0, 0, 0, None, None, None, 0, 0))
    evaluate_batch_pmapped = jax.pmap(model.evaluate_batch, axis_name='batch')
    start_time = time.time()
    start_step = global_step
    prev_eval_step = start_step

    def get_step_frequency(cur_step):
        return float(cur_step - start_step) / (time.time() - start_time)

    if jax.host_id() == 0:
        logging.info('Starting training!')

    # Numpy array of range(0, local_device_count) to send to each device to be
    # folded into the RNG inside each train step to get a unique per-device RNG.
    local_device_indices = np.arange(jax.local_device_count())

    # Start at the resumed step and continue until we have finished the number of
    # training steps. If building a dataset iterator using a tf.data.Dataset, in
    # the case of a batch size that does not evenly divide the training dataset
    # size, if using `ds.batch(..., drop_remainer=True)` on the training dataset
    # then the final batch in this iterator will be a partial batch. However, if
    # `drop_remainer=False`, then this iterator will always return batches of the
    # same size, and the final batch will have elements from the start of the
    # (num_epochs + 1)-th epoch.
    train_iter = itertools.islice(dataset.train_iterator_fn(), global_step,
                                  num_train_steps)
    for batch in train_iter:
        if global_step in checkpoint_steps and jax.host_id() == 0:
            save_checkpoint(
                checkpoint_dir, {
                    'optimizer': optimizer,
                    'batch_stats': batch_stats,
                    'training_metrics_grabber': training_metrics_grabber
                },
                global_step,
                preemption_count,
                sum_train_cost,
                max_to_keep=None,
                use_deprecated_checkpointing=use_deprecated_checkpointing)
        batch = data_utils.shard(batch)
        lr = lr_fn(global_step)
        optimizer, batch_stats, cost_val, training_metrics_grabber = update_pmapped(
            optimizer, batch_stats, batch, global_step, lr, rng,
            local_device_indices, training_metrics_grabber)
        # Calling float is needed since cost_val is a shape (1,) DeviceArray.
        sum_train_cost += float(jnp.mean(cost_val))
        global_step += 1
        # TODO(gdahl, gilmer): consider moving this test up.
        # NB: Since this test is after we increment global_step, having 0 in
        # eval_steps does nothing.
        if should_eval(global_step, eval_frequency, eval_steps):
            batch_stats = _maybe_sync_batchnorm_stats(batch_stats)
            report, eval_time = eval_metrics(optimizer.target, batch_stats,
                                             dataset, eval_num_batches,
                                             eval_train_num_batches,
                                             evaluate_batch_pmapped)
            mean_train_cost = sum_train_cost / max(
                1, global_step - prev_eval_step)
            report.update(learning_rate=float(lr),
                          global_step=global_step,
                          epoch=global_step * hps.batch_size // hps.train_size,
                          steps_per_sec=get_step_frequency(global_step),
                          eval_time=eval_time,
                          preemption_count=preemption_count,
                          train_cost=mean_train_cost)
            yield report
            if jax.host_id() == 0:
                _log_epoch_report(report, metrics_logger)
                _maybe_log_training_metrics(training_metrics_grabber,
                                            metrics_logger)
                save_checkpoint(
                    train_dir, {
                        'optimizer': optimizer,
                        'batch_stats': batch_stats,
                        'training_metrics_grabber': training_metrics_grabber
                    },
                    global_step,
                    preemption_count,
                    sum_train_cost,
                    use_deprecated_checkpointing=use_deprecated_checkpointing)
            sum_train_cost = 0.0
            prev_eval_step = global_step

    # Always log and checkpoint on host 0 at the end of training.
    # If we moved where in the loop body evals happen then we would not need this
    # test.
    if prev_eval_step != num_train_steps:
        batch_stats = _maybe_sync_batchnorm_stats(batch_stats)
        report, eval_time = eval_metrics(optimizer.target, batch_stats,
                                         dataset, eval_num_batches,
                                         eval_train_num_batches,
                                         evaluate_batch_pmapped)
        lr = lr_fn(global_step)
        # Correct the average for the final partial epoch.
        mean_train_cost = sum_train_cost / max(1, global_step - prev_eval_step)
        report.update(learning_rate=float(lr),
                      global_step=global_step,
                      epoch=global_step * hps.batch_size // hps.train_size,
                      steps_per_sec=get_step_frequency(global_step),
                      eval_time=eval_time,
                      preemption_count=preemption_count,
                      train_cost=mean_train_cost)
        yield report
        if jax.host_id() == 0:
            _log_epoch_report(report, metrics_logger)
            _maybe_log_training_metrics(training_metrics_grabber,
                                        metrics_logger)
            save_checkpoint(
                train_dir, {
                    'optimizer': optimizer,
                    'batch_stats': batch_stats,
                    'training_metrics_grabber': training_metrics_grabber
                },
                global_step,
                preemption_count,
                sum_train_cost,
                use_deprecated_checkpointing=use_deprecated_checkpointing)
    # To make sure the last checkpoint was correctly saved.
    checkpoint.wait_for_checkpoint_save()