Example #1
0
    def _sweep(self, global_step_id: int) -> None:
        """Deletes or preserves managed checkpoints."""
        if not self._max_to_keep:
            return

        kept_checkpoints = []
        maybe_delete_checkpoints = []
        for checkpoint in self._checkpoint_history.checkpoints:
            if (self._last_kept_checkpoint_datetime is not None
                    and (from_timestamp(checkpoint.timestamp_sec) <=
                         self._last_kept_checkpoint_datetime)):
                kept_checkpoints.append(checkpoint)
            else:
                maybe_delete_checkpoints.append(checkpoint)

        if not self.use_multi_host_flax:
            py_utils.sync_global_devices(
                'checkpoint_manager:begin_delete_checkpoints:'
                f'step_{global_step_id}')
        while len(maybe_delete_checkpoints) > self._max_to_keep:
            checkpoint = maybe_delete_checkpoints.pop(0)
            if (self._keep_interval_timedelta
                    and (not kept_checkpoints
                         or self._last_kept_checkpoint_datetime is None or
                         ((from_timestamp(checkpoint.timestamp_sec) -
                           self._last_kept_checkpoint_datetime) >=
                          self._keep_interval_timedelta))):
                kept_checkpoints.append(checkpoint)
                self._last_kept_checkpoint_datetime = from_timestamp(
                    checkpoint.timestamp_sec)
                continue
            self._delete_checkpoint(checkpoint)
        if not self.use_multi_host_flax:
            py_utils.sync_global_devices(
                'checkpoint_manager:end_delete_checkpoints:'
                f'step_{global_step_id}')

        self._checkpoint_history = self._create_checkpoint_history()
        for c in itertools.chain(kept_checkpoints, maybe_delete_checkpoints):
            logging.info('Keeping checkpoint: (step: %d, timestamp: %s)',
                         c.global_step_id, c.timestamp_sec)
            self._checkpoint_history.checkpoints.add().CopyFrom(c)
Example #2
0
    def _save_checkpoint_file(self,
                              global_step_id: int,
                              key: str = '') -> None:
        """Saves the checkpoint file with latest checkpoint metadata.

    Note: This method overrides previous version of the checkpoint file.

    Args:
      global_step_id: The current global step.
      key: A key to add to the synchronization string.
    """
        py_utils.sync_global_devices(
            f'checkpoint_manager:begin_save_checkpoint_file:{key}_{global_step_id}'
        )
        if self.use_multi_host_flax or jax.process_index() == 0:
            with tf.io.gfile.GFile(self.checkpoint_filename, 'wb') as writer:
                writer.write(self._checkpoint_history.SerializeToString())
        py_utils.sync_global_devices(
            f'checkpoint_manager:end_save_checkpoint_file:{key}_{global_step_id}'
        )
Example #3
0
def _restore_checkpoint_gda(
    train_state: Optional[train_states.TrainState],
    checkpoint_dir: str,
    global_mesh: Optional[maps.Mesh],
    state_specs: Optional[train_states.TrainState],
    step: Optional[int] = None) -> train_states.TrainState:
  """Restores a checkpoint using JAX GDA deserialization mechanism."""
  if not tf.io.gfile.exists(checkpoint_dir) or not tf.io.gfile.listdir(
      checkpoint_dir):
    if train_state is not None and step is None:
      logging.info(
          'GDA checkpoint restore did not find checkpoint_dir %s; '
          'Return train_state passed in', checkpoint_dir)
      return train_state
    raise FileNotFoundError(
        f'No checkpoint found for restore in {checkpoint_dir}')

  if step is None:
    checkpoint_dirnames = tf.io.gfile.listdir(checkpoint_dir)
    tmp_checkpoint_dirnames = [
        x for x in checkpoint_dirnames if _is_tmp_checkpoint_asset(x)
    ]
    if tmp_checkpoint_dirnames:
      logging.warn('Found incompletely saved checkpoints %s; skipping them',
                   tmp_checkpoint_dirnames)
    sorted_dirnames = sorted(
        [x for x in checkpoint_dirnames if _is_checkpoint_asset(x)])
    if not sorted_dirnames:
      raise FileNotFoundError(
          f'No checkpoint found for restore in {checkpoint_dir}')
    latest_checkpoint_dirname = sorted_dirnames[-1]
    step = get_step_from_checkpoint_asset(latest_checkpoint_dirname)
    checkpoint_step_dir = _make_checkpoint_step_dir(checkpoint_dir, step)
    logging.info('Found latest checkpoint: %s', checkpoint_step_dir)
  else:
    checkpoint_step_dir = _make_checkpoint_step_dir(checkpoint_dir, step)
    if not tf.io.gfile.exists(checkpoint_step_dir) or not tf.io.gfile.listdir(
        checkpoint_step_dir):
      raise FileNotFoundError(
          f'No checkpoint found for restore in {checkpoint_step_dir}')

  logging.info('GDA checkpoint restore started...')
  if train_state is not None:
    leaves, treedef = jax.tree_util.tree_flatten(train_state)
    partition_spec_leaves, _ = jax.tree_util.tree_flatten(state_specs)
    nested_names = _extract_nested_prefix_names(train_state)
    global_shapes = jax.tree_map(lambda x: x.shape, leaves)
  else:
    partition_spec_leaves, treedef = jax.tree_util.tree_flatten(state_specs)
    nested_names = _extract_nested_prefix_names(state_specs)
    global_shapes = None

  flattened_nested_names, _ = jax.tree_util.tree_flatten(nested_names)

  ckpt_paths = [
      os.path.join(checkpoint_step_dir, x).rstrip('/')
      for x in flattened_nested_names
  ]
  tspecs = jax.tree_map(gda_serialization.get_tensorstore_spec, ckpt_paths)

  train_state_gda = gda_serialization.run_deserialization(
      [global_mesh] * len(tspecs),
      partition_spec_leaves,
      tspecs,
      global_shapes=global_shapes)

  restored_train_state = jax.tree_util.tree_unflatten(treedef, train_state_gda)
  # Barrier across all processes to ensure all restore finish.
  py_utils.sync_global_devices('Wait for checkpoint restore from '
                               f'{checkpoint_step_dir} to finish.')
  logging.info('Successfully restored GDA checkpoint at %s!',
               checkpoint_step_dir)
  return restored_train_state
