def save_all_states(): """ Invokes `save_state` on all `State` objects for which `State.skip` is True. This function can be used to trigger a global checkpoint and save every `State` in the current job. """ if from_ray(): from ray.tune.trainable import TrainableUtil checkpoint_dir = TrainableUtil.make_checkpoint_dir("/tmp", index=None, override=True) else: checkpoint_dir = checkpoint_path() for state in _STATES_TO_NAMES: save_state(state, checkpoint_dir) # Prevent corrupting original state files in case the process got killed # during state file writing. if replica_rank() == 0 and checkpoint_dir is not None: tmp_ckpt_dir = _get_tmp_ckpt_dir(checkpoint_dir) ckpt_dir = os.path.join(checkpoint_dir, f"{CKPT_DIR_PREFIX}{num_restarts()}") os.rename(tmp_ckpt_dir, ckpt_dir) # atomic, rename(src, dst) for dir_name in os.listdir(checkpoint_dir): dir_path = os.path.join(checkpoint_dir, dir_name) if dir_name.startswith(CKPT_DIR_PREFIX) and dir_path != ckpt_dir: shutil.rmtree(dir_path) return checkpoint_dir
def test_accumulator_restarts(): import adaptdl.checkpoint import adaptdl.collective from adaptdl.env import num_restarts, replica_rank adaptdl.collective.initialize("0.0.0.0") accum = Accumulator() if num_restarts() == 0: accum["a"] += 15 # Idempotent. assert "a" not in accum with accum.synchronized(): assert "a" in accum assert accum["a"] == 15 assert "a" not in accum if num_restarts() == 0: accum["a"] -= 5 # Idempotent. adaptdl.checkpoint.save_all_states() return 4 # Restart with 4 replicas. if num_restarts() == 1: # Idempotent. accum.update({"a": replica_rank(), "b": replica_rank()}) assert len(accum) == 0 with accum.synchronized(): assert len(accum) == 2 assert accum["a"] == 16 assert accum["b"] == 6 assert len(accum) == 0 if num_restarts() == 1: adaptdl.checkpoint.save_all_states() return 2 # Restart with 2 replicas. if num_restarts() == 2: # Idempotent. accum -= {"b": 5, "c": 5} with accum.synchronized(): assert accum["a"] == 16 assert accum["b"] == -4 assert accum["c"] == -10 accum.clear() with accum.synchronized(): assert not accum
def save_state(state, sync=True): """ Saves a `State` object to persistent storage. First invokes `State.sync` on all replicas if `sync` is `True` (default), and then invokes `State.save` on the replica of rank 0 only. Arguments: state (State): The `State` object to save to persistent storage. sync (bool): Whether `State.sync` should be invoked. """ if sync: state.sync() if replica_rank() == 0: name = _STATES_TO_NAMES[state] if checkpoint_path() is not None: with open(os.path.join(checkpoint_path(), name), "wb") as f: state.save(f)
def save_state(state, checkpoint_dir, sync=True): """ Saves a `State` object to persistent storage. First invokes `State.sync` on all replicas if `sync` is `True` (default), and then invokes `State.save` on the replica of rank 0 only. Note that we save state to a temporary folder first. Then, it will be renamed to the formal checkpoint folder after all states are saved. Arguments: state (State): The `State` object to save to persistent storage. sync (bool): Whether `State.sync` should be invoked. """ if sync: state.sync() if replica_rank() == 0 and checkpoint_dir is not None: name = _STATES_TO_NAMES[state] state_file = os.path.join(_get_tmp_ckpt_dir(checkpoint_dir), name) with open(state_file, "wb") as f: state.save(f)
def save_all_states(): """ Invokes `save_state` on all `State` objects for which `State.skip` is True. This function can be used to trigger a global checkpoint and save every `State` in the current job. """ for state in _STATES_TO_NAMES: save_state(state) # Prevent corrupting original state files in case the process got killed # during state file writing. if replica_rank() == 0 and checkpoint_path() is not None: tmp_ckpt_dir = _get_tmp_ckpt_dir() ckpt_dir = os.path.join(checkpoint_path(), f"{CKPT_DIR_PREFIX}{num_restarts()}") os.rename(tmp_ckpt_dir, ckpt_dir) # atomic for dir_name in os.listdir(checkpoint_path()): dir_path = os.path.join(checkpoint_path(), dir_name) if dir_name.startswith(CKPT_DIR_PREFIX) and dir_path != ckpt_dir: shutil.rmtree(dir_path)
def save(self, fileobj): assert replica_rank() == 0 # Should only be called from rank 0. pickle.dump(self.value, fileobj)