Пример #1
0
def write_to_store(obj, path):
    """Util for writing arbitrary objects to store.

  Args:
    obj: arbitrary object to be saved.
    path: a path to save the object.
  """
    def dumps(obj, pickle=builtin_pickle):
        """Returns the bytes representation of a pickled object.

    Args:
      obj: Object to pickle.
      pickle: pickle library to use.
    """
        return pickle.dumps(obj)

    def dump(obj, file, pickle=builtin_pickle):
        """Pickles an object to a file.

    Args:
      obj: Object to pickle.
      file: Name of, or open file-handle to file to write pickled object to.
      pickle: pickle library to use.
    """
        # Because of latency involved in CNS sequential reads, it is way faster to
        # do f.write(dill.dumps(obj)) than dill.dump(f, obj).
        data = dumps(obj, pickle=pickle)
        with _maybe_open(file, mode="wb") as f:
            f.write(data)

    # In order to be robust to interruptions we first save checkpoint to the
    # temporal file and then move to actual path name.
    path_tmp = path + "-TEMPORARY"
    dump(obj, path_tmp)
    gfile.rename(path_tmp, path, overwrite=True)
def save_checkpoint(tree: Params,
                    path: str,
                    step_for_copy: Optional[int] = None) -> None:
    """Saves the values of JAX pytrees to disk in a NumPy `.npz` file.

  Args:
    tree: A JAX pytree to be saved.
    path: A path to save the checkpoint.
    step_for_copy: Optional integer that, when not None, will be used to save a
      copy of the checkpoint with the name `path-{step_for_copy}`.
  """
    # NOTE: In general, this could be greatly simplified as follows. However, we
    # currently need to store the leaf names as well in order to be able to load
    # and reconstruct the tree directly from the checkpoint when initialized a
    # subset of a model from a pretrained model for fine tuning.
    # ```
    # values, _ = jax.tree_util.tree_flatten(tree)
    # io_buffer = io.BytesIO()
    # np.savez(io_buffer, *values)
    # ```
    names_and_vals, _ = _tree_flatten_with_names(tree)
    io_buffer = io.BytesIO()
    np.savez(io_buffer, **{k: v for k, v in names_and_vals})

    # In order to be robust to interruptions during saving, we first save the
    # checkpoint to a temporary file, and then rename it to the actual path name.
    path_tmp = path + "-TEMPORARY"
    with gfile.GFile(path_tmp, "wb") as f:
        f.write(io_buffer.getvalue())
    gfile.rename(path_tmp, path, overwrite=True)

    if step_for_copy is not None:
        gfile.copy(path, f"{path}-{step_for_copy:09d}", overwrite=True)
Пример #3
0
def save_checkpoint(ckpt_dir,
                    target,
                    step,
                    prefix='checkpoint_',
                    keep=1):
  """Save a checkpoint of the model.

  Attempts to be pre-emption safe by writing to temporary before
  a final rename and cleanup of past files.

  Args:
    ckpt_dir: str: path to store checkpoint files in.
    target: serializable flax object, usually a flax optimizer.
    step: int or float: training step number or other metric number.
    prefix: str: checkpoint file name prefix.
    keep: number of past checkpoint files to keep.

  Returns:
    Filename of saved checkpoint.
  """
  # Write temporary checkpoint file.
  logging.info('Saving checkpoint at step: %s', step)
  ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix)
  ckpt_path = _checkpoint_path(ckpt_dir, step, prefix)
  gfile.makedirs(os.path.dirname(ckpt_path))
  with gfile.GFile(ckpt_tmp_path, 'wb') as fp:
    fp.write(serialization.to_bytes(target))

  # Rename once serialization and writing finished.
  gfile.rename(ckpt_tmp_path, ckpt_path)
  logging.info('Saved checkpoint at %s', ckpt_path)

  # Remove old checkpoint files.
  base_path = os.path.join(ckpt_dir, f'{prefix}')
  checkpoint_files = natural_sort(gfile.glob(base_path + '*'))
  if len(checkpoint_files) > keep:
    old_ckpts = checkpoint_files[:-keep]
    for path in old_ckpts:
      logging.info('Removing checkpoint at %s', path)
      gfile.remove(path)

  return ckpt_path