Example #4
0
def _save_checkpoint_gda(train_state: train_states.TrainState,
                         checkpoint_dir: str, overwrite: bool,
                         step: int) -> None:
  """Saves a checkpoint using JAX GDA serialization mechanism.

  Note that all JAX processes must call _save_checkpoint_gda in sync because
  each process may only have a slice of the global data.

  Args:
    train_state: A partitioned train_state that is a Pytree of
      GlobalDeviceArray.
    checkpoint_dir: Full path to parent checkpoint_dir.
    overwrite: Whether to allow overwriting an existing target directory.
    step: Step to save checkpoint for.
  """
  if not overwrite:
    # Does not contain directory path, only dirname is returned.
    checkpoint_dirnames = tf.io.gfile.listdir(checkpoint_dir)
    # Delete tmp directories if any.
    if jax.process_index() == 0:
      tmp_checkpoint_dirnames = [
          x for x in checkpoint_dirnames if _is_tmp_checkpoint_asset(x)
      ]
      if tmp_checkpoint_dirnames:
        logging.warn('Found incompletely saved checkpoints %s; deleting them',
                     tmp_checkpoint_dirnames)
        for x in tmp_checkpoint_dirnames:
          tf.io.gfile.rmtree(os.path.join(checkpoint_dir, x))
    # Note we must barrier across all processes after the tmp directory delete.
    py_utils.sync_global_devices('Wait for checkpoint tmp dir deletions to '
                                 'finish.')

    sorted_dirnames = sorted(
        [x for x in checkpoint_dirnames if _is_checkpoint_asset(x)])
    if sorted_dirnames:
      latest_checkpoint_dirname = sorted_dirnames[-1]
      previous_step = get_step_from_checkpoint_asset(latest_checkpoint_dirname)
      if previous_step >= step:
        logging.warning(
            'A more recent checkpoint `%d` has already been saved compared '
            'to the current timestep `%d`. Skip saving a checkpoint.',
            previous_step, step)
        return

  checkpoint_step_dir = _make_checkpoint_step_dir(checkpoint_dir, step)
  checkpoint_step_tmp_dir = _make_tmp_checkpoint_dir(
      checkpoint_dir, step, sync_timestamp=True)
  logging.info('Saving to a tmp checkpoint dir %s', checkpoint_step_tmp_dir)

  nested_names = _extract_nested_prefix_names(train_state)
  flattened_nested_names, _ = jax.tree_util.tree_flatten(nested_names)

  if jax.process_index() == 0:
    # Create the tmp parent dir.
    tf.io.gfile.makedirs(checkpoint_step_tmp_dir)

  with futures.ThreadPoolExecutor() as executor:
    ckpt_paths = list(
        executor.map(_mkdir_path, flattened_nested_names,
                     [checkpoint_step_tmp_dir] * len(flattened_nested_names)))
  py_utils.sync_global_devices('Wait for checkpoint tmp dir and subdirs '
                               f'creation {checkpoint_step_tmp_dir} to finish.')

  tspecs = jax.tree_map(gda_serialization.get_tensorstore_spec, ckpt_paths)
  leaves, _ = jax.tree_util.tree_flatten(train_state)

  gda_serialization.run_serialization(leaves, tspecs)

  # Note we must barrier across all processes before the directory rename.
  py_utils.sync_global_devices('Wait for checkpoint chunk writes to '
                               f'{checkpoint_step_tmp_dir} to finish.')

  if jax.process_index() == 0:
    # Rename temporary checkpoint directory to its final location.
    logging.info('Renaming %s to %s', checkpoint_step_tmp_dir,
                 checkpoint_step_dir)
    tf.io.gfile.rename(checkpoint_step_tmp_dir, checkpoint_step_dir)

  logging.info('Finished saving GDA checkpoint for step `%s` to `%s`.', step,
               checkpoint_step_dir)
