示例#1
0
def test_dataloader_restarts():
    import adaptdl.checkpoint
    import adaptdl.collective
    from adaptdl.env import num_restarts, num_replicas
    adaptdl.collective.initialize("0.0.0.0")
    dataset_size = 100
    init_batch_size = 10
    dataset = TensorDataset(torch.rand(dataset_size))
    dataloader = AdaptiveDataLoader(dataset, batch_size=init_batch_size)
    # Load data samples in the following order:
    # 2 batches (20 samples) using 1 replica (local_bsz = 10, batch_size = 10)
    # 5 batches (60 samples) using 4 replica (local_bsz = 3, batch_size = 12)
    # 2 batches (20 samples) using 2 replica (local_bsz = 5, batch_size = 10)
    assert current_dataloader() is None
    for idx, batch in enumerate(dataloader):
        if num_restarts() == 0 and idx == 2:
            adaptdl.checkpoint.save_all_states()
            return 4  # Restart with 4 replicas.
        if num_restarts() == 1 and idx == 5:
            adaptdl.checkpoint.save_all_states()
            return 2  # Restart with 2 replicas.
        assert current_dataloader() is dataloader._elastic
        local_bsz = batch[0].size(0)
        assert dataloader.current_local_bsz == local_bsz
        assert local_bsz == math.ceil(init_batch_size / num_replicas())
        assert dataloader.current_batch_size == num_replicas() * local_bsz
    # After the last 2 batches.
    assert idx == 1
示例#2
0
def test_save_load():
    import pickle
    from adaptdl.checkpoint import (State, save_all_states,
                                    save_state, load_state)
    from adaptdl.env import replica_rank, num_restarts

    class TestState(State):
        def __init__(self, name):
            super().__init__(name)
            self.synced = False

        def sync(self):
            self.synced = True

        def save(self, fileobj):
            assert replica_rank() == 0  # Should only be called from rank 0.
            pickle.dump(self.value, fileobj)

        def load(self, fileobj):
            # Should load the correct value.
            self.value = pickle.load(fileobj)

    state_1 = TestState("state_1")
    state_2 = TestState("state_2")

    if num_restarts() == 0:
        # Save state with sync=True.
        state_1.value = 1
        save_state(state_1)
        assert state_1.synced
        return 2  # Restart with 2 replicas.
    elif num_restarts() == 1:
        load_state(state_1)
        assert state_1.value == 1
        # Save state with sync=False.
        state_2.value = 2
        save_state(state_2, sync=False)
        assert not state_2.synced
        return 2  # Restart with 2 replicas.
    elif num_restarts() == 2:
        load_state(state_1)
        assert state_1.value == 1
        load_state(state_2)
        assert state_2.value == 2
        return 2
    elif num_restarts() == 3:
        # Save all state.
        state_1.value = 10
        state_2.value = 20
        save_all_states()
        assert state_1.synced and state_2.synced
        return 2  # Restart with 2 replicas.
    elif num_restarts() == 4:
        load_state(state_1)
        load_state(state_2)
        assert state_1.value == 10
        assert state_2.value == 20
    else:
        assert False
示例#3
0
def test_profile(num_replicas):
    import adaptdl.checkpoint
    from adaptdl.env import num_restarts
    from adaptdl.torch._metrics import (
            profile_step_start, profile_sync_time,
            profile_step_commit, _metrics_state)
    if num_restarts() == 0:
        profile = _metrics_state().profile
        assert len(profile) == 0
        # Profile local_bsz=1 but don't commit.
        profile_step_start(1)
        profile_sync_time(1.0)
        # Profile local_bsz=2 and commit.
        profile_step_start(2)
        profile_sync_time(1.0)
        profile_sync_time(2.0)
        profile_step_commit()
        # Ensure profile is updated correctly.
        profile = _metrics_state().profile
        key = (1, 1, 2)
        assert len(profile) == 1
        assert profile[key]["accum_count"] == 0
        assert profile[key]["optim_count"] == 1
        assert profile[key]["optim_sync_time"] == 3.0
        assert profile[key]["optim_step_time"] > 0.0
        # Checkpoint and restart.
        adaptdl.checkpoint.save_all_states()
        return num_replicas
    elif num_restarts() == 1:
        profile = _metrics_state().profile
        # Ensure checkpoint is loaded correctly.
        key = (1, 1, 2)
        assert len(profile) == 1
        assert profile[key]["accum_count"] == 0
        assert profile[key]["optim_count"] == 1
        assert profile[key]["optim_sync_time"] == 3.0
        assert profile[key]["optim_step_time"] > 0.0
        # Profile local_bsz=3 and commit twice.
        profile_step_start(3)
        profile_sync_time(2.0)
        profile_sync_time(3.0)
        profile_step_commit()
        key = (1, num_replicas, 3)
        old_step_time = profile[key]["optim_step_time"]
        profile_step_start(3)
        profile_sync_time(3.0)
        profile_sync_time(4.0)
        profile_step_commit()
        # Ensure profile is updated correctly.
        assert len(profile) == 2
        assert profile[key]["accum_count"] == 0
        assert profile[key]["optim_count"] == 2
        assert profile[key]["optim_sync_time"] == 12.0
        assert profile[key]["optim_step_time"] > old_step_time > 0.0
