コード例 #1
0
ファイル: mesh_utils_test.py プロジェクト: frederikwilde/jax
    def test_create_contiguous_submeshes_errors(self):
        v4 = mesh_utils._TPU_V4

        topology = (4, 4, 8)
        mesh_shape = (1, 16, 8)
        devices = mock_devices(topology[0],
                               topology[1],
                               topology[2],
                               v4,
                               one_device_per_chip=True)
        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "create_device_mesh cannot create contiguous submeshes for "
                "physical mesh topology (4, 4, 8)"):
            mesh_utils.create_device_mesh(mesh_shape,
                                          devices=devices,
                                          contiguous_submeshes=True)

        topology = (4, 8, 8)
        mesh_shape = (1, 128, 2)
        devices = mock_devices(topology[0],
                               topology[1],
                               topology[2],
                               v4,
                               one_device_per_chip=True)
        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "create_device_mesh cannot create contiguous submeshes for mesh_shape "
                "(1, 128, 2) and physical mesh topology (4, 8, 8). "
                "Available mesh_shapes: [(1, 64, 4), (1, 4, 64), (64, 4), (4, 64)]"
        ):
            mesh_utils.create_device_mesh(mesh_shape,
                                          devices=devices,
                                          contiguous_submeshes=True)
コード例 #2
0
ファイル: mesh_utils_test.py プロジェクト: cloudhan/jax
 def test_v3_create_device_mesh(self, devices, mesh_shape,
                                expected_device_id_mesh):
   global_devices = devices()
   mesh = mesh_utils.create_device_mesh(
       mesh_shape, devices=global_devices, contiguous_submeshes=False)
   device_id_mesh = np.vectorize(lambda d: d.id)(mesh)
   self.assertAllClose(np.array(expected_device_id_mesh), device_id_mesh)
コード例 #3
0
ファイル: mesh_utils_test.py プロジェクト: cloudhan/jax
 def test_create_contiguous_submeshes_for_tpu_v4(self):
   v4 = mesh_utils._TPU_V4
   for topology, mesh_shapes in mesh_utils._TRANSPOSE_TRICKS.items():
     logging.vlog(1, "topology: %s", topology)
     devices = mock_devices(topology[0], topology[1], topology[2], v4,
                            one_device_per_chip=True)
     for mesh_shape in mesh_shapes:
       logging.vlog(1, "  mesh_shape: %s", mesh_shape)
       mesh = mesh_utils.create_device_mesh(
           mesh_shape, devices=devices, contiguous_submeshes=True)
       self._assert_contiguous_submeshes(mesh)
