def _get_tmp_ckpt_dir(): if checkpoint_path() is None: return None tmp_dir = os.path.join(checkpoint_path(), "_checkpoint") os.makedirs(tmp_dir, exist_ok=True) return tmp_dir
def save_all_states(): """ Invokes `save_state` on all `State` objects for which `State.skip` is True. This function can be used to trigger a global checkpoint and save every `State` in the current job. """ if from_ray(): from ray.tune.trainable import TrainableUtil checkpoint_dir = TrainableUtil.make_checkpoint_dir("/tmp", index=None, override=True) else: checkpoint_dir = checkpoint_path() for state in _STATES_TO_NAMES: save_state(state, checkpoint_dir) # Prevent corrupting original state files in case the process got killed # during state file writing. if replica_rank() == 0 and checkpoint_dir is not None: tmp_ckpt_dir = _get_tmp_ckpt_dir(checkpoint_dir) ckpt_dir = os.path.join(checkpoint_dir, f"{CKPT_DIR_PREFIX}{num_restarts()}") os.rename(tmp_ckpt_dir, ckpt_dir) # atomic, rename(src, dst) for dir_name in os.listdir(checkpoint_dir): dir_path = os.path.join(checkpoint_dir, dir_name) if dir_name.startswith(CKPT_DIR_PREFIX) and dir_path != ckpt_dir: shutil.rmtree(dir_path) return checkpoint_dir
def save_state(state, sync=True): """ Saves a `State` object to persistent storage. First invokes `State.sync` on all replicas if `sync` is `True` (default), and then invokes `State.save` on the replica of rank 0 only. Arguments: state (State): The `State` object to save to persistent storage. sync (bool): Whether `State.sync` should be invoked. """ if sync: state.sync() if replica_rank() == 0: name = _STATES_TO_NAMES[state] if checkpoint_path() is not None: with open(os.path.join(checkpoint_path(), name), "wb") as f: state.save(f)
def load_state(state): """ Load the given `State` object from persistent storage. If the object was previously saved, then State.load will be invoked with a readable file object to load from. Arguments: state (State): `State` object to load from persistent storage. Returns: `True` if state was previously saved and `State.load` was invoked, `False` otherwise. """ if checkpoint_path() is None: return False ckpt_dirs = os.listdir(checkpoint_path()) if not ckpt_dirs: LOG.info(f"No checkpoint found in {checkpoint_path()}.") return False latest_restart_id = 0 for dir_name in ckpt_dirs: if dir_name.startswith(CKPT_DIR_PREFIX): restart_id = int(dir_name[len(CKPT_DIR_PREFIX):]) latest_restart_id = max(latest_restart_id, restart_id) if latest_restart_id != num_restarts() - 1: LOG.warning("Cannot find checkpoint from the last restart. " f"Loading checkpoint from restart {latest_restart_id}.") ckpt_dir = os.path.join(checkpoint_path(), f"{CKPT_DIR_PREFIX}{latest_restart_id}") name = _STATES_TO_NAMES[state] state_file = os.path.join(ckpt_dir, name) if not os.path.isfile(state_file): LOG.warning(f"Cannot find state file {state_file}.") return False with open(state_file, "rb") as f: state.load(f) return True
def save_all_states(): """ Invokes `save_state` on all `State` objects for which `State.skip` is True. This function can be used to trigger a global checkpoint and save every `State` in the current job. """ for state in _STATES_TO_NAMES: save_state(state) # Prevent corrupting original state files in case the process got killed # during state file writing. if replica_rank() == 0 and checkpoint_path() is not None: tmp_ckpt_dir = _get_tmp_ckpt_dir() ckpt_dir = os.path.join(checkpoint_path(), f"{CKPT_DIR_PREFIX}{num_restarts()}") os.rename(tmp_ckpt_dir, ckpt_dir) # atomic for dir_name in os.listdir(checkpoint_path()): dir_path = os.path.join(checkpoint_path(), dir_name) if dir_name.startswith(CKPT_DIR_PREFIX) and dir_path != ckpt_dir: shutil.rmtree(dir_path)
def load_state(state): """ Load the given `State` object from persistent storage. If the object was previously saved, then State.load will be invoked with a readable file object to load from. Arguments: state (State): `State` object to load from persistent storage. Returns: `True` if state was previously saved and `State.load` was invoked, `False` otherwise. """ if checkpoint_path() is None: return False try: name = _STATES_TO_NAMES[state] with open(os.path.join(checkpoint_path(), name), "rb") as f: state.load(f) return True except FileNotFoundError: return False
def save_state(state, sync=True): """ Saves a `State` object to persistent storage. First invokes `State.sync` on all replicas if `sync` is `True` (default), and then invokes `State.save` on the replica of rank 0 only. Note that we save state to a temporary folder first. Then, it will be renamed to the formal checkpoint folder after all states are saved. Arguments: state (State): The `State` object to save to persistent storage. sync (bool): Whether `State.sync` should be invoked. """ if sync: state.sync() if replica_rank() == 0 and checkpoint_path() is not None: name = _STATES_TO_NAMES[state] state_file = os.path.join(_get_tmp_ckpt_dir(), name) with open(state_file, "wb") as f: state.save(f)