示例#4
0
def test_bptt_iterator():
    import adaptdl.checkpoint
    import adaptdl.collective
    from adaptdl.env import num_restarts
    adaptdl.collective.initialize("0.0.0.0")
    # Load the iterator with 500 words
    # 1 batch (5x10) using 1 replica. Restart after one iteration.
    # 9 batches (5x5) using 2 replicas to consume remaining batches.
    TEXT = torchtext.data.Field(tokenize=get_tokenizer("basic_english"),
                                init_token='<sos>',
                                eos_token='<eos>')
    fields = [('text', TEXT)]
    examples = [torchtext.data.Example.fromlist([['The'] * 500], fields)]
    dataset = torchtext.data.Dataset(examples, fields)
    TEXT.build_vocab(dataset)
    bptt_iter = AdaptiveBPTTIterator(dataset, batch_size=10, bptt_len=5)
    for idx, batch in enumerate(bptt_iter):
        if num_restarts() == 0 and idx == 1:
            assert batch.text.shape == (5, 10)
            adaptdl.checkpoint.save_all_states()
            return 2
        if adaptdl.env.num_replicas() == 2:
            assert batch.text.shape == (5, 5) or batch.text.shape == (4, 5)
    if adaptdl.env.num_replicas() == 2:
        assert idx == 8
示例#5
0
def test_duplicate():
    from adaptdl.env import num_restarts
    from adaptdl.checkpoint import State
    state_1 = State("state_1")  # noqa: F841
    state_2 = State("state_2")  # noqa: F841
    with pytest.raises(ValueError):
        state_dup = State("state_1")  # noqa: F841
    return [2, 0][num_restarts()]
def test_accumulator_restarts():
    import adaptdl.checkpoint
    import adaptdl.collective
    from adaptdl.env import num_restarts, replica_rank
    adaptdl.collective.initialize("0.0.0.0")
    accum = Accumulator()

    if num_restarts() == 0:
        accum["a"] += 15  # Idempotent.
    assert "a" not in accum
    with accum.synchronized():
        assert "a" in accum
        assert accum["a"] == 15
    assert "a" not in accum
    if num_restarts() == 0:
        accum["a"] -= 5  # Idempotent.
        adaptdl.checkpoint.save_all_states()
        return 4  # Restart with 4 replicas.

    if num_restarts() == 1:  # Idempotent.
        accum.update({"a": replica_rank(), "b": replica_rank()})
    assert len(accum) == 0
    with accum.synchronized():
        assert len(accum) == 2
        assert accum["a"] == 16
        assert accum["b"] == 6
    assert len(accum) == 0
    if num_restarts() == 1:
        adaptdl.checkpoint.save_all_states()
        return 2  # Restart with 2 replicas.

    if num_restarts() == 2:  # Idempotent.
        accum -= {"b": 5, "c": 5}
    with accum.synchronized():
        assert accum["a"] == 16
        assert accum["b"] == -4
        assert accum["c"] == -10
        accum.clear()
    with accum.synchronized():
        assert not accum
示例#7
0
def load_state(state):
    """
    Load the given `State` object from persistent storage. If the object was
    previously saved, then State.load will be invoked with a readable file
    object to load from.

    Arguments:
        state (State): `State` object to load from persistent storage.

    Returns:
        `True` if state was previously saved and `State.load` was invoked,
        `False` otherwise.
    """
    if from_ray():
        from ray.tune import session
        checkpoint_dir = session.get_session().get_checkpoint()
    else:
        checkpoint_dir = checkpoint_path()
    if checkpoint_dir is None:
        return False

    ckpt_dirs = os.listdir(checkpoint_dir)
    if not ckpt_dirs:
        LOG.info(f"No checkpoint found in {checkpoint_dir}.")
        return False

    latest_restart_id = 0
    for dir_name in ckpt_dirs:
        if dir_name.startswith(CKPT_DIR_PREFIX):
            restart_id = int(dir_name[len(CKPT_DIR_PREFIX):])
            latest_restart_id = max(latest_restart_id, restart_id)

    if latest_restart_id != num_restarts() - 1:
        LOG.warning("Cannot find checkpoint from the last restart. "
                    f"Loading checkpoint from restart {latest_restart_id}.")

    ckpt_dir = os.path.join(checkpoint_dir,
                            f"{CKPT_DIR_PREFIX}{latest_restart_id}")
    name = _STATES_TO_NAMES[state]
    state_file = os.path.join(ckpt_dir, name)
    if not os.path.isfile(state_file):
        LOG.warning(f"Cannot find state file {state_file}.")
        return False

    with open(state_file, "rb") as f:
        state.load(f)

    return True
