예제 #1
0
 def load_model(self):
     if self.network:
         logger.info(
             'Attempting to reload model when model is loaded. Returning')
         return
     with self.lock:
         logger.info('Loading Model')
         start = timer()
         logger.info(f"JAX Devices: {jax.device_count()}")
         logger.info(f"JAX Runtime Initialized in {timer(start):.06} secs")
         mesh_shape = (jax.device_count() //
                       self.params['cores_per_replica'],
                       self.params['cores_per_replica'])
         self.devices = np.array(jax.devices()).reshape(mesh_shape)
         self.total_batch = self.params[
             'per_replica_batch'] * jax.device_count(
             ) // self.params['cores_per_replica'] * 8
         maps.thread_resources.env = maps.ResourceEnv(
             maps.Mesh(self.devices, ('dp', 'mp')))
         network = CausalTransformer(self.params)
         logger.info(f'Loading Checkpoint')
         network.state = read_ckpt(network.state, "/app/model/",
                                   self.devices.shape[1])
         logger.info(
             f"GPTJ Network loaded in {timer(start):.06} secs. Total Batch Size: {self.total_batch}"
         )
         del network.state["opt_state"]
         network.state = network.move_xmap(
             network.state, np.zeros(self.params['cores_per_replica']))
         self.network = network
예제 #2
0
    def background(self):
        logger.info(f'Init Background')
        maps.thread_resources.env = maps.ResourceEnv(
            maps.Mesh(self.devices, ('dp', 'mp')))
        while True:
            batch, qids = [], []
            while len(batch) <= self.total_batch:
                try:
                    req = self.queue.get(block=False)
                    logger.info(f'Got Queue Item: {req}')
                    batch.append(req['item'])
                    qids.append(req['qidx'])

                except Empty:
                    if len(batch):
                        break
                    else:
                        time.sleep(0.01)
            batch_size = len(batch)
            logger.info(f'Working on Batch: {batch_size} - {qids}')
            while len(batch) < self.total_batch:
                batch.append(self.placeholder_item)
            start = timer()
            results = self.infer_batch(batch)
            for res, qid in zip(results, qids):
                self.queue_ids[qid].put(res)
            logger.info(
                f'Completed Current Batch of {batch_size} Items in {timer(start):.2f} secs'
            )
예제 #3
0
  def _pjit(inp):
    if isinstance(inp, GlobalDeviceArray):
      if inp.is_fully_replicated:
        return inp.local_data(0).to_py()
      global_mesh = inp.mesh
      in_axis_resources = FROM_GDA
    else:
      # DA/SDA/np.array will be sharded based on global_mesh.local_mesh.
      # Shape of local_mesh will always be (1, local_device_count())
      devices = np.array(jax.devices()).reshape(jax.process_count(),
                                                jax.local_device_count())
      global_mesh = maps.Mesh(devices, ('processes', 'local_devices'))
      in_axis_resources = P('processes')
      if inp.ndim == 0 or not tiled:
        inp = np.expand_dims(inp, axis=0)

    with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
      out = pjit(lambda x: x, in_axis_resources=in_axis_resources,
                 out_axis_resources=None)(inp)
    return out.local_data(0).to_py()
예제 #4
0
 def test_pjit_inherits_effects(self):
   if jax.default_backend() not in {'gpu', 'tpu'}:
     raise unittest.SkipTest("pjit only supports GPU and TPU backends")
   def f(x):
     effect_p.bind(effect='foo')
     effect_p.bind(effect='bar')
     return x
   f = pjit.pjit(f, in_axis_resources=pjit.PartitionSpec('x'),
       out_axis_resources=pjit.PartitionSpec('x'))
   with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
     with maps.Mesh(np.array(jax.devices()), ['x']):
       jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
