Ejemplo n.º 1
0
def _get_tmp_ckpt_dir():
    if checkpoint_path() is None:
        return None

    tmp_dir = os.path.join(checkpoint_path(), "_checkpoint")
    os.makedirs(tmp_dir, exist_ok=True)
    return tmp_dir
Ejemplo n.º 2
0
def save_all_states():
    """
    Invokes `save_state` on all `State` objects for which `State.skip` is True.
    This function can be used to trigger a global checkpoint and save every
    `State` in the current job.
    """
    if from_ray():
        from ray.tune.trainable import TrainableUtil
        checkpoint_dir = TrainableUtil.make_checkpoint_dir("/tmp",
                                                           index=None,
                                                           override=True)
    else:
        checkpoint_dir = checkpoint_path()
    for state in _STATES_TO_NAMES:
        save_state(state, checkpoint_dir)

    # Prevent corrupting original state files in case the process got killed
    # during state file writing.
    if replica_rank() == 0 and checkpoint_dir is not None:
        tmp_ckpt_dir = _get_tmp_ckpt_dir(checkpoint_dir)
        ckpt_dir = os.path.join(checkpoint_dir,
                                f"{CKPT_DIR_PREFIX}{num_restarts()}")
        os.rename(tmp_ckpt_dir, ckpt_dir)  # atomic, rename(src, dst)
        for dir_name in os.listdir(checkpoint_dir):
            dir_path = os.path.join(checkpoint_dir, dir_name)
            if dir_name.startswith(CKPT_DIR_PREFIX) and dir_path != ckpt_dir:
                shutil.rmtree(dir_path)
        return checkpoint_dir
Ejemplo n.º 3
0
def save_state(state, sync=True):
    """
    Saves a `State` object to persistent storage. First invokes `State.sync` on
    all replicas if `sync` is `True` (default), and then invokes `State.save`
    on the replica of rank 0 only.

    Arguments:
        state (State): The `State` object to save to persistent storage.
        sync (bool): Whether `State.sync` should be invoked.
    """
    if sync:
        state.sync()
    if replica_rank() == 0:
        name = _STATES_TO_NAMES[state]
        if checkpoint_path() is not None:
            with open(os.path.join(checkpoint_path(), name), "wb") as f:
                state.save(f)
Ejemplo n.º 4
0
def load_state(state):
    """
    Load the given `State` object from persistent storage. If the object was
    previously saved, then State.load will be invoked with a readable file
    object to load from.

    Arguments:
        state (State): `State` object to load from persistent storage.

    Returns:
        `True` if state was previously saved and `State.load` was invoked,
        `False` otherwise.
    """
    if checkpoint_path() is None:
        return False

    ckpt_dirs = os.listdir(checkpoint_path())
    if not ckpt_dirs:
        LOG.info(f"No checkpoint found in {checkpoint_path()}.")
        return False

    latest_restart_id = 0
    for dir_name in ckpt_dirs:
        if dir_name.startswith(CKPT_DIR_PREFIX):
            restart_id = int(dir_name[len(CKPT_DIR_PREFIX):])
            latest_restart_id = max(latest_restart_id, restart_id)

    if latest_restart_id != num_restarts() - 1:
        LOG.warning("Cannot find checkpoint from the last restart. "
                    f"Loading checkpoint from restart {latest_restart_id}.")

    ckpt_dir = os.path.join(checkpoint_path(),
                            f"{CKPT_DIR_PREFIX}{latest_restart_id}")
    name = _STATES_TO_NAMES[state]
    state_file = os.path.join(ckpt_dir, name)
    if not os.path.isfile(state_file):
        LOG.warning(f"Cannot find state file {state_file}.")
        return False

    with open(state_file, "rb") as f:
        state.load(f)

    return True
Ejemplo n.º 5
0
def save_all_states():
    """
    Invokes `save_state` on all `State` objects for which `State.skip` is True.
    This function can be used to trigger a global checkpoint and save every
    `State` in the current job.
    """
    for state in _STATES_TO_NAMES:
        save_state(state)

    # Prevent corrupting original state files in case the process got killed
    # during state file writing.
    if replica_rank() == 0 and checkpoint_path() is not None:
        tmp_ckpt_dir = _get_tmp_ckpt_dir()
        ckpt_dir = os.path.join(checkpoint_path(),
                                f"{CKPT_DIR_PREFIX}{num_restarts()}")
        os.rename(tmp_ckpt_dir, ckpt_dir)  # atomic

        for dir_name in os.listdir(checkpoint_path()):
            dir_path = os.path.join(checkpoint_path(), dir_name)
            if dir_name.startswith(CKPT_DIR_PREFIX) and dir_path != ckpt_dir:
                shutil.rmtree(dir_path)
Ejemplo n.º 6
0
def load_state(state):
    """
    Load the given `State` object from persistent storage. If the object was
    previously saved, then State.load will be invoked with a readable file
    object to load from.

    Arguments:
        state (State): `State` object to load from persistent storage.

    Returns:
        `True` if state was previously saved and `State.load` was invoked,
        `False` otherwise.
    """
    if checkpoint_path() is None:
        return False
    try:
        name = _STATES_TO_NAMES[state]
        with open(os.path.join(checkpoint_path(), name), "rb") as f:
            state.load(f)
        return True
    except FileNotFoundError:
        return False
Ejemplo n.º 7
0
def save_state(state, sync=True):
    """
    Saves a `State` object to persistent storage. First invokes `State.sync` on
    all replicas if `sync` is `True` (default), and then invokes `State.save`
    on the replica of rank 0 only. Note that we save state to a temporary
    folder first. Then, it will be renamed to the formal checkpoint folder
    after all states are saved.

    Arguments:
        state (State): The `State` object to save to persistent storage.
        sync (bool): Whether `State.sync` should be invoked.
    """
    if sync:
        state.sync()

    if replica_rank() == 0 and checkpoint_path() is not None:
        name = _STATES_TO_NAMES[state]
        state_file = os.path.join(_get_tmp_ckpt_dir(), name)

        with open(state_file, "wb") as f:
            state.save(f)