Example #5
0
def _save_checkpoint_flax(train_state: train_states.TrainState,
                          checkpoint_dir: str, overwrite: bool,
                          unreplicate: bool, step: int,
                          use_multi_host: bool) -> None:
  """Saves a checkpoint using Flax serialization mechanism."""
  if not overwrite:
    previous_filename = latest_checkpoint(checkpoint_dir)
    if previous_filename:
      previous_step = int(previous_filename.rsplit('_', 1)[-1])
      if previous_step >= step:
        logging.warning(
            'A more recent checkpoint `%d` has already been saved compared '
            'to the current timestep `%d`. Skip saving a checkpoint.',
            previous_step, step)
        return

  # Assume data parallel-only model for now and retrieve train states
  # from the first replica only.
  def maybe_unreplicate(data):
    if unreplicate:
      return jax.device_get(jax_utils.unreplicate(data))
    else:
      return jax.device_get(data)

  # Extract/flatten data structure to store to disk. Flax requires a flattened
  # data structure to be passed to the checkpointer.
  flattened_state, pytree_state = jax.tree_flatten(
      maybe_unreplicate(train_state))
  checkpoint_target = {
      'flattened_state': flattened_state,
      # Saves a serialized version of the pytree structure to detect potential
      # mismatch caused by different versions of saver/restorer.
      'str_pytree_state': str(pytree_state),
  }

  prefix = CHECKPOINT_PREFIX
  if use_multi_host:
    # Notes:
    # 1. We currently don't broadcast / synchronize the timestamp across
    #    all the JAX processes.
    # 2. Flax checkpointing already saves the checkpoint file into a temporary
    #    file that is ultimately moved. We, hence, don't need to add a second
    #    layer of temporary files for single-host checkpointing.
    timestamp = _to_timestamp(datetime.datetime.utcnow())
    prefix = f'{TMP_PREFIX}{timestamp}.{prefix}'

  checkpoints.save_checkpoint(
      checkpoint_dir,
      checkpoint_target,
      step,
      prefix=prefix,
      keep=_MAX_CHECKPOINT_FLAX,
      overwrite=overwrite)

  if use_multi_host:
    py_utils.sync_global_devices(
        f'Renaming temporary checkpoint files at step {step} into their final '
        'destination.')
    tmp_filename = os.path.join(checkpoint_dir, f'{prefix}{step}')
    new_filename = os.path.join(checkpoint_dir, f'{CHECKPOINT_PREFIX}{step}')
    logging.debug('Renaming %s to %s', tmp_filename, new_filename)
    tf.io.gfile.rename(tmp_filename, new_filename)