def save(data, path):
    """Util for checkpointing: saves jax pytree objects to the disk.

  These checkpoints can later be recovered with `load()`.

  Args:
    data: arbitrary jax pytree to be saved.
    path: a path to save the data.
  """
    names_and_vals, _ = tree_flatten_with_names(data)
    io_buffer = io.BytesIO()

    # savez uses `seek()` API call, which is not supported by cns. Thus, we first
    # write the checkpoint to the temp buffer and then write it to the disk.
    np.savez(io_buffer, **{k: v for k, v in names_and_vals})

    # In order to be robust to interruptions we first save checkpoint to the
    # temporal file and then move to actual path name.
    path_tmp = path + '-TEMPORARY'
    gfile.makedirs(os.path.dirname(path))
    with gfile.GFile(path_tmp, 'wb') as f:
        f.write(io_buffer.getvalue())
    gfile.rename(path_tmp, path, overwrite=True)
Пример #5
0
def save_checkpoint(ckpt_dir: Union[str, os.PathLike],
                    target,
                    step,
                    prefix='checkpoint_',
                    keep=1,
                    overwrite=False):
    """Save a checkpoint of the model.

  Attempts to be pre-emption safe by writing to temporary before
  a final rename and cleanup of past files.

  Args:
    ckpt_dir: str or pathlib-like path to store checkpoint files in.
    target: serializable flax object, usually a flax optimizer.
    step: int or float: training step number or other metric number.
    prefix: str: checkpoint file name prefix.
    keep: number of past checkpoint files to keep.
    overwrite: overwrite existing checkpoint files if a checkpoint
      at the current or a later step already exits (default: False).
  Returns:
    Filename of saved checkpoint.
  """
    ckpt_dir = os.fspath(ckpt_dir)  # Pathlib -> str
    # Write temporary checkpoint file.
    logging.info('Saving checkpoint at step: %s', step)
    if ckpt_dir.startswith('./'):
        ckpt_dir = ckpt_dir[2:]  # gfile.glob() can remove leading './'
    ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix)
    ckpt_path = _checkpoint_path(ckpt_dir, step, prefix)
    gfile.makedirs(os.path.dirname(ckpt_path))
    base_path = os.path.join(ckpt_dir, prefix)
    checkpoint_files = gfile.glob(base_path + '*')

    if ckpt_path in checkpoint_files:
        if not overwrite:
            raise errors.InvalidCheckpointError(ckpt_path, step)
    else:
        checkpoint_files.append(ckpt_path)

    checkpoint_files = natural_sort(checkpoint_files)
    if checkpoint_files[-1] == ckpt_tmp_path:
        checkpoint_files.pop(-1)
    if ckpt_path != checkpoint_files[-1]:
        if not overwrite:
            raise errors.InvalidCheckpointError(ckpt_path, step)

    with gfile.GFile(ckpt_tmp_path, 'wb') as fp:
        fp.write(serialization.to_bytes(target))

    # Rename once serialization and writing finished.
    gfile.rename(ckpt_tmp_path, ckpt_path, overwrite=overwrite)
    logging.info('Saved checkpoint at %s', ckpt_path)
    print(ckpt_path)

    # Remove newer checkpoints
    if overwrite:
        ind = checkpoint_files.index(ckpt_path) + 1
        newer_ckpts = checkpoint_files[ind:]
        checkpoint_files = checkpoint_files[:ind]
        for path in newer_ckpts:
            logging.info('Removing checkpoint at %s', path)
            gfile.remove(path)

    # Remove old checkpoint files.
    if len(checkpoint_files) > keep:
        old_ckpts = checkpoint_files[:-keep]
        for path in old_ckpts:
            logging.info('Removing checkpoint at %s', path)
            gfile.remove(path)

    return ckpt_path