コード例 #4
0
def decode_once_spmd_model(
    task_p: InstantiableParams,
    input_p: Sequence[InstantiableParams],
    job_log_dir: Optional[str],
    checkpoint_type: CheckpointType,
    restore_checkpoint_dir: str,
    restore_checkpoint_step: Optional[int],
) -> None:
  """Runs the decoding once on the entire decoder datasets for SPMD model.

  Args:
    task_p: Params for the task that encapsulates an SPMD model.
    input_p: List of input params to be decoded.
    job_log_dir: Directory for the job logs.
    checkpoint_type: Type of model checkpointing method to use.
    restore_checkpoint_dir: The directory from which to restore checkpoint.
    restore_checkpoint_step: If set, the checkpoint step to restore. If unset,
      try to restore from the latest checkpoint if any.
  """
  # TODO(bf-jax): Retrieve the seeds from the model definition instead.
  prng_key = jax.random.PRNGKey(1234)
  prng_key, init_key = jax.random.split(prng_key)

  if restore_checkpoint_dir:
    restore_checkpoint_parent_dir = restore_checkpoint_dir
    if checkpoint_type == CheckpointType.CHECKPOINT_MULTI_HOST_FLAX:
      # TODO(zhouwk): add sanity check on number of subdirs and number of
      # processes and fail early if unequal.
      restore_checkpoint_dir = os.path.join(restore_checkpoint_dir,
                                            f'{jax.process_index():03d}')

  multi_host_checkpointing = bool(checkpoint_type in {
      CheckpointType.CHECKPOINT_MULTI_HOST_FLAX, CheckpointType.CHECKPOINT_GDA
  })

  sample_inputs = input_p[0].Instantiate().get_next()
  inputs_shape = tf.nest.map_structure(py_utils.get_global_input_shape_dtype,
                                       sample_inputs)

  model_p = task_p.model
  # TODO(b/198356509): This is a hack for now as we need to change some
  # annotations for mode='decode'. A future cl will move this logic
  # to a more generic model_p.update_sharding_params_v1(mode='decode').
  model_p.lm = model_p.lm.cls.set_sharding_params_v1(
      model_p.lm,
      replica_axis=model_p.lm.mesh_axis_names[0],
      data_axis=model_p.lm.mesh_axis_names[1],
      mdl_axis=model_p.lm.mesh_axis_names[2],
      device_ids_mesh=model_p.lm.device_mesh,
      mesh_axis_names=model_p.lm.mesh_axis_names,
      mode='decode')

  mesh_shape = model_p.device_mesh.shape
  device_mesh = mesh_utils.create_device_mesh(mesh_shape)
  logging.info('device_mesh: %s', device_mesh)
  jax_task = task_p.Instantiate()
  global_mesh = maps.Mesh(device_mesh, model_p.mesh_axis_names)
  with global_mesh:
    if restore_checkpoint_dir:
      model = jax_task.model
      model.instantiate_variable_configs()
      # Get the metadata from variables instead of actually instantiating them.
      partitioned_specs = jax_task.create_train_state_partition_specs(
          model.vars, is_eval=True)
      # Instantiate the TrainState directly from the checkpoint.
      partitioned_train_state = checkpoints.restore_checkpoint(
          None,
          restore_checkpoint_dir,
          global_mesh=global_mesh,
          checkpoint_type=checkpoint_type,
          state_specs=partitioned_specs,
          step=restore_checkpoint_step)
      if multi_host_checkpointing:
        py_utils.sync_global_devices(
            f'checkpointer:restored:{restore_checkpoint_parent_dir}')
      decode_step_fn, inputs_partition_spec = (
          trainer_lib.get_partitioned_spmd_model_decode_fn(
              jax_task, init_key, partitioned_train_state, partitioned_specs,
              inputs_shape))
    else:
      # When restore is not specified, randomly initiate the train_state.
      (partitioned_train_state, inputs_partition_spec, partitioned_specs,
       decode_step_fn) = trainer_lib.partition_spmd_model_decode(
           task_p, init_key, inputs_shape)
    logging.info('partitioned_train_state: %s',
                 jax.tree_map(lambda x: x.shape, partitioned_train_state))
    # We do not fold in jax.process_index in contrast to the pmap version and
    # use a single global key instead to rely on pjit to split for different
    # replicas.
    logging.info('root prng_key: %s', prng_key)
    prng_key, decode_key = jax.random.split(prng_key)
    logging.info('eval prng_key: %s', decode_key)
    spmd_decode_step_fn = functools.partial(decode_step_fn,
                                            partitioned_train_state.mdl_vars,
                                            decode_key,
                                            partitioned_train_state.step)

    num_steps = [
        -1 if p.reset_for_eval else p.eval_loop_num_batches for p in input_p
    ]
    inputs = [p.Instantiate() for p in input_p]
    decodes = [list() for _ in input_p]
    process_id = jax.process_index()

    for split, num_split_steps in enumerate(num_steps):
      logging.info('Start decoding on input %s', input_p[split].name)
      step_num = 0
      while num_split_steps < 0 or step_num < num_split_steps:
        step_num += 1
        try:
          batch = inputs[split].get_next()
        except (tf.errors.OutOfRangeError, StopIteration):
          break
        if jax.config.jax_parallel_functions_output_gda:
          batch = py_utils.create_gda(batch, inputs_shape, global_mesh,
                                      inputs_partition_spec)
        _, out = spmd_decode_step_fn(batch)
        # Output is fully replicated now, so it's ok to unreplicate it by
        # retrieving from device 0 only.
        out = py_utils.maybe_unreplicate_gda(out)
        global_batch_size = next(iter(out.values())).shape[0]
        logging.info('Finished decoding input batch %d with %d examples',
                     step_num, global_batch_size)
        # Manually shard the output per each jax process.
        # We require that all fields in the output is batch major.
        if global_batch_size % jax.process_count() != 0:
          raise ValueError(f'Global batch size {global_batch_size} must divide '
                           f'jax process count {jax.process_count()}')
        for k, v in out.items():
          if v.shape[0] != global_batch_size:
            raise ValueError('We require that all fields in the decode output '
                             'to have batch size as the first dim, got shape='
                             f'{v.shape} with key={k}, expect batch size = '
                             f'{global_batch_size}')
        per_process_batch_size = global_batch_size // jax.process_count()

        def shard(x, per_process_batch_size=per_process_batch_size):
          return x[(process_id *
                    per_process_batch_size):((process_id + 1) *
                                             per_process_batch_size)]

        out = jax.tree_map(shard, out)
        _, processed = jax_task.model.process_decode_out(inputs[split], out)
        decodes[split].extend(processed)
        logging.info('Finished processing decoded input batch %d', step_num)

  basedir = os.path.join(job_log_dir, 'decoder_out')
  dirnames = _get_dir_names(input_p)
  filename = _get_filename(
      py_utils.maybe_unreplicate_gda(partitioned_train_state.step))
  for s in dirnames:
    dir_path = os.path.join(basedir, s)
    if not tf.io.gfile.exists(dir_path):
      tf.io.gfile.makedirs(dir_path)
  filenames = [os.path.join(basedir, s, filename) for s in dirnames]
  for split, output_file in enumerate(filenames):
    logging.info('Writing decoder output to %s with %d entries', output_file,
                 len(decodes[split]))
    io_utils.WriteKeyValuePairs(output_file, decodes[split])