Example #6
0
def decode_once_spmd_model(
    task_p: InstantiableParams,
    input_p: Sequence[InstantiableParams],
    job_log_dir: Optional[str],
    checkpoint_type: CheckpointType,
    restore_checkpoint_dir: str,
    restore_checkpoint_step: Optional[int],
) -> None:
  """Runs the decoding once on the entire decoder datasets for SPMD model.

  Args:
    task_p: Params for the task that encapsulates an SPMD model.
    input_p: List of input params to be decoded.
    job_log_dir: Directory for the job logs.
    checkpoint_type: Type of model checkpointing method to use.
    restore_checkpoint_dir: The directory from which to restore checkpoint.
    restore_checkpoint_step: If set, the checkpoint step to restore. If unset,
      try to restore from the latest checkpoint if any.
  """
  # TODO(bf-jax): Retrieve the seeds from the model definition instead.
  prng_key = jax.random.PRNGKey(1234)
  prng_key, init_key = jax.random.split(prng_key)

  if restore_checkpoint_dir:
    restore_checkpoint_parent_dir = restore_checkpoint_dir
    if checkpoint_type == CheckpointType.CHECKPOINT_MULTI_HOST_FLAX:
      # TODO(zhouwk): add sanity check on number of subdirs and number of
      # processes and fail early if unequal.
      restore_checkpoint_dir = os.path.join(restore_checkpoint_dir,
                                            f'{jax.process_index():03d}')

  multi_host_checkpointing = bool(checkpoint_type in {
      CheckpointType.CHECKPOINT_MULTI_HOST_FLAX, CheckpointType.CHECKPOINT_GDA
  })

  sample_inputs = input_p[0].Instantiate().get_next()
  inputs_shape = tf.nest.map_structure(py_utils.get_global_input_shape_dtype,
                                       sample_inputs)

  model_p = task_p.model
  # TODO(b/198356509): This is a hack for now as we need to change some
  # annotations for mode='decode'. A future cl will move this logic
  # to a more generic model_p.update_sharding_params_v1(mode='decode').
  model_p.lm = model_p.lm.cls.set_sharding_params_v1(
      model_p.lm,
      replica_axis=model_p.lm.mesh_axis_names[0],
      data_axis=model_p.lm.mesh_axis_names[1],
      mdl_axis=model_p.lm.mesh_axis_names[2],
      device_ids_mesh=model_p.lm.device_mesh,
      mesh_axis_names=model_p.lm.mesh_axis_names,
      mode='decode')

  mesh_shape = model_p.device_mesh.shape
  device_mesh = mesh_utils.create_device_mesh(mesh_shape)
  logging.info('device_mesh: %s', device_mesh)
  jax_task = task_p.Instantiate()
  global_mesh = maps.Mesh(device_mesh, model_p.mesh_axis_names)
  with global_mesh:
    if restore_checkpoint_dir:
      model = jax_task.model
      model.instantiate_variable_configs()
      # Get the metadata from variables instead of actually instantiating them.
      partitioned_specs = jax_task.create_train_state_partition_specs(
          model.vars, is_eval=True)
      # Instantiate the TrainState directly from the checkpoint.
      partitioned_train_state = checkpoints.restore_checkpoint(
          None,
          restore_checkpoint_dir,
          global_mesh=global_mesh,
          checkpoint_type=checkpoint_type,
          state_specs=partitioned_specs,
          step=restore_checkpoint_step)
      if multi_host_checkpointing:
        py_utils.sync_global_devices(
            f'checkpointer:restored:{restore_checkpoint_parent_dir}')
      decode_step_fn, inputs_partition_spec = (
          trainer_lib.get_partitioned_spmd_model_decode_fn(
              jax_task, init_key, partitioned_train_state, partitioned_specs,
              inputs_shape))
    else:
      # When restore is not specified, randomly initiate the train_state.
      (partitioned_train_state, inputs_partition_spec, partitioned_specs,
       decode_step_fn) = trainer_lib.partition_spmd_model_decode(
           task_p, init_key, inputs_shape)
    logging.info('partitioned_train_state: %s',
                 jax.tree_map(lambda x: x.shape, partitioned_train_state))
    # We do not fold in jax.process_index in contrast to the pmap version and
    # use a single global key instead to rely on pjit to split for different
    # replicas.
    logging.info('root prng_key: %s', prng_key)
    prng_key, decode_key = jax.random.split(prng_key)
    logging.info('eval prng_key: %s', decode_key)
    spmd_decode_step_fn = functools.partial(decode_step_fn,
                                            partitioned_train_state.mdl_vars,
                                            decode_key,
                                            partitioned_train_state.step)

    num_steps = [
        -1 if p.reset_for_eval else p.eval_loop_num_batches for p in input_p
    ]
    inputs = [p.Instantiate() for p in input_p]
    decodes = [list() for _ in input_p]
    process_id = jax.process_index()

    for split, num_split_steps in enumerate(num_steps):
      logging.info('Start decoding on input %s', input_p[split].name)
      step_num = 0
      while num_split_steps < 0 or step_num < num_split_steps:
        step_num += 1
        try:
          batch = inputs[split].get_next()
        except (tf.errors.OutOfRangeError, StopIteration):
          break
        if jax.config.jax_parallel_functions_output_gda:
          batch = py_utils.create_gda(batch, inputs_shape, global_mesh,
                                      inputs_partition_spec)
        _, out = spmd_decode_step_fn(batch)
        # Output is fully replicated now, so it's ok to unreplicate it by
        # retrieving from device 0 only.
        out = py_utils.maybe_unreplicate_gda(out)
        global_batch_size = next(iter(out.values())).shape[0]
        logging.info('Finished decoding input batch %d with %d examples',
                     step_num, global_batch_size)
        # Manually shard the output per each jax process.
        # We require that all fields in the output is batch major.
        if global_batch_size % jax.process_count() != 0:
          raise ValueError(f'Global batch size {global_batch_size} must divide '
                           f'jax process count {jax.process_count()}')
        for k, v in out.items():
          if v.shape[0] != global_batch_size:
            raise ValueError('We require that all fields in the decode output '
                             'to have batch size as the first dim, got shape='
                             f'{v.shape} with key={k}, expect batch size = '
                             f'{global_batch_size}')
        per_process_batch_size = global_batch_size // jax.process_count()

        def shard(x, per_process_batch_size=per_process_batch_size):
          return x[(process_id *
                    per_process_batch_size):((process_id + 1) *
                                             per_process_batch_size)]

        out = jax.tree_map(shard, out)
        _, processed = jax_task.model.process_decode_out(inputs[split], out)
        decodes[split].extend(processed)
        logging.info('Finished processing decoded input batch %d', step_num)

  basedir = os.path.join(job_log_dir, 'decoder_out')
  dirnames = _get_dir_names(input_p)
  filename = _get_filename(
      py_utils.maybe_unreplicate_gda(partitioned_train_state.step))
  for s in dirnames:
    dir_path = os.path.join(basedir, s)
    if not tf.io.gfile.exists(dir_path):
      tf.io.gfile.makedirs(dir_path)
  filenames = [os.path.join(basedir, s, filename) for s in dirnames]
  for split, output_file in enumerate(filenames):
    logging.info('Writing decoder output to %s with %d entries', output_file,
                 len(decodes[split]))
    io_utils.WriteKeyValuePairs(output_file, decodes[split])
