def write_to_store(obj, path): """Util for writing arbitrary objects to store. Args: obj: arbitrary object to be saved. path: a path to save the object. """ def dumps(obj, pickle=builtin_pickle): """Returns the bytes representation of a pickled object. Args: obj: Object to pickle. pickle: pickle library to use. """ return pickle.dumps(obj) def dump(obj, file, pickle=builtin_pickle): """Pickles an object to a file. Args: obj: Object to pickle. file: Name of, or open file-handle to file to write pickled object to. pickle: pickle library to use. """ # Because of latency involved in CNS sequential reads, it is way faster to # do f.write(dill.dumps(obj)) than dill.dump(f, obj). data = dumps(obj, pickle=pickle) with _maybe_open(file, mode="wb") as f: f.write(data) # In order to be robust to interruptions we first save checkpoint to the # temporal file and then move to actual path name. path_tmp = path + "-TEMPORARY" dump(obj, path_tmp) gfile.rename(path_tmp, path, overwrite=True)
def save_checkpoint(tree: Params, path: str, step_for_copy: Optional[int] = None) -> None: """Saves the values of JAX pytrees to disk in a NumPy `.npz` file. Args: tree: A JAX pytree to be saved. path: A path to save the checkpoint. step_for_copy: Optional integer that, when not None, will be used to save a copy of the checkpoint with the name `path-{step_for_copy}`. """ # NOTE: In general, this could be greatly simplified as follows. However, we # currently need to store the leaf names as well in order to be able to load # and reconstruct the tree directly from the checkpoint when initialized a # subset of a model from a pretrained model for fine tuning. # ``` # values, _ = jax.tree_util.tree_flatten(tree) # io_buffer = io.BytesIO() # np.savez(io_buffer, *values) # ``` names_and_vals, _ = _tree_flatten_with_names(tree) io_buffer = io.BytesIO() np.savez(io_buffer, **{k: v for k, v in names_and_vals}) # In order to be robust to interruptions during saving, we first save the # checkpoint to a temporary file, and then rename it to the actual path name. path_tmp = path + "-TEMPORARY" with gfile.GFile(path_tmp, "wb") as f: f.write(io_buffer.getvalue()) gfile.rename(path_tmp, path, overwrite=True) if step_for_copy is not None: gfile.copy(path, f"{path}-{step_for_copy:09d}", overwrite=True)
def save_checkpoint(ckpt_dir, target, step, prefix='checkpoint_', keep=1): """Save a checkpoint of the model. Attempts to be pre-emption safe by writing to temporary before a final rename and cleanup of past files. Args: ckpt_dir: str: path to store checkpoint files in. target: serializable flax object, usually a flax optimizer. step: int or float: training step number or other metric number. prefix: str: checkpoint file name prefix. keep: number of past checkpoint files to keep. Returns: Filename of saved checkpoint. """ # Write temporary checkpoint file. logging.info('Saving checkpoint at step: %s', step) ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix) ckpt_path = _checkpoint_path(ckpt_dir, step, prefix) gfile.makedirs(os.path.dirname(ckpt_path)) with gfile.GFile(ckpt_tmp_path, 'wb') as fp: fp.write(serialization.to_bytes(target)) # Rename once serialization and writing finished. gfile.rename(ckpt_tmp_path, ckpt_path) logging.info('Saved checkpoint at %s', ckpt_path) # Remove old checkpoint files. base_path = os.path.join(ckpt_dir, f'{prefix}') checkpoint_files = natural_sort(gfile.glob(base_path + '*')) if len(checkpoint_files) > keep: old_ckpts = checkpoint_files[:-keep] for path in old_ckpts: logging.info('Removing checkpoint at %s', path) gfile.remove(path) return ckpt_path
def save(data, path): """Util for checkpointing: saves jax pytree objects to the disk. These checkpoints can later be recovered with `load()`. Args: data: arbitrary jax pytree to be saved. path: a path to save the data. """ names_and_vals, _ = tree_flatten_with_names(data) io_buffer = io.BytesIO() # savez uses `seek()` API call, which is not supported by cns. Thus, we first # write the checkpoint to the temp buffer and then write it to the disk. np.savez(io_buffer, **{k: v for k, v in names_and_vals}) # In order to be robust to interruptions we first save checkpoint to the # temporal file and then move to actual path name. path_tmp = path + '-TEMPORARY' gfile.makedirs(os.path.dirname(path)) with gfile.GFile(path_tmp, 'wb') as f: f.write(io_buffer.getvalue()) gfile.rename(path_tmp, path, overwrite=True)
def save_checkpoint(ckpt_dir: Union[str, os.PathLike], target, step, prefix='checkpoint_', keep=1, overwrite=False): """Save a checkpoint of the model. Attempts to be pre-emption safe by writing to temporary before a final rename and cleanup of past files. Args: ckpt_dir: str or pathlib-like path to store checkpoint files in. target: serializable flax object, usually a flax optimizer. step: int or float: training step number or other metric number. prefix: str: checkpoint file name prefix. keep: number of past checkpoint files to keep. overwrite: overwrite existing checkpoint files if a checkpoint at the current or a later step already exits (default: False). Returns: Filename of saved checkpoint. """ ckpt_dir = os.fspath(ckpt_dir) # Pathlib -> str # Write temporary checkpoint file. logging.info('Saving checkpoint at step: %s', step) if ckpt_dir.startswith('./'): ckpt_dir = ckpt_dir[2:] # gfile.glob() can remove leading './' ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix) ckpt_path = _checkpoint_path(ckpt_dir, step, prefix) gfile.makedirs(os.path.dirname(ckpt_path)) base_path = os.path.join(ckpt_dir, prefix) checkpoint_files = gfile.glob(base_path + '*') if ckpt_path in checkpoint_files: if not overwrite: raise errors.InvalidCheckpointError(ckpt_path, step) else: checkpoint_files.append(ckpt_path) checkpoint_files = natural_sort(checkpoint_files) if checkpoint_files[-1] == ckpt_tmp_path: checkpoint_files.pop(-1) if ckpt_path != checkpoint_files[-1]: if not overwrite: raise errors.InvalidCheckpointError(ckpt_path, step) with gfile.GFile(ckpt_tmp_path, 'wb') as fp: fp.write(serialization.to_bytes(target)) # Rename once serialization and writing finished. gfile.rename(ckpt_tmp_path, ckpt_path, overwrite=overwrite) logging.info('Saved checkpoint at %s', ckpt_path) print(ckpt_path) # Remove newer checkpoints if overwrite: ind = checkpoint_files.index(ckpt_path) + 1 newer_ckpts = checkpoint_files[ind:] checkpoint_files = checkpoint_files[:ind] for path in newer_ckpts: logging.info('Removing checkpoint at %s', path) gfile.remove(path) # Remove old checkpoint files. if len(checkpoint_files) > keep: old_ckpts = checkpoint_files[:-keep] for path in old_ckpts: logging.info('Removing checkpoint at %s', path) gfile.remove(path) return ckpt_path