예제 #5
0
  def test_unordered_print_with_xmap(self):

    def f(x):
      debug_print("{}", x, ordered=False)
    f = maps.xmap(f, in_axes=['a'], out_axes=None, backend='cpu',
                  axis_resources={'a': 'dev'})
    with maps.Mesh(np.array(jax.devices(backend='cpu')), ['dev']):
      with capture_stdout() as output:
        f(jnp.arange(40))
        jax.effects_barrier()
      lines = [f"{i}\n" for i in range(40)]
      self._assertLinesEqual(output(), "".join(lines))
예제 #6
0
    def test_pjit_inherits_effects(self):
        def f(x):
            effect_p.bind(effect='foo')
            effect_p.bind(effect='bar')
            return x

        f = pjit.pjit(f,
                      in_axis_resources=pjit.PartitionSpec('x'),
                      out_axis_resources=pjit.PartitionSpec('x'))
        with self.assertRaisesRegex(NotImplementedError,
                                    'Effects not supported'):
            with maps.Mesh(np.array(jax.devices()), ['x']):
                jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
예제 #7
0
    "seq": 2048,
    "cores_per_replica": 1,  # only running on one GPU
    "per_replica_batch": 1,
}

per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]

params["sampler"] = nucleaus_sample

# here we "remove" the optimizer parameters from the model (as we don't need them for inference)
params["optimizer"] = optax.scale(0)

devices = np.array([jax.devices()[0]]).reshape((1, 1))
maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))

tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

network = CausalTransformer(params)

start = time.time()

# here we load a checkpoint which was written with 8 shards into 1 shard
network.state = read_ckpt(network.state,
                          "step_383500/",
                          8,
                          shards_out=cores_per_replica)