Example #7
0
def evaluate_spmd_model(
    task_p: InstantiableParams,
    eval_input_p: Sequence[InstantiableParams],
    job_log_dir: Optional[str],
    checkpoint_type: CheckpointType,
) -> None:
  """Runs the evaluation loop on the entire test dataset for SPMD model.

  Args:
    task_p: Params of the task encapsulating an SPMD model.
    eval_input_p: List of Params for the eval data pipelines.
    job_log_dir: Directory for the job logs.
    checkpoint_type: Type of model checkpointing method to use.
  """
  logging.info('Using SPMD sharding for model parallelism.')
  eval_input_pipelines = [input_p.Instantiate() for input_p in eval_input_p]
  # TODO(bf-jax): Retrieve the seeds from the model definition instead.
  prng_key = jax.random.PRNGKey(1234)
  prng_key, init_key = jax.random.split(prng_key)

  checkpoint_dir = os.path.join(job_log_dir, 'checkpoints')
  # Note that GDA checkpoint requires all processes to participate in
  # checkpointing but it does not require a separate checkpoint_dir per process.
  if checkpoint_type == CheckpointType.CHECKPOINT_MULTI_HOST_FLAX:
    checkpoint_task_dir = os.path.join(checkpoint_dir,
                                       f'{jax.process_index():03d}')
  else:
    checkpoint_task_dir = checkpoint_dir

  multi_host_checkpointing = bool(checkpoint_type in {
      CheckpointType.CHECKPOINT_MULTI_HOST_FLAX, CheckpointType.CHECKPOINT_GDA
  })

  def get_shape_dtype(x):
    y = jax.ShapeDtypeStruct(x.shape, x.dtype)
    return y

  # Do not ues eval_input_pipelines[0] directly.
  sample_model_inputs = eval_input_p[0].Instantiate().get_next()
  inputs_shape = tf.nest.map_structure(get_shape_dtype, sample_model_inputs)

  model_p = task_p.model
  mesh_shape = model_p.device_mesh.shape
  device_mesh = mesh_utils.create_device_mesh(mesh_shape)
  logging.info('device_mesh: %s', device_mesh)
  global_mesh = maps.Mesh(device_mesh, model_p.mesh_axis_names)
  with global_mesh:
    partitioned_train_state, partitioned_specs, eval_inputs_partition_specs, _, eval_step, _ = (
        trainer_lib.partition_spmd_model(task_p, init_key, inputs_shape))
    partitioned_train_state = checkpoints.restore_checkpoint(
        partitioned_train_state,
        checkpoint_task_dir,
        global_mesh=global_mesh,
        checkpoint_type=checkpoint_type,
        state_specs=partitioned_specs)
    logging.info('partitioned_train_state: %s',
                 jax.tree_map(lambda x: x.shape, partitioned_train_state))
    if multi_host_checkpointing:
      py_utils.sync_global_devices(f'checkpointer:restored:{checkpoint_dir}')

    # We do not fold in jax.process_index in contrast to the pmap version and
    # use a single global key instead to rely on pjit to split for different
    # replicas.
    logging.info('root prng_key: %s', prng_key)
    prng_key, eval_key = jax.random.split(prng_key)
    logging.info('eval prng_key: %s', eval_key)

    logging.info('Evaluation loop starting...')
    summary_base_dir = os.path.join(job_log_dir, 'summaries')
    summary_eval_dirs = [
        os.path.join(summary_base_dir, f'eval_{split}')
        for split, _ in enumerate(eval_input_p)
    ]

    num_steps = [-1 if p.reset_for_eval else 1 for p in eval_input_p]
    last_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir)
    with contextlib.ExitStack() as exit_stack:
      eval_summary_writers = [
          exit_stack.enter_context(summary_utils.get_summary_writer(d))
          for d in summary_eval_dirs
      ]
      while True:
        step_i = int(jax.device_get(partitioned_train_state.step))
        eval_step_fn = functools.partial(eval_step,
                                         partitioned_train_state.mdl_vars,
                                         eval_key, partitioned_train_state.step)
        # Run the eval loop.
        model_utils.run_eval_loop_over_test_splits(
            num_steps,
            eval_step_fn,
            eval_summary_writers,
            step_i,
            eval_input_pipelines,
            eval_inputs_partition_specs,
            inputs_shape,
            global_mesh,
            reshard_inputs=False)
        # If the last check point evaluated matches max train steps, exit.
        if last_checkpoint is not None:
          last_ckpt_step = checkpoints.get_step_from_checkpoint_asset(
              last_checkpoint)
          exceeded_ckpt = last_ckpt_step + task_p.train.save_interval_steps
          if exceeded_ckpt >= task_p.train.num_train_steps:
            break
        new_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir)
        while new_checkpoint == last_checkpoint:
          # Sleep for a minute.
          time.sleep(60)
          new_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir)
        # There must be a new checkpoint here.
        logging.info('Found new checkpoint: %s', new_checkpoint)
        partitioned_train_state = checkpoints.restore_checkpoint(
            partitioned_train_state,
            checkpoint_task_dir,
            global_mesh=global_mesh,
            checkpoint_type=checkpoint_type,
            state_specs=partitioned_specs)
        if multi_host_checkpointing:
          py_utils.sync_global_devices(
              f'checkpointer:restored:{checkpoint_dir}')
        last_checkpoint = new_checkpoint
Example #8
0
 def test_sync_global_devices(self):
     py_utils.sync_global_devices('sync')
