예제 #1
0
def restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_'):
    """Restore last/best checkpoint from checkpoints in path.

  Sorts the checkpoint files naturally, returning the highest-valued
  file, e.g.:
    ckpt_1, ckpt_2, ckpt_3 --> ckpt_3
    ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1
    ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5

  Args:
    ckpt_dir: str: directory of checkpoints to restore from.
    target: matching object to rebuild via deserialized state-dict.
    step: int: step number to load or None to load latest.
    prefix: str: name prefix of checkpoint files.

  Returns:
    Restored `target` updated from checkpoint file, or if no step specified and
    no checkpoint files present, returns the passed-in `target` unchanged.
  """
    if step:
        ckpt_path = _checkpoint_path(ckpt_dir, step, prefix)
        if not gfile.exists(ckpt_path):
            raise ValueError(f'Matching checkpoint not found: {ckpt_path}')
    else:
        glob_path = os.path.join(ckpt_dir, f'{prefix}*')
        checkpoint_files = natural_sort(gfile.glob(glob_path))
        ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix)
        checkpoint_files = [f for f in checkpoint_files if f != ckpt_tmp_path]
        if not checkpoint_files:
            return target
        ckpt_path = checkpoint_files[-1]

    logging.info('Restoring checkpoint from %s', ckpt_path)
    with gfile.GFile(ckpt_path, 'rb') as fp:
        return serialization.from_bytes(target, fp.read())
예제 #2
0
def save_checkpoint(ckpt_dir,
                    target,
                    step,
                    prefix='checkpoint_',
                    keep=1,
                    overwrite=False):
    """Save a checkpoint of the model.

  Attempts to be pre-emption safe by writing to temporary before
  a final rename and cleanup of past files.

  Args:
    ckpt_dir: str: path to store checkpoint files in.
    target: serializable flax object, usually a flax optimizer.
    step: int or float: training step number or other metric number.
    prefix: str: checkpoint file name prefix.
    keep: number of past checkpoint files to keep.
    overwrite: bool: allow overwriting when writing a checkpoint.

  Returns:
    Filename of saved checkpoint.
  """
    # Write temporary checkpoint file.
    logging.info('Saving checkpoint at step: %s', step)
    ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix)
    ckpt_path = _checkpoint_path(ckpt_dir, step, prefix)
    gfile.makedirs(os.path.dirname(ckpt_path))

    logging.info('Writing to temporary checkpoint location: %s', ckpt_tmp_path)
    with gfile.GFile(ckpt_tmp_path, 'wb') as fp:
        fp.write(serialization.to_bytes(target))

    # Rename once serialization and writing finished.
    gfile.rename(ckpt_tmp_path, ckpt_path, overwrite=overwrite)
    logging.info('Saved checkpoint at %s', ckpt_path)

    # Remove old checkpoint files.
    base_path = os.path.join(ckpt_dir, f'{prefix}')
    checkpoint_files = natural_sort(gfile.glob(base_path + '*'))
    if len(checkpoint_files) > keep:
        old_ckpts = checkpoint_files[:-keep]
        for path in old_ckpts:
            logging.info('Removing checkpoint at %s', path)
            gfile.remove(path)

    return ckpt_path
예제 #3
0
def latest_checkpoint_path(ckpt_dir, prefix):
    glob_path = os.path.join(ckpt_dir, f'{prefix}*')
    checkpoint_files = natural_sort(gfile.glob(glob_path))
    ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix)
    checkpoint_files = [f for f in checkpoint_files if f != ckpt_tmp_path]
    return checkpoint_files[-1] if checkpoint_files else None