def infer(context, top_p=0.9, temp=1.0, gen_len=512):
예제 #8
0
def decode_once_spmd_model(
    task_p: InstantiableParams,
    input_p: Sequence[InstantiableParams],
    job_log_dir: Optional[str],
    checkpoint_type: CheckpointType,
    restore_checkpoint_dir: str,
    restore_checkpoint_step: Optional[int],
) -> None:
  """Runs the decoding once on the entire decoder datasets for SPMD model.

  Args:
    task_p: Params for the task that encapsulates an SPMD model.
    input_p: List of input params to be decoded.
    job_log_dir: Directory for the job logs.
    checkpoint_type: Type of model checkpointing method to use.
    restore_checkpoint_dir: The directory from which to restore checkpoint.
    restore_checkpoint_step: If set, the checkpoint step to restore. If unset,
      try to restore from the latest checkpoint if any.
  """
  # TODO(bf-jax): Retrieve the seeds from the model definition instead.
  prng_key = jax.random.PRNGKey(1234)
  prng_key, init_key = jax.random.split(prng_key)

  if restore_checkpoint_dir:
    restore_checkpoint_parent_dir = restore_checkpoint_dir
    if checkpoint_type == CheckpointType.CHECKPOINT_MULTI_HOST_FLAX:
      # TODO(zhouwk): add sanity check on number of subdirs and number of
      # processes and fail early if unequal.
      restore_checkpoint_dir = os.path.join(restore_checkpoint_dir,
                                            f'{jax.process_index():03d}')

  multi_host_checkpointing = bool(checkpoint_type in {
      CheckpointType.CHECKPOINT_MULTI_HOST_FLAX, CheckpointType.CHECKPOINT_GDA
  })

  sample_inputs = input_p[0].Instantiate().get_next()
  inputs_shape = tf.nest.map_structure(py_utils.get_global_input_shape_dtype,
                                       sample_inputs)

  model_p = task_p.model
  # TODO(b/198356509): This is a hack for now as we need to change some
  # annotations for mode='decode'. A future cl will move this logic
  # to a more generic model_p.update_sharding_params_v1(mode='decode').
  model_p.lm = model_p.lm.cls.set_sharding_params_v1(
      model_p.lm,
      replica_axis=model_p.lm.mesh_axis_names[0],
      data_axis=model_p.lm.mesh_axis_names[1],
      mdl_axis=model_p.lm.mesh_axis_names[2],
      device_ids_mesh=model_p.lm.device_mesh,
      mesh_axis_names=model_p.lm.mesh_axis_names,
      mode='decode')

  mesh_shape = model_p.device_mesh.shape
  device_mesh = mesh_utils.create_device_mesh(mesh_shape)
  logging.info('device_mesh: %s', device_mesh)
  jax_task = task_p.Instantiate()
  global_mesh = maps.Mesh(device_mesh, model_p.mesh_axis_names)
  with global_mesh:
    if restore_checkpoint_dir:
      model = jax_task.model
      model.instantiate_variable_configs()
      # Get the metadata from variables instead of actually instantiating them.
      partitioned_specs = jax_task.create_train_state_partition_specs(
          model.vars, is_eval=True)
      # Instantiate the TrainState directly from the checkpoint.
      partitioned_train_state = checkpoints.restore_checkpoint(
          None,
          restore_checkpoint_dir,
          global_mesh=global_mesh,
          checkpoint_type=checkpoint_type,
          state_specs=partitioned_specs,
          step=restore_checkpoint_step)
      if multi_host_checkpointing:
        py_utils.sync_global_devices(
            f'checkpointer:restored:{restore_checkpoint_parent_dir}')
      decode_step_fn, inputs_partition_spec = (
          trainer_lib.get_partitioned_spmd_model_decode_fn(
              jax_task, init_key, partitioned_train_state, partitioned_specs,
              inputs_shape))
    else:
      # When restore is not specified, randomly initiate the train_state.
      (partitioned_train_state, inputs_partition_spec, partitioned_specs,
       decode_step_fn) = trainer_lib.partition_spmd_model_decode(
           task_p, init_key, inputs_shape)
    logging.info('partitioned_train_state: %s',
                 jax.tree_map(lambda x: x.shape, partitioned_train_state))
    # We do not fold in jax.process_index in contrast to the pmap version and
    # use a single global key instead to rely on pjit to split for different
    # replicas.
    logging.info('root prng_key: %s', prng_key)
    prng_key, decode_key = jax.random.split(prng_key)
    logging.info('eval prng_key: %s', decode_key)
    spmd_decode_step_fn = functools.partial(decode_step_fn,
                                            partitioned_train_state.mdl_vars,
                                            decode_key,
                                            partitioned_train_state.step)

    num_steps = [
        -1 if p.reset_for_eval else p.eval_loop_num_batches for p in input_p
    ]
    inputs = [p.Instantiate() for p in input_p]
    decodes = [list() for _ in input_p]
    process_id = jax.process_index()

    for split, num_split_steps in enumerate(num_steps):
      logging.info('Start decoding on input %s', input_p[split].name)
      step_num = 0
      while num_split_steps < 0 or step_num < num_split_steps:
        step_num += 1
        try:
          batch = inputs[split].get_next()
        except (tf.errors.OutOfRangeError, StopIteration):
          break
        if jax.config.jax_parallel_functions_output_gda:
          batch = py_utils.create_gda(batch, inputs_shape, global_mesh,
                                      inputs_partition_spec)
        _, out = spmd_decode_step_fn(batch)
        # Output is fully replicated now, so it's ok to unreplicate it by
        # retrieving from device 0 only.
        out = py_utils.maybe_unreplicate_gda(out)
        global_batch_size = next(iter(out.values())).shape[0]
        logging.info('Finished decoding input batch %d with %d examples',
                     step_num, global_batch_size)
        # Manually shard the output per each jax process.
        # We require that all fields in the output is batch major.
        if global_batch_size % jax.process_count() != 0:
          raise ValueError(f'Global batch size {global_batch_size} must divide '
                           f'jax process count {jax.process_count()}')
        for k, v in out.items():
          if v.shape[0] != global_batch_size:
            raise ValueError('We require that all fields in the decode output '
                             'to have batch size as the first dim, got shape='
                             f'{v.shape} with key={k}, expect batch size = '
                             f'{global_batch_size}')
        per_process_batch_size = global_batch_size // jax.process_count()

        def shard(x, per_process_batch_size=per_process_batch_size):
          return x[(process_id *
                    per_process_batch_size):((process_id + 1) *
                                             per_process_batch_size)]

        out = jax.tree_map(shard, out)
        _, processed = jax_task.model.process_decode_out(inputs[split], out)
        decodes[split].extend(processed)
        logging.info('Finished processing decoded input batch %d', step_num)

  basedir = os.path.join(job_log_dir, 'decoder_out')
  dirnames = _get_dir_names(input_p)
  filename = _get_filename(
      py_utils.maybe_unreplicate_gda(partitioned_train_state.step))
  for s in dirnames:
    dir_path = os.path.join(basedir, s)
    if not tf.io.gfile.exists(dir_path):
      tf.io.gfile.makedirs(dir_path)
  filenames = [os.path.join(basedir, s, filename) for s in dirnames]
  for split, output_file in enumerate(filenames):
    logging.info('Writing decoder output to %s with %d entries', output_file,
                 len(decodes[split]))
    io_utils.WriteKeyValuePairs(output_file, decodes[split])