Example #9
0
    def _init_checkpoint_history(self) -> None:
        """Initializes the checkpoint history and sets related class attributes."""
        self._last_saved_checkpoint_step: int = None
        self._last_kept_checkpoint_datetime: Optional[datetime.datetime] = None

        if not tf.io.gfile.exists(self.checkpoint_filename):
            self._checkpoint_history = self._create_checkpoint_history()
            return

        # Read the previous checkpoints file and performs a sanity check.
        self._checkpoint_history = self._read_checkpoint_file()

        last_saved_timestamp = (
            self._checkpoint_history.checkpoints[-1].timestamp_sec)
        current_datetime = datetime.datetime.utcnow()
        if current_datetime < from_timestamp(last_saved_timestamp):
            # Time seems to have reversed itself.
            logging.warning(
                'datetime.datetime.utcnow() returned a value `%s` behind the last '
                'saved checkpoint timestamp.',
                from_timestamp(last_saved_timestamp) - current_datetime)

        # Add few of the checkpoints to the `kept` list.
        kept_checkpoints = []
        if self._keep_interval_timedelta is None:
            maybe_delete_checkpoints = list(
                self._checkpoint_history.checkpoints)
        else:
            maybe_delete_checkpoints = []
            oldest_kept_timestamp = None
            for checkpoint in self._checkpoint_history.checkpoints:
                if (oldest_kept_timestamp is None
                        or ((from_timestamp(oldest_kept_timestamp) -
                             from_timestamp(checkpoint.timestamp_sec)) >=
                            self._keep_interval_timedelta)):
                    oldest_kept_timestamp = checkpoint.timestamp_sec
                    kept_checkpoints.append(checkpoint)
                    if self._last_kept_checkpoint_datetime is None:
                        self._last_kept_checkpoint_datetime = (from_timestamp(
                            checkpoint.timestamp_sec))
                else:
                    maybe_delete_checkpoints.append(checkpoint)

        # Only keep at most `max_to_keep` non-kept checkpoints. Delete the old ones.
        if not self.use_multi_host_flax:
            py_utils.sync_global_devices(
                'checkpoint_manager:begin_delete_checkpoints:'
                f'init_{self._checkpoint_history.checkpoints[-1].global_step_id}'
            )
        for i, checkpoint in enumerate(reversed(maybe_delete_checkpoints)):
            if self._max_to_keep is None or i < self._max_to_keep:
                kept_checkpoints.append(checkpoint)
            else:
                self._delete_checkpoint(checkpoint)
        if not self.use_multi_host_flax:
            py_utils.sync_global_devices(
                'checkpoint_manager:end_delete_checkpoints:'
                f'init_{self._checkpoint_history.checkpoints[-1].global_step_id}'
            )

        # Finally create a new CheckpointHistory and save a new checkpoint file.
        kept_checkpoints = sorted(
            kept_checkpoints, key=lambda c: from_timestamp(c.timestamp_sec))
        latest_global_step = kept_checkpoints[-1].global_step_id
        self._last_saved_checkpoint_step = latest_global_step
        self._checkpoint_history = self._create_checkpoint_history()
        for c in kept_checkpoints:
            self._checkpoint_history.checkpoints.add().CopyFrom(c)

        self._save_checkpoint_file(latest_global_step)