示例#8
0
def test_dataloader_break():
    import adaptdl.checkpoint
    import adaptdl.collective
    from adaptdl.env import num_restarts
    if num_restarts() == 0:
        return 2
    adaptdl.collective.initialize("0.0.0.0")
    dataset = TensorDataset(torch.rand(100))
    dataloader = AdaptiveDataLoader(dataset, batch_size=10)
    # Break in the middle of the first for-loop, and start another for-loop.
    assert current_dataloader() is None
    for idx, batch in enumerate(dataloader):
        assert current_dataloader() is dataloader._elastic
        if idx == 5:
            break
    assert current_dataloader() is None
    for idx, batch in enumerate(dataloader):
        assert current_dataloader() is dataloader._elastic
    assert idx == 9  # Run 10 batches total.
示例#9
0
def test_profile_accumulation(num_replicas):
    import adaptdl.checkpoint
    from adaptdl.env import num_restarts
    from adaptdl.torch._metrics import (
            profile_step_start, profile_sync_time,
            profile_step_commit, _metrics_state, _fit_perf_params)
    if num_restarts() == 0:
        profile = _metrics_state().profile
        assert len(profile) == 0
        # Profile local_bsz=1 but don't commit.
        profile_step_start(1)
        profile_sync_time(1.0)
        # Profile local_bsz=2 and commit.
        profile_step_start(2)
        profile_step_commit(accumulation_step=True)
        profile_step_start(2)
        profile_step_commit(accumulation_step=True)
        profile_step_start(2)
        profile_sync_time(4.0)
        profile_step_commit(accumulation_step=False)
        profile_step_start(5)
        profile_step_commit(accumulation_step=True)
        profile_step_start(5)
        profile_step_commit(accumulation_step=True)
        profile_step_start(5)
        profile_sync_time(6.0)
        profile_step_commit(accumulation_step=False)
        # Ensure profile is updated correctly.
        profile = _metrics_state().profile
        key = (1, 1, 2)
        assert len(profile) == 2
        assert profile[key]["accum_count"] == 2
        assert profile[key]["optim_count"] == 1
        assert profile[key]["optim_sync_time"] == 4.0
        assert profile[key]["accum_step_time"] > 0.0
        assert profile[key]["optim_step_time"] > 0.0
        profile_step_start(3)
        profile_step_commit(accumulation_step=True)
        profile_step_start(3)
        profile_step_commit(accumulation_step=True)
        # Check that fitting parameters works even
        # without a final accumulation_step=False commit
        for val in profile.values():
            # Ensure step time is at least sync time.
            val["optim_step_time"] += val["optim_sync_time"]
        _fit_perf_params()
        # Checkpoint and restart.
        adaptdl.checkpoint.save_all_states()
        return num_replicas
    elif num_restarts() == 1:
        profile = _metrics_state().profile
        # Ensure checkpoint is loaded correctly.
        key = (1, 1, 2)
        assert len(profile) == 3
        assert profile[key]["accum_count"] == 2
        assert profile[key]["optim_count"] == 1
        assert profile[key]["optim_sync_time"] == 4.0
        assert profile[key]["optim_step_time"] > 0.0
        # Profile local_bsz=3 and commit twice.
        profile_step_start(3)
        profile_sync_time(2.0)
        profile_sync_time(3.0)
        profile_step_commit()
        key = (1, num_replicas, 3)
        old_step_time = profile[key]["optim_step_time"]
        profile_step_start(3)
        profile_sync_time(3.0)
        profile_sync_time(4.0)
        profile_step_commit()
        # Ensure profile is updated correctly.
        if num_replicas == 1:
            assert len(profile) == 3
        else:
            assert len(profile) == 4
        assert profile[key]["accum_count"] == 0 if num_replicas > 1 else 2
        assert profile[key]["optim_count"] == 2
        assert profile[key]["optim_sync_time"] == 12.0
        assert profile[key]["optim_step_time"] > old_step_time > 0.0