예제 #9
0
def evaluate_spmd_model(
    task_p: InstantiableParams,
    eval_input_p: Sequence[InstantiableParams],
    job_log_dir: Optional[str],
    checkpoint_type: CheckpointType,
) -> None:
  """Runs the evaluation loop on the entire test dataset for SPMD model.

  Args:
    task_p: Params of the task encapsulating an SPMD model.
    eval_input_p: List of Params for the eval data pipelines.
    job_log_dir: Directory for the job logs.
    checkpoint_type: Type of model checkpointing method to use.
  """
  logging.info('Using SPMD sharding for model parallelism.')
  eval_input_pipelines = [input_p.Instantiate() for input_p in eval_input_p]
  # TODO(bf-jax): Retrieve the seeds from the model definition instead.
  prng_key = jax.random.PRNGKey(1234)
  prng_key, init_key = jax.random.split(prng_key)

  checkpoint_dir = os.path.join(job_log_dir, 'checkpoints')
  # Note that GDA checkpoint requires all processes to participate in
  # checkpointing but it does not require a separate checkpoint_dir per process.
  if checkpoint_type == CheckpointType.CHECKPOINT_MULTI_HOST_FLAX:
    checkpoint_task_dir = os.path.join(checkpoint_dir,
                                       f'{jax.process_index():03d}')
  else:
    checkpoint_task_dir = checkpoint_dir

  multi_host_checkpointing = bool(checkpoint_type in {
      CheckpointType.CHECKPOINT_MULTI_HOST_FLAX, CheckpointType.CHECKPOINT_GDA
  })

  def get_shape_dtype(x):
    y = jax.ShapeDtypeStruct(x.shape, x.dtype)
    return y

  # Do not ues eval_input_pipelines[0] directly.
  sample_model_inputs = eval_input_p[0].Instantiate().get_next()
  inputs_shape = tf.nest.map_structure(get_shape_dtype, sample_model_inputs)

  model_p = task_p.model
  mesh_shape = model_p.device_mesh.shape
  device_mesh = mesh_utils.create_device_mesh(mesh_shape)
  logging.info('device_mesh: %s', device_mesh)
  global_mesh = maps.Mesh(device_mesh, model_p.mesh_axis_names)
  with global_mesh:
    partitioned_train_state, partitioned_specs, eval_inputs_partition_specs, _, eval_step, _ = (
        trainer_lib.partition_spmd_model(task_p, init_key, inputs_shape))
    partitioned_train_state = checkpoints.restore_checkpoint(
        partitioned_train_state,
        checkpoint_task_dir,
        global_mesh=global_mesh,
        checkpoint_type=checkpoint_type,
        state_specs=partitioned_specs)
    logging.info('partitioned_train_state: %s',
                 jax.tree_map(lambda x: x.shape, partitioned_train_state))
    if multi_host_checkpointing:
      py_utils.sync_global_devices(f'checkpointer:restored:{checkpoint_dir}')

    # We do not fold in jax.process_index in contrast to the pmap version and
    # use a single global key instead to rely on pjit to split for different
    # replicas.
    logging.info('root prng_key: %s', prng_key)
    prng_key, eval_key = jax.random.split(prng_key)
    logging.info('eval prng_key: %s', eval_key)

    logging.info('Evaluation loop starting...')
    summary_base_dir = os.path.join(job_log_dir, 'summaries')
    summary_eval_dirs = [
        os.path.join(summary_base_dir, f'eval_{split}')
        for split, _ in enumerate(eval_input_p)
    ]

    num_steps = [-1 if p.reset_for_eval else 1 for p in eval_input_p]
    last_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir)
    with contextlib.ExitStack() as exit_stack:
      eval_summary_writers = [
          exit_stack.enter_context(summary_utils.get_summary_writer(d))
          for d in summary_eval_dirs
      ]
      while True:
        step_i = int(jax.device_get(partitioned_train_state.step))
        eval_step_fn = functools.partial(eval_step,
                                         partitioned_train_state.mdl_vars,
                                         eval_key, partitioned_train_state.step)
        # Run the eval loop.
        model_utils.run_eval_loop_over_test_splits(
            num_steps,
            eval_step_fn,
            eval_summary_writers,
            step_i,
            eval_input_pipelines,
            eval_inputs_partition_specs,
            inputs_shape,
            global_mesh,
            reshard_inputs=False)
        # If the last check point evaluated matches max train steps, exit.
        if last_checkpoint is not None:
          last_ckpt_step = checkpoints.get_step_from_checkpoint_asset(
              last_checkpoint)
          exceeded_ckpt = last_ckpt_step + task_p.train.save_interval_steps
          if exceeded_ckpt >= task_p.train.num_train_steps:
            break
        new_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir)
        while new_checkpoint == last_checkpoint:
          # Sleep for a minute.
          time.sleep(60)
          new_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir)
        # There must be a new checkpoint here.
        logging.info('Found new checkpoint: %s', new_checkpoint)
        partitioned_train_state = checkpoints.restore_checkpoint(
            partitioned_train_state,
            checkpoint_task_dir,
            global_mesh=global_mesh,
            checkpoint_type=checkpoint_type,
            state_specs=partitioned_specs)
        if multi_host_checkpointing:
          py_utils.sync_global_devices(
              f'checkpointer:restored:{checkpoint_dir}')
        last_checkpoint = new_checkpoint
