Exemplo n.º 1
0
def save_all_states():
    """
    Invokes `save_state` on all `State` objects for which `State.skip` is True.
    This function can be used to trigger a global checkpoint and save every
    `State` in the current job.
    """
    if from_ray():
        from ray.tune.trainable import TrainableUtil
        checkpoint_dir = TrainableUtil.make_checkpoint_dir("/tmp",
                                                           index=None,
                                                           override=True)
    else:
        checkpoint_dir = checkpoint_path()
    for state in _STATES_TO_NAMES:
        save_state(state, checkpoint_dir)

    # Prevent corrupting original state files in case the process got killed
    # during state file writing.
    if replica_rank() == 0 and checkpoint_dir is not None:
        tmp_ckpt_dir = _get_tmp_ckpt_dir(checkpoint_dir)
        ckpt_dir = os.path.join(checkpoint_dir,
                                f"{CKPT_DIR_PREFIX}{num_restarts()}")
        os.rename(tmp_ckpt_dir, ckpt_dir)  # atomic, rename(src, dst)
        for dir_name in os.listdir(checkpoint_dir):
            dir_path = os.path.join(checkpoint_dir, dir_name)
            if dir_name.startswith(CKPT_DIR_PREFIX) and dir_path != ckpt_dir:
                shutil.rmtree(dir_path)
        return checkpoint_dir
Exemplo n.º 2
0
def test_accumulator_restarts():
    import adaptdl.checkpoint
    import adaptdl.collective
    from adaptdl.env import num_restarts, replica_rank
    adaptdl.collective.initialize("0.0.0.0")
    accum = Accumulator()

    if num_restarts() == 0:
        accum["a"] += 15  # Idempotent.
    assert "a" not in accum
    with accum.synchronized():
        assert "a" in accum
        assert accum["a"] == 15
    assert "a" not in accum
    if num_restarts() == 0:
        accum["a"] -= 5  # Idempotent.
        adaptdl.checkpoint.save_all_states()
        return 4  # Restart with 4 replicas.

    if num_restarts() == 1:  # Idempotent.
        accum.update({"a": replica_rank(), "b": replica_rank()})
    assert len(accum) == 0
    with accum.synchronized():
        assert len(accum) == 2
        assert accum["a"] == 16
        assert accum["b"] == 6
    assert len(accum) == 0
    if num_restarts() == 1:
        adaptdl.checkpoint.save_all_states()
        return 2  # Restart with 2 replicas.

    if num_restarts() == 2:  # Idempotent.
        accum -= {"b": 5, "c": 5}
    with accum.synchronized():
        assert accum["a"] == 16
        assert accum["b"] == -4
        assert accum["c"] == -10
        accum.clear()
    with accum.synchronized():
        assert not accum
Exemplo n.º 3
0
def save_state(state, sync=True):
    """
    Saves a `State` object to persistent storage. First invokes `State.sync` on
    all replicas if `sync` is `True` (default), and then invokes `State.save`
    on the replica of rank 0 only.

    Arguments:
        state (State): The `State` object to save to persistent storage.
        sync (bool): Whether `State.sync` should be invoked.
    """
    if sync:
        state.sync()
    if replica_rank() == 0:
        name = _STATES_TO_NAMES[state]
        if checkpoint_path() is not None:
            with open(os.path.join(checkpoint_path(), name), "wb") as f:
                state.save(f)
Exemplo n.º 4
0
def save_state(state, checkpoint_dir, sync=True):
    """
    Saves a `State` object to persistent storage. First invokes `State.sync` on
    all replicas if `sync` is `True` (default), and then invokes `State.save`
    on the replica of rank 0 only. Note that we save state to a temporary
    folder first. Then, it will be renamed to the formal checkpoint folder
    after all states are saved.

    Arguments:
        state (State): The `State` object to save to persistent storage.
        sync (bool): Whether `State.sync` should be invoked.
    """
    if sync:
        state.sync()

    if replica_rank() == 0 and checkpoint_dir is not None:
        name = _STATES_TO_NAMES[state]
        state_file = os.path.join(_get_tmp_ckpt_dir(checkpoint_dir), name)

        with open(state_file, "wb") as f:
            state.save(f)
Exemplo n.º 5
0
def save_all_states():
    """
    Invokes `save_state` on all `State` objects for which `State.skip` is True.
    This function can be used to trigger a global checkpoint and save every
    `State` in the current job.
    """
    for state in _STATES_TO_NAMES:
        save_state(state)

    # Prevent corrupting original state files in case the process got killed
    # during state file writing.
    if replica_rank() == 0 and checkpoint_path() is not None:
        tmp_ckpt_dir = _get_tmp_ckpt_dir()
        ckpt_dir = os.path.join(checkpoint_path(),
                                f"{CKPT_DIR_PREFIX}{num_restarts()}")
        os.rename(tmp_ckpt_dir, ckpt_dir)  # atomic

        for dir_name in os.listdir(checkpoint_path()):
            dir_path = os.path.join(checkpoint_path(), dir_name)
            if dir_name.startswith(CKPT_DIR_PREFIX) and dir_path != ckpt_dir:
                shutil.rmtree(dir_path)
Exemplo n.º 6
0
 def save(self, fileobj):
     assert replica_rank() == 0  # Should only be called from rank 0.
     pickle.dump(self.value, fileobj)