Example #10
0
def train_and_evaluate_spmd_model(
        task_p: InstantiableParams, train_input_p: InstantiableParams,
        job_log_dir: Optional[str],
        checkpoint_manager: checkpoint_managers.CheckpointManager,
        checkpoint_type: CheckpointType, restore_checkpoint_dir: Optional[str],
        restore_checkpoint_step: Optional[int],
        eval_input_p: Optional[Sequence[InstantiableParams]]) -> None:
    """Runs the training and evaluation loop.

  Args:
    task_p: Params for task encapsulating the SPMD model.
    train_input_p: Params for the train data pipeline.
    job_log_dir: Directory for the job logs.
    checkpoint_manager: A checkpoint manager controlling how often to save and
      delete checkpoints.
    checkpoint_type: The type of checkpoint to use.
    restore_checkpoint_dir: If set, the directory from which to restore
      checkpoint. If unset, use job_log_dir's `checkpoints` subdirectory
      instead.
    restore_checkpoint_step: If set, the checkpoint step to restore. If unset,
      try to restore from the latest checkpoint if any.
    eval_input_p: Optional list of params for the eval input pipelines.
  """
    logging.info('Using SPMD sharding for model parallelism.')
    model_p = task_p.model

    if eval_input_p is not None:
        eval_input_pipelines = [
            input_p.Instantiate() for input_p in eval_input_p
        ]
        # Do not mutate eval_input_pipelines itself. Instantiate a new one
        # to get sample input.
        sample_eval_model_inputs = eval_input_p[0].Instantiate().get_next()
        eval_test_inputs_shape = tf.nest.map_structure(
            py_utils.get_global_input_shape_dtype, sample_eval_model_inputs)
        eval_test_inputs_pspecs = trainer_lib.get_input_partition_specs(
            model_p.mesh_axis_names, eval_test_inputs_shape)

    # TODO(bf-jax): Retrieve the seeds from the model definition instead.
    prng_key = jax.random.PRNGKey(1234)
    prng_key, init_key = jax.random.split(prng_key)

    checkpoint_dir = _checkpoint_dir(job_log_dir)
    restore_checkpoint_dir = restore_checkpoint_dir or checkpoint_dir
    # Note that GDA checkpoint requires all processes to participate in
    # checkpointing but it does not require a separate checkpoint_dir per process.
    if checkpoint_type == CheckpointType.CHECKPOINT_MULTI_HOST_FLAX:
        checkpoint_task_dir = os.path.join(checkpoint_dir,
                                           f'{jax.process_index():03d}')
        restore_checkpoint_task_dir = os.path.join(
            restore_checkpoint_dir, f'{jax.process_index():03d}')
    else:
        checkpoint_task_dir = checkpoint_dir
        restore_checkpoint_task_dir = restore_checkpoint_dir

    multi_host_checkpointing = bool(checkpoint_type in {
        CheckpointType.CHECKPOINT_MULTI_HOST_FLAX,
        CheckpointType.CHECKPOINT_GDA
    })

    if jax.process_index() == 0:
        tf.io.gfile.makedirs(checkpoint_dir)
    if multi_host_checkpointing:
        # Block all hosts until directory is ready.
        py_utils.sync_global_devices(f'checkpointer:makedirs:{checkpoint_dir}')

    logging.info('Retrieving model inputs for shape info.')
    train_input_for_shape = train_input_p.Instantiate()
    model_inputs_for_shape = train_input_for_shape.get_next()
    inputs_shape = tf.nest.map_structure(py_utils.get_global_input_shape_dtype,
                                         model_inputs_for_shape)

    mesh_shape = model_p.device_mesh.shape
    device_mesh = mesh_utils.create_device_mesh(mesh_shape)
    logging.info('device_mesh: %s', device_mesh)

    global_mesh = maps.Mesh(device_mesh, model_p.mesh_axis_names)
    with global_mesh:
        (partitioned_train_state, train_state_pspecs, inputs_pspecs,
         train_step, eval_step,
         total_num_params) = trainer_lib.partition_spmd_model(
             task_p, init_key, inputs_shape)

        partitioned_train_state = checkpoints.restore_checkpoint(
            partitioned_train_state,
            restore_checkpoint_task_dir,
            global_mesh=global_mesh,
            checkpoint_type=checkpoint_type,
            state_specs=train_state_pspecs,
            step=restore_checkpoint_step)
        logging.info(
            'partitioned_train_state shapes '
            '(global shape for GDA, host-local shape for non-GDA: %s',
            jax.tree_map(lambda x: x.shape, partitioned_train_state))
        if multi_host_checkpointing:
            py_utils.sync_global_devices(
                f'checkpointer:restored:{checkpoint_dir}')

        train_p = task_p.train
        initial_global_step = int(
            jax.device_get(
                py_utils.maybe_unreplicate_gda(partitioned_train_state.step)))
        logging.info('Model initial global_step=%d', initial_global_step)
        _update_latest_model_step(train_input_p, initial_global_step,
                                  train_p.eval_interval_steps)
        train_input_pipeline = train_input_p.Instantiate()

        # We do not fold in jax.process_index in contrast to the pmap version and
        # use a single global key instead to rely on pjit to split for different
        # replicas.
        logging.info('root prng_key: %s', prng_key)
        prng_key, train_key, eval_key = jax.random.split(prng_key, 3)
        logging.info('train prng_key: %s', train_key)
        logging.info('eval prng_key: %s', eval_key)

        logging.info('Training loop starting...')
        summary_base_dir = os.path.join(job_log_dir, 'summaries')
        summary_train_dir = os.path.join(summary_base_dir, 'train')
        summary_eval_dir = os.path.join(summary_base_dir, 'eval_train')
        summary_writer = summary_utils.get_summary_writer
        if eval_input_p is not None:
            summary_eval_test_dirs = [
                os.path.join(summary_base_dir, f'eval_test_{split}')
                for split, _ in enumerate(eval_input_p)
            ]
            # We either run p.eval_loop_num_batches steps or one epoch (when supported
            # by a resettable input) per eval loop during training. When
            # p.reset_for_eval is set to True, we run the eval loop until
            # tf.errors.OutOfRangeError (or StopIteration) is raised, which can be
            # triggered either because input pipeline has reached the end of the input
            # sequence, or a pre-determined num_batches has reached.
            eval_num_steps = [
                -1 if p.reset_for_eval else p.eval_loop_num_batches
                for p in eval_input_p
            ]
        else:
            summary_eval_test_dirs = []

        with contextlib.ExitStack() as exit_stack:
            train_summary_writer = exit_stack.enter_context(
                summary_writer(summary_train_dir))
            eval_summary_writer = exit_stack.enter_context(
                summary_writer(summary_eval_dir))
            eval_test_summary_writers = [
                exit_stack.enter_context(summary_writer(d))
                for d in summary_eval_test_dirs
            ]

            # This only prints the view from the first host machine.
            summary_utils.write_model_structure(train_summary_writer,
                                                partitioned_train_state,
                                                is_vars_replicated=False)
            summary_utils.write_total_num_params(train_summary_writer,
                                                 total_num_params)

            summary_last_time = time.time()
            summary_last_step = None

            step_i = int(
                jax.device_get(
                    py_utils.maybe_unreplicate_gda(
                        partitioned_train_state.step)))

            # Start the train loop. Make sure all at the same step.
            py_utils.sync_global_devices(
                f'Start training loop from step: {step_i}')
            while True:
                logging.debug('step=`%d`: Beginning', step_i)
                if step_i >= train_p.num_train_steps:
                    logging.info(
                        'Training loop completed (step (`%d`) greater than '
                        'num_train_step (`%d`).', step_i,
                        train_p.num_train_steps)
                    break

                if summary_last_step is None:
                    summary_last_step = step_i - 1

                if checkpoint_manager.should_save(step_i):
                    logging.info('Saving a ckpt at step: %d', step_i)
                    if multi_host_checkpointing:
                        py_utils.sync_global_devices(
                            f'checkpointer:saving:{checkpoint_dir}:step-{step_i}'
                        )
                    if multi_host_checkpointing or jax.process_index() == 0:
                        checkpoints.save_checkpoint(
                            partitioned_train_state,
                            checkpoint_task_dir,
                            checkpoint_type=checkpoint_type,
                            state_specs=train_state_pspecs,
                            unreplicate=False)
                    checkpoint_manager.save_metadata(global_step_id=step_i)
                    if multi_host_checkpointing:
                        py_utils.sync_global_devices(
                            f'checkpointer:saved:{checkpoint_dir}:step-{step_i}'
                        )

                # Get new model inputs
                if step_i <= _N_STEPS_WARMUP_LOGGING:
                    logging.info('step=`%d`: Retrieving model inputs.', step_i)
                logging.debug('  Retrieving inputs.')
                model_inputs = train_input_pipeline.get_next()

                if jax.config.jax_parallel_functions_output_gda:
                    if step_i <= _N_STEPS_WARMUP_LOGGING:
                        start = time.time()
                    py_utils.assert_same_shape_and_dtype(
                        inputs_shape,
                        tf.nest.map_structure(
                            py_utils.get_global_input_shape_dtype,
                            model_inputs))
                    model_inputs = py_utils.create_gda(model_inputs,
                                                       inputs_shape,
                                                       global_mesh,
                                                       inputs_pspecs)
                    if step_i <= _N_STEPS_WARMUP_LOGGING:
                        logging.info('GDA train batch input creation time %s',
                                     time.time() - start)

                logging.debug('  Retrieved inputs.')

                logging.debug('  Performing train_step().')
                with jax.profiler.StepTraceAnnotation('train',
                                                      step_num=step_i):
                    (partitioned_train_state, loss, metrics, per_example_out,
                     summary_tensors) = train_step(partitioned_train_state,
                                                   train_key, model_inputs)
                logging.debug('  Completed train_step().')

                logging.debug('  Writing summaries (attempt).')
                if summary_utils.write_summary_every_n_steps(
                        partitioned_train_state,
                        train_summary_writer,
                        step_i,
                        train_p.summary_interval_steps,
                        loss,
                        metrics,
                        per_example_out,
                        summary_tensors,
                        train_p.norm_summary_interval_steps,
                        summary_last_time,
                        summary_last_step,
                        unreplicate_mdl_vars=False,
                        unreplicate_metrics=False):
                    summary_last_time = time.time()
                    summary_last_step = step_i
                    step_i = int(
                        py_utils.maybe_unreplicate_gda(
                            partitioned_train_state.step))
                else:
                    # Increment train step locally to avoid an explicit device sync.
                    step_i += 1
                logging.debug('  Wrote summaries (attempted).')

                # Run eval at regular step interval.
                if step_i % train_p.eval_interval_steps == 0:
                    logging.debug('  Starting eval_step().')
                    logging.debug('  Retrieving eval model_inputs.')
                    eval_inputs = train_input_pipeline.get_next()

                    if jax.config.jax_parallel_functions_output_gda:
                        eval_inputs = py_utils.create_gda(
                            eval_inputs, inputs_shape, global_mesh,
                            inputs_pspecs)

                    logging.debug('  Retrieved eval model_inputs.')
                    logging.debug(
                        '  Performing eval_step() runs on training split.')

                    eval_step_fn = functools.partial(
                        eval_step, partitioned_train_state.mdl_vars, eval_key,
                        partitioned_train_state.step)
                    loss, mean_metrics, summary_tensors = model_utils.run_eval_one_step(
                        eval_inputs, eval_step_fn, reshard_inputs=False)
                    logging.debug(
                        '  Completed eval_step() runs on training split.')

                    logging.info('step=`%d`', step_i)
                    logging.info('  eval loss: %s', loss)
                    logging.info('  mean_metrics: %s', mean_metrics)
                    logging.info('  summary_tensors: %s', summary_tensors)
                    if step_i % train_p.summary_interval_steps == 0:
                        logging.debug('  Writing eval summaries.')
                        summary_utils.write_summary_entry(
                            eval_summary_writer,
                            step_i,
                            loss,
                            mean_metrics,
                            summary_tensors,
                            unreplicate_metrics=False)
                        logging.debug('  Wrote eval summaries.')
                    # If we have eval test then also evaluate on test.
                    if eval_input_p is not None:
                        logging.debug(
                            '  Performing eval_step() runs on test splits.')
                        model_utils.run_eval_loop_over_test_splits(
                            eval_num_steps,
                            eval_step_fn,
                            eval_test_summary_writers,
                            step_i,
                            eval_input_pipelines,
                            eval_test_inputs_pspecs,
                            eval_test_inputs_shape,
                            global_mesh,
                            reshard_inputs=False)
                        logging.debug(
                            '  Completed eval_step() runs on test splits.')

                logging.debug('step=`%d`: End', step_i - 1)