def restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_'): """Restore last/best checkpoint from checkpoints in path. Sorts the checkpoint files naturally, returning the highest-valued file, e.g.: ckpt_1, ckpt_2, ckpt_3 --> ckpt_3 ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1 ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5 Args: ckpt_dir: str: directory of checkpoints to restore from. target: matching object to rebuild via deserialized state-dict. step: int: step number to load or None to load latest. prefix: str: name prefix of checkpoint files. Returns: Restored `target` updated from checkpoint file, or if no step specified and no checkpoint files present, returns the passed-in `target` unchanged. """ if step: ckpt_path = _checkpoint_path(ckpt_dir, step, prefix) if not gfile.exists(ckpt_path): raise ValueError(f'Matching checkpoint not found: {ckpt_path}') else: glob_path = os.path.join(ckpt_dir, f'{prefix}*') checkpoint_files = natural_sort(gfile.glob(glob_path)) ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix) checkpoint_files = [f for f in checkpoint_files if f != ckpt_tmp_path] if not checkpoint_files: return target ckpt_path = checkpoint_files[-1] logging.info('Restoring checkpoint from %s', ckpt_path) with gfile.GFile(ckpt_path, 'rb') as fp: return serialization.from_bytes(target, fp.read())
def save_checkpoint(ckpt_dir, target, step, prefix='checkpoint_', keep=1, overwrite=False): """Save a checkpoint of the model. Attempts to be pre-emption safe by writing to temporary before a final rename and cleanup of past files. Args: ckpt_dir: str: path to store checkpoint files in. target: serializable flax object, usually a flax optimizer. step: int or float: training step number or other metric number. prefix: str: checkpoint file name prefix. keep: number of past checkpoint files to keep. overwrite: bool: allow overwriting when writing a checkpoint. Returns: Filename of saved checkpoint. """ # Write temporary checkpoint file. logging.info('Saving checkpoint at step: %s', step) ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix) ckpt_path = _checkpoint_path(ckpt_dir, step, prefix) gfile.makedirs(os.path.dirname(ckpt_path)) logging.info('Writing to temporary checkpoint location: %s', ckpt_tmp_path) with gfile.GFile(ckpt_tmp_path, 'wb') as fp: fp.write(serialization.to_bytes(target)) # Rename once serialization and writing finished. gfile.rename(ckpt_tmp_path, ckpt_path, overwrite=overwrite) logging.info('Saved checkpoint at %s', ckpt_path) # Remove old checkpoint files. base_path = os.path.join(ckpt_dir, f'{prefix}') checkpoint_files = natural_sort(gfile.glob(base_path + '*')) if len(checkpoint_files) > keep: old_ckpts = checkpoint_files[:-keep] for path in old_ckpts: logging.info('Removing checkpoint at %s', path) gfile.remove(path) return ckpt_path
def latest_checkpoint_path(ckpt_dir, prefix): glob_path = os.path.join(ckpt_dir, f'{prefix}*') checkpoint_files = natural_sort(gfile.glob(glob_path)) ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix) checkpoint_files = [f for f in checkpoint_files if f != ckpt_tmp_path] return checkpoint_files[-1] if checkpoint_files else None