コード例 #5
0
def evaluate_spmd_model(
    task_p: InstantiableParams,
    eval_input_p: Sequence[InstantiableParams],
    job_log_dir: Optional[str],
    checkpoint_type: CheckpointType,
) -> None:
  """Runs the evaluation loop on the entire test dataset for SPMD model.

  Args:
    task_p: Params of the task encapsulating an SPMD model.
    eval_input_p: List of Params for the eval data pipelines.
    job_log_dir: Directory for the job logs.
    checkpoint_type: Type of model checkpointing method to use.
  """
  logging.info('Using SPMD sharding for model parallelism.')
  eval_input_pipelines = [input_p.Instantiate() for input_p in eval_input_p]
  # TODO(bf-jax): Retrieve the seeds from the model definition instead.
  prng_key = jax.random.PRNGKey(1234)
  prng_key, init_key = jax.random.split(prng_key)

  checkpoint_dir = os.path.join(job_log_dir, 'checkpoints')
  # Note that GDA checkpoint requires all processes to participate in
  # checkpointing but it does not require a separate checkpoint_dir per process.
  if checkpoint_type == CheckpointType.CHECKPOINT_MULTI_HOST_FLAX:
    checkpoint_task_dir = os.path.join(checkpoint_dir,
                                       f'{jax.process_index():03d}')
  else:
    checkpoint_task_dir = checkpoint_dir

  multi_host_checkpointing = bool(checkpoint_type in {
      CheckpointType.CHECKPOINT_MULTI_HOST_FLAX, CheckpointType.CHECKPOINT_GDA
  })

  def get_shape_dtype(x):
    y = jax.ShapeDtypeStruct(x.shape, x.dtype)
    return y

  # Do not ues eval_input_pipelines[0] directly.
  sample_model_inputs = eval_input_p[0].Instantiate().get_next()
  inputs_shape = tf.nest.map_structure(get_shape_dtype, sample_model_inputs)

  model_p = task_p.model
  mesh_shape = model_p.device_mesh.shape
  device_mesh = mesh_utils.create_device_mesh(mesh_shape)
  logging.info('device_mesh: %s', device_mesh)
  global_mesh = maps.Mesh(device_mesh, model_p.mesh_axis_names)
  with global_mesh:
    partitioned_train_state, partitioned_specs, eval_inputs_partition_specs, _, eval_step, _ = (
        trainer_lib.partition_spmd_model(task_p, init_key, inputs_shape))
    partitioned_train_state = checkpoints.restore_checkpoint(
        partitioned_train_state,
        checkpoint_task_dir,
        global_mesh=global_mesh,
        checkpoint_type=checkpoint_type,
        state_specs=partitioned_specs)
    logging.info('partitioned_train_state: %s',
                 jax.tree_map(lambda x: x.shape, partitioned_train_state))
    if multi_host_checkpointing:
      py_utils.sync_global_devices(f'checkpointer:restored:{checkpoint_dir}')

    # We do not fold in jax.process_index in contrast to the pmap version and
    # use a single global key instead to rely on pjit to split for different
    # replicas.
    logging.info('root prng_key: %s', prng_key)
    prng_key, eval_key = jax.random.split(prng_key)
    logging.info('eval prng_key: %s', eval_key)

    logging.info('Evaluation loop starting...')
    summary_base_dir = os.path.join(job_log_dir, 'summaries')
    summary_eval_dirs = [
        os.path.join(summary_base_dir, f'eval_{split}')
        for split, _ in enumerate(eval_input_p)
    ]

    num_steps = [-1 if p.reset_for_eval else 1 for p in eval_input_p]
    last_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir)
    with contextlib.ExitStack() as exit_stack:
      eval_summary_writers = [
          exit_stack.enter_context(summary_utils.get_summary_writer(d))
          for d in summary_eval_dirs
      ]
      while True:
        step_i = int(jax.device_get(partitioned_train_state.step))
        eval_step_fn = functools.partial(eval_step,
                                         partitioned_train_state.mdl_vars,
                                         eval_key, partitioned_train_state.step)
        # Run the eval loop.
        model_utils.run_eval_loop_over_test_splits(
            num_steps,
            eval_step_fn,
            eval_summary_writers,
            step_i,
            eval_input_pipelines,
            eval_inputs_partition_specs,
            inputs_shape,
            global_mesh,
            reshard_inputs=False)
        # If the last check point evaluated matches max train steps, exit.
        if last_checkpoint is not None:
          last_ckpt_step = checkpoints.get_step_from_checkpoint_asset(
              last_checkpoint)
          exceeded_ckpt = last_ckpt_step + task_p.train.save_interval_steps
          if exceeded_ckpt >= task_p.train.num_train_steps:
            break
        new_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir)
        while new_checkpoint == last_checkpoint:
          # Sleep for a minute.
          time.sleep(60)
          new_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir)
        # There must be a new checkpoint here.
        logging.info('Found new checkpoint: %s', new_checkpoint)
        partitioned_train_state = checkpoints.restore_checkpoint(
            partitioned_train_state,
            checkpoint_task_dir,
            global_mesh=global_mesh,
            checkpoint_type=checkpoint_type,
            state_specs=partitioned_specs)
        if multi_host_checkpointing:
          py_utils.sync_global_devices(
              f'checkpointer:restored:{checkpoint_dir}')
        last_checkpoint = new_checkpoint