예제 #10
0
파일: train.py 프로젝트: tensorflow/lingvo
def train_and_evaluate_spmd_model(
        task_p: InstantiableParams, train_input_p: InstantiableParams,
        job_log_dir: Optional[str],
        checkpoint_manager: checkpoint_managers.CheckpointManager,
        checkpoint_type: CheckpointType, restore_checkpoint_dir: Optional[str],
        restore_checkpoint_step: Optional[int],
        eval_input_p: Optional[Sequence[InstantiableParams]]) -> None:
    """Runs the training and evaluation loop.

  Args:
    task_p: Params for task encapsulating the SPMD model.
    train_input_p: Params for the train data pipeline.
    job_log_dir: Directory for the job logs.
    checkpoint_manager: A checkpoint manager controlling how often to save and
      delete checkpoints.
    checkpoint_type: The type of checkpoint to use.
    restore_checkpoint_dir: If set, the directory from which to restore
      checkpoint. If unset, use job_log_dir's `checkpoints` subdirectory
      instead.
    restore_checkpoint_step: If set, the checkpoint step to restore. If unset,
      try to restore from the latest checkpoint if any.
    eval_input_p: Optional list of params for the eval input pipelines.
  """
    logging.info('Using SPMD sharding for model parallelism.')
    model_p = task_p.model

    if eval_input_p is not None:
        eval_input_pipelines = [
            input_p.Instantiate() for input_p in eval_input_p
        ]
        # Do not mutate eval_input_pipelines itself. Instantiate a new one
        # to get sample input.
        sample_eval_model_inputs = eval_input_p[0].Instantiate().get_next()
        eval_test_inputs_shape = tf.nest.map_structure(
            py_utils.get_global_input_shape_dtype, sample_eval_model_inputs)
        eval_test_inputs_pspecs = trainer_lib.get_input_partition_specs(
            model_p.mesh_axis_names, eval_test_inputs_shape)

    # TODO(bf-jax): Retrieve the seeds from the model definition instead.
    prng_key = jax.random.PRNGKey(1234)
    prng_key, init_key = jax.random.split(prng_key)

    checkpoint_dir = _checkpoint_dir(job_log_dir)
    restore_checkpoint_dir = restore_checkpoint_dir or checkpoint_dir
    # Note that GDA checkpoint requires all processes to participate in
    # checkpointing but it does not require a separate checkpoint_dir per process.
    if checkpoint_type == CheckpointType.CHECKPOINT_MULTI_HOST_FLAX:
        checkpoint_task_dir = os.path.join(checkpoint_dir,
                                           f'{jax.process_index():03d}')
        restore_checkpoint_task_dir = os.path.join(
            restore_checkpoint_dir, f'{jax.process_index():03d}')
    else:
        checkpoint_task_dir = checkpoint_dir
        restore_checkpoint_task_dir = restore_checkpoint_dir

    multi_host_checkpointing = bool(checkpoint_type in {
        CheckpointType.CHECKPOINT_MULTI_HOST_FLAX,
        CheckpointType.CHECKPOINT_GDA
    })

    if jax.process_index() == 0:
        tf.io.gfile.makedirs(checkpoint_dir)
    if multi_host_checkpointing:
        # Block all hosts until directory is ready.
        py_utils.sync_global_devices(f'checkpointer:makedirs:{checkpoint_dir}')

    logging.info('Retrieving model inputs for shape info.')
    train_input_for_shape = train_input_p.Instantiate()
    model_inputs_for_shape = train_input_for_shape.get_next()
    inputs_shape = tf.nest.map_structure(py_utils.get_global_input_shape_dtype,
                                         model_inputs_for_shape)

    mesh_shape = model_p.device_mesh.shape
    device_mesh = mesh_utils.create_device_mesh(mesh_shape)
    logging.info('device_mesh: %s', device_mesh)

    global_mesh = maps.Mesh(device_mesh, model_p.mesh_axis_names)
    with global_mesh:
        (partitioned_train_state, train_state_pspecs, inputs_pspecs,
         train_step, eval_step,
         total_num_params) = trainer_lib.partition_spmd_model(
             task_p, init_key, inputs_shape)

        partitioned_train_state = checkpoints.restore_checkpoint(
            partitioned_train_state,
            restore_checkpoint_task_dir,
            global_mesh=global_mesh,
            checkpoint_type=checkpoint_type,
            state_specs=train_state_pspecs,
            step=restore_checkpoint_step)
        logging.info(
            'partitioned_train_state shapes '
            '(global shape for GDA, host-local shape for non-GDA: %s',
            jax.tree_map(lambda x: x.shape, partitioned_train_state))
        if multi_host_checkpointing:
            py_utils.sync_global_devices(
                f'checkpointer:restored:{checkpoint_dir}')

        train_p = task_p.train
        initial_global_step = int(
            jax.device_get(
                py_utils.maybe_unreplicate_gda(partitioned_train_state.step)))
        logging.info('Model initial global_step=%d', initial_global_step)
        _update_latest_model_step(train_input_p, initial_global_step,
                                  train_p.eval_interval_steps)
        train_input_pipeline = train_input_p.Instantiate()

        # We do not fold in jax.process_index in contrast to the pmap version and
        # use a single global key instead to rely on pjit to split for different
        # replicas.
        logging.info('root prng_key: %s', prng_key)
        prng_key, train_key, eval_key = jax.random.split(prng_key, 3)
        logging.info('train prng_key: %s', train_key)
        logging.info('eval prng_key: %s', eval_key)

        logging.info('Training loop starting...')
        summary_base_dir = os.path.join(job_log_dir, 'summaries')
        summary_train_dir = os.path.join(summary_base_dir, 'train')
        summary_eval_dir = os.path.join(summary_base_dir, 'eval_train')
        summary_writer = summary_utils.get_summary_writer
        if eval_input_p is not None:
            summary_eval_test_dirs = [
                os.path.join(summary_base_dir, f'eval_test_{split}')
                for split, _ in enumerate(eval_input_p)
            ]
            # We either run p.eval_loop_num_batches steps or one epoch (when supported
            # by a resettable input) per eval loop during training. When
            # p.reset_for_eval is set to True, we run the eval loop until
            # tf.errors.OutOfRangeError (or StopIteration) is raised, which can be
            # triggered either because input pipeline has reached the end of the input
            # sequence, or a pre-determined num_batches has reached.
            eval_num_steps = [
                -1 if p.reset_for_eval else p.eval_loop_num_batches
                for p in eval_input_p
            ]
        else:
            summary_eval_test_dirs = []

        with contextlib.ExitStack() as exit_stack:
            train_summary_writer = exit_stack.enter_context(
                summary_writer(summary_train_dir))
            eval_summary_writer = exit_stack.enter_context(
                summary_writer(summary_eval_dir))
            eval_test_summary_writers = [
                exit_stack.enter_context(summary_writer(d))
                for d in summary_eval_test_dirs
            ]

            # This only prints the view from the first host machine.
            summary_utils.write_model_structure(train_summary_writer,
                                                partitioned_train_state,
                                                is_vars_replicated=False)
            summary_utils.write_total_num_params(train_summary_writer,
                                                 total_num_params)

            summary_last_time = time.time()
            summary_last_step = None

            step_i = int(
                jax.device_get(
                    py_utils.maybe_unreplicate_gda(
                        partitioned_train_state.step)))

            # Start the train loop. Make sure all at the same step.
            py_utils.sync_global_devices(
                f'Start training loop from step: {step_i}')
            while True:
                logging.debug('step=`%d`: Beginning', step_i)
                if step_i >= train_p.num_train_steps:
                    logging.info(
                        'Training loop completed (step (`%d`) greater than '
                        'num_train_step (`%d`).', step_i,
                        train_p.num_train_steps)
                    break

                if summary_last_step is None:
                    summary_last_step = step_i - 1

                if checkpoint_manager.should_save(step_i):
                    logging.info('Saving a ckpt at step: %d', step_i)
                    if multi_host_checkpointing:
                        py_utils.sync_global_devices(
                            f'checkpointer:saving:{checkpoint_dir}:step-{step_i}'
                        )
                    if multi_host_checkpointing or jax.process_index() == 0:
                        checkpoints.save_checkpoint(
                            partitioned_train_state,
                            checkpoint_task_dir,
                            checkpoint_type=checkpoint_type,
                            state_specs=train_state_pspecs,
                            unreplicate=False)
                    checkpoint_manager.save_metadata(global_step_id=step_i)
                    if multi_host_checkpointing:
                        py_utils.sync_global_devices(
                            f'checkpointer:saved:{checkpoint_dir}:step-{step_i}'
                        )

                # Get new model inputs
                if step_i <= _N_STEPS_WARMUP_LOGGING:
                    logging.info('step=`%d`: Retrieving model inputs.', step_i)
                logging.debug('  Retrieving inputs.')
                model_inputs = train_input_pipeline.get_next()

                if jax.config.jax_parallel_functions_output_gda:
                    if step_i <= _N_STEPS_WARMUP_LOGGING:
                        start = time.time()
                    py_utils.assert_same_shape_and_dtype(
                        inputs_shape,
                        tf.nest.map_structure(
                            py_utils.get_global_input_shape_dtype,
                            model_inputs))
                    model_inputs = py_utils.create_gda(model_inputs,
                                                       inputs_shape,
                                                       global_mesh,
                                                       inputs_pspecs)
                    if step_i <= _N_STEPS_WARMUP_LOGGING:
                        logging.info('GDA train batch input creation time %s',
                                     time.time() - start)

                logging.debug('  Retrieved inputs.')

                logging.debug('  Performing train_step().')
                with jax.profiler.StepTraceAnnotation('train',
                                                      step_num=step_i):
                    (partitioned_train_state, loss, metrics, per_example_out,
                     summary_tensors) = train_step(partitioned_train_state,
                                                   train_key, model_inputs)
                logging.debug('  Completed train_step().')

                logging.debug('  Writing summaries (attempt).')
                if summary_utils.write_summary_every_n_steps(
                        partitioned_train_state,
                        train_summary_writer,
                        step_i,
                        train_p.summary_interval_steps,
                        loss,
                        metrics,
                        per_example_out,
                        summary_tensors,
                        train_p.norm_summary_interval_steps,
                        summary_last_time,
                        summary_last_step,
                        unreplicate_mdl_vars=False,
                        unreplicate_metrics=False):
                    summary_last_time = time.time()
                    summary_last_step = step_i
                    step_i = int(
                        py_utils.maybe_unreplicate_gda(
                            partitioned_train_state.step))
                else:
                    # Increment train step locally to avoid an explicit device sync.
                    step_i += 1
                logging.debug('  Wrote summaries (attempted).')

                # Run eval at regular step interval.
                if step_i % train_p.eval_interval_steps == 0:
                    logging.debug('  Starting eval_step().')
                    logging.debug('  Retrieving eval model_inputs.')
                    eval_inputs = train_input_pipeline.get_next()

                    if jax.config.jax_parallel_functions_output_gda:
                        eval_inputs = py_utils.create_gda(
                            eval_inputs, inputs_shape, global_mesh,
                            inputs_pspecs)

                    logging.debug('  Retrieved eval model_inputs.')
                    logging.debug(
                        '  Performing eval_step() runs on training split.')

                    eval_step_fn = functools.partial(
                        eval_step, partitioned_train_state.mdl_vars, eval_key,
                        partitioned_train_state.step)
                    loss, mean_metrics, summary_tensors = model_utils.run_eval_one_step(
                        eval_inputs, eval_step_fn, reshard_inputs=False)
                    logging.debug(
                        '  Completed eval_step() runs on training split.')

                    logging.info('step=`%d`', step_i)
                    logging.info('  eval loss: %s', loss)
                    logging.info('  mean_metrics: %s', mean_metrics)
                    logging.info('  summary_tensors: %s', summary_tensors)
                    if step_i % train_p.summary_interval_steps == 0:
                        logging.debug('  Writing eval summaries.')
                        summary_utils.write_summary_entry(
                            eval_summary_writer,
                            step_i,
                            loss,
                            mean_metrics,
                            summary_tensors,
                            unreplicate_metrics=False)
                        logging.debug('  Wrote eval summaries.')
                    # If we have eval test then also evaluate on test.
                    if eval_input_p is not None:
                        logging.debug(
                            '  Performing eval_step() runs on test splits.')
                        model_utils.run_eval_loop_over_test_splits(
                            eval_num_steps,
                            eval_step_fn,
                            eval_test_summary_writers,
                            step_i,
                            eval_input_pipelines,
                            eval_test_inputs_pspecs,
                            eval_test_inputs_shape,
                            global_mesh,
                            reshard_inputs=False)
                        logging.debug(
                            '  Completed eval_step() runs on test splits.')

                logging.debug('step=`%d`: End', step_i - 1)