コード例 #6
0
ファイル: train.py プロジェクト: tensorflow/lingvo
def train_and_evaluate_spmd_model(
        task_p: InstantiableParams, train_input_p: InstantiableParams,
        job_log_dir: Optional[str],
        checkpoint_manager: checkpoint_managers.CheckpointManager,
        checkpoint_type: CheckpointType, restore_checkpoint_dir: Optional[str],
        restore_checkpoint_step: Optional[int],
        eval_input_p: Optional[Sequence[InstantiableParams]]) -> None:
    """Runs the training and evaluation loop.

  Args:
    task_p: Params for task encapsulating the SPMD model.
    train_input_p: Params for the train data pipeline.
    job_log_dir: Directory for the job logs.
    checkpoint_manager: A checkpoint manager controlling how often to save and
      delete checkpoints.
    checkpoint_type: The type of checkpoint to use.
    restore_checkpoint_dir: If set, the directory from which to restore
      checkpoint. If unset, use job_log_dir's `checkpoints` subdirectory
      instead.
    restore_checkpoint_step: If set, the checkpoint step to restore. If unset,
      try to restore from the latest checkpoint if any.
    eval_input_p: Optional list of params for the eval input pipelines.
  """
    logging.info('Using SPMD sharding for model parallelism.')
    model_p = task_p.model

    if eval_input_p is not None:
        eval_input_pipelines = [
            input_p.Instantiate() for input_p in eval_input_p
        ]
        # Do not mutate eval_input_pipelines itself. Instantiate a new one
        # to get sample input.
        sample_eval_model_inputs = eval_input_p[0].Instantiate().get_next()
        eval_test_inputs_shape = tf.nest.map_structure(
            py_utils.get_global_input_shape_dtype, sample_eval_model_inputs)
        eval_test_inputs_pspecs = trainer_lib.get_input_partition_specs(
            model_p.mesh_axis_names, eval_test_inputs_shape)

    # TODO(bf-jax): Retrieve the seeds from the model definition instead.
    prng_key = jax.random.PRNGKey(1234)
    prng_key, init_key = jax.random.split(prng_key)

    checkpoint_dir = _checkpoint_dir(job_log_dir)
    restore_checkpoint_dir = restore_checkpoint_dir or checkpoint_dir
    # Note that GDA checkpoint requires all processes to participate in
    # checkpointing but it does not require a separate checkpoint_dir per process.
    if checkpoint_type == CheckpointType.CHECKPOINT_MULTI_HOST_FLAX:
        checkpoint_task_dir = os.path.join(checkpoint_dir,
                                           f'{jax.process_index():03d}')
        restore_checkpoint_task_dir = os.path.join(
            restore_checkpoint_dir, f'{jax.process_index():03d}')
    else:
        checkpoint_task_dir = checkpoint_dir
        restore_checkpoint_task_dir = restore_checkpoint_dir

    multi_host_checkpointing = bool(checkpoint_type in {
        CheckpointType.CHECKPOINT_MULTI_HOST_FLAX,
        CheckpointType.CHECKPOINT_GDA
    })

    if jax.process_index() == 0:
        tf.io.gfile.makedirs(checkpoint_dir)
    if multi_host_checkpointing:
        # Block all hosts until directory is ready.
        py_utils.sync_global_devices(f'checkpointer:makedirs:{checkpoint_dir}')

    logging.info('Retrieving model inputs for shape info.')
    train_input_for_shape = train_input_p.Instantiate()
    model_inputs_for_shape = train_input_for_shape.get_next()
    inputs_shape = tf.nest.map_structure(py_utils.get_global_input_shape_dtype,
                                         model_inputs_for_shape)

    mesh_shape = model_p.device_mesh.shape
    device_mesh = mesh_utils.create_device_mesh(mesh_shape)
    logging.info('device_mesh: %s', device_mesh)

    global_mesh = maps.Mesh(device_mesh, model_p.mesh_axis_names)
    with global_mesh:
        (partitioned_train_state, train_state_pspecs, inputs_pspecs,
         train_step, eval_step,
         total_num_params) = trainer_lib.partition_spmd_model(
             task_p, init_key, inputs_shape)

        partitioned_train_state = checkpoints.restore_checkpoint(
            partitioned_train_state,
            restore_checkpoint_task_dir,
            global_mesh=global_mesh,
            checkpoint_type=checkpoint_type,
            state_specs=train_state_pspecs,
            step=restore_checkpoint_step)
        logging.info(
            'partitioned_train_state shapes '
            '(global shape for GDA, host-local shape for non-GDA: %s',
            jax.tree_map(lambda x: x.shape, partitioned_train_state))
        if multi_host_checkpointing:
            py_utils.sync_global_devices(
                f'checkpointer:restored:{checkpoint_dir}')

        train_p = task_p.train
        initial_global_step = int(
            jax.device_get(
                py_utils.maybe_unreplicate_gda(partitioned_train_state.step)))
        logging.info('Model initial global_step=%d', initial_global_step)
        _update_latest_model_step(train_input_p, initial_global_step,
                                  train_p.eval_interval_steps)
        train_input_pipeline = train_input_p.Instantiate()

        # We do not fold in jax.process_index in contrast to the pmap version and
        # use a single global key instead to rely on pjit to split for different
        # replicas.
        logging.info('root prng_key: %s', prng_key)
        prng_key, train_key, eval_key = jax.random.split(prng_key, 3)
        logging.info('train prng_key: %s', train_key)
        logging.info('eval prng_key: %s', eval_key)

        logging.info('Training loop starting...')
        summary_base_dir = os.path.join(job_log_dir, 'summaries')
        summary_train_dir = os.path.join(summary_base_dir, 'train')
        summary_eval_dir = os.path.join(summary_base_dir, 'eval_train')
        summary_writer = summary_utils.get_summary_writer
        if eval_input_p is not None:
            summary_eval_test_dirs = [
                os.path.join(summary_base_dir, f'eval_test_{split}')
                for split, _ in enumerate(eval_input_p)
            ]
            # We either run p.eval_loop_num_batches steps or one epoch (when supported
            # by a resettable input) per eval loop during training. When
            # p.reset_for_eval is set to True, we run the eval loop until
            # tf.errors.OutOfRangeError (or StopIteration) is raised, which can be
            # triggered either because input pipeline has reached the end of the input
            # sequence, or a pre-determined num_batches has reached.
            eval_num_steps = [
                -1 if p.reset_for_eval else p.eval_loop_num_batches
                for p in eval_input_p
            ]
        else:
            summary_eval_test_dirs = []

        with contextlib.ExitStack() as exit_stack:
            train_summary_writer = exit_stack.enter_context(
                summary_writer(summary_train_dir))
            eval_summary_writer = exit_stack.enter_context(
                summary_writer(summary_eval_dir))
            eval_test_summary_writers = [
                exit_stack.enter_context(summary_writer(d))
                for d in summary_eval_test_dirs
            ]

            # This only prints the view from the first host machine.
            summary_utils.write_model_structure(train_summary_writer,
                                                partitioned_train_state,
                                                is_vars_replicated=False)
            summary_utils.write_total_num_params(train_summary_writer,
                                                 total_num_params)

            summary_last_time = time.time()
            summary_last_step = None

            step_i = int(
                jax.device_get(
                    py_utils.maybe_unreplicate_gda(
                        partitioned_train_state.step)))

            # Start the train loop. Make sure all at the same step.
            py_utils.sync_global_devices(
                f'Start training loop from step: {step_i}')
            while True:
                logging.debug('step=`%d`: Beginning', step_i)
                if step_i >= train_p.num_train_steps:
                    logging.info(
                        'Training loop completed (step (`%d`) greater than '
                        'num_train_step (`%d`).', step_i,
                        train_p.num_train_steps)
                    break

                if summary_last_step is None:
                    summary_last_step = step_i - 1

                if checkpoint_manager.should_save(step_i):
                    logging.info('Saving a ckpt at step: %d', step_i)
                    if multi_host_checkpointing:
                        py_utils.sync_global_devices(
                            f'checkpointer:saving:{checkpoint_dir}:step-{step_i}'
                        )
                    if multi_host_checkpointing or jax.process_index() == 0:
                        checkpoints.save_checkpoint(
                            partitioned_train_state,
                            checkpoint_task_dir,
                            checkpoint_type=checkpoint_type,
                            state_specs=train_state_pspecs,
                            unreplicate=False)
                    checkpoint_manager.save_metadata(global_step_id=step_i)
                    if multi_host_checkpointing:
                        py_utils.sync_global_devices(
                            f'checkpointer:saved:{checkpoint_dir}:step-{step_i}'
                        )

                # Get new model inputs
                if step_i <= _N_STEPS_WARMUP_LOGGING:
                    logging.info('step=`%d`: Retrieving model inputs.', step_i)
                logging.debug('  Retrieving inputs.')
                model_inputs = train_input_pipeline.get_next()

                if jax.config.jax_parallel_functions_output_gda:
                    if step_i <= _N_STEPS_WARMUP_LOGGING:
                        start = time.time()
                    py_utils.assert_same_shape_and_dtype(
                        inputs_shape,
                        tf.nest.map_structure(
                            py_utils.get_global_input_shape_dtype,
                            model_inputs))
                    model_inputs = py_utils.create_gda(model_inputs,
                                                       inputs_shape,
                                                       global_mesh,
                                                       inputs_pspecs)
                    if step_i <= _N_STEPS_WARMUP_LOGGING:
                        logging.info('GDA train batch input creation time %s',
                                     time.time() - start)

                logging.debug('  Retrieved inputs.')

                logging.debug('  Performing train_step().')
                with jax.profiler.StepTraceAnnotation('train',
                                                      step_num=step_i):
                    (partitioned_train_state, loss, metrics, per_example_out,
                     summary_tensors) = train_step(partitioned_train_state,
                                                   train_key, model_inputs)
                logging.debug('  Completed train_step().')

                logging.debug('  Writing summaries (attempt).')
                if summary_utils.write_summary_every_n_steps(
                        partitioned_train_state,
                        train_summary_writer,
                        step_i,
                        train_p.summary_interval_steps,
                        loss,
                        metrics,
                        per_example_out,
                        summary_tensors,
                        train_p.norm_summary_interval_steps,
                        summary_last_time,
                        summary_last_step,
                        unreplicate_mdl_vars=False,
                        unreplicate_metrics=False):
                    summary_last_time = time.time()
                    summary_last_step = step_i
                    step_i = int(
                        py_utils.maybe_unreplicate_gda(
                            partitioned_train_state.step))
                else:
                    # Increment train step locally to avoid an explicit device sync.
                    step_i += 1
                logging.debug('  Wrote summaries (attempted).')

                # Run eval at regular step interval.
                if step_i % train_p.eval_interval_steps == 0:
                    logging.debug('  Starting eval_step().')
                    logging.debug('  Retrieving eval model_inputs.')
                    eval_inputs = train_input_pipeline.get_next()

                    if jax.config.jax_parallel_functions_output_gda:
                        eval_inputs = py_utils.create_gda(
                            eval_inputs, inputs_shape, global_mesh,
                            inputs_pspecs)

                    logging.debug('  Retrieved eval model_inputs.')
                    logging.debug(
                        '  Performing eval_step() runs on training split.')

                    eval_step_fn = functools.partial(
                        eval_step, partitioned_train_state.mdl_vars, eval_key,
                        partitioned_train_state.step)
                    loss, mean_metrics, summary_tensors = model_utils.run_eval_one_step(
                        eval_inputs, eval_step_fn, reshard_inputs=False)
                    logging.debug(
                        '  Completed eval_step() runs on training split.')

                    logging.info('step=`%d`', step_i)
                    logging.info('  eval loss: %s', loss)
                    logging.info('  mean_metrics: %s', mean_metrics)
                    logging.info('  summary_tensors: %s', summary_tensors)
                    if step_i % train_p.summary_interval_steps == 0:
                        logging.debug('  Writing eval summaries.')
                        summary_utils.write_summary_entry(
                            eval_summary_writer,
                            step_i,
                            loss,
                            mean_metrics,
                            summary_tensors,
                            unreplicate_metrics=False)
                        logging.debug('  Wrote eval summaries.')
                    # If we have eval test then also evaluate on test.
                    if eval_input_p is not None:
                        logging.debug(
                            '  Performing eval_step() runs on test splits.')
                        model_utils.run_eval_loop_over_test_splits(
                            eval_num_steps,
                            eval_step_fn,
                            eval_test_summary_writers,
                            step_i,
                            eval_input_pipelines,
                            eval_test_inputs_pspecs,
                            eval_test_inputs_shape,
                            global_mesh,
                            reshard_inputs=False)
                        logging.debug(
                            '  Completed eval_step() runs on test splits.')

                logging.debug('step=`%d`: End', step_i - 1)