def _test_fast_forward_sampler_with_distributed_sampler(rank, worldsize):
    _setup_ddp(rank, worldsize)

    initial_seed = seed_everything(42)

    generator = torch.Generator()
    generator.manual_seed(initial_seed)

    num_workers = 2
    batch_size = 4

    dataset = range(30)
    sampler = FastForwardSampler(DistributedSampler(dataset, num_replicas=worldsize, rank=rank, seed=initial_seed))
    sampler.setup(batch_size)
    dataloader = DataLoader(
        dataset, batch_size=batch_size, num_workers=num_workers, generator=generator, sampler=sampler
    )

    iter_dataloader = iter(dataloader)

    num_yielded = 0
    batches = []
    while True:
        try:
            batches.append(next(iter_dataloader))
            num_yielded += 1
        except StopIteration:
            break

    expected = torch.tensor([17, 27, 24]) if rank == 0 else torch.tensor([19, 5, 3])
    assert torch.equal(batches[-1], expected)

    assert sampler.state_dict(num_yielded)[0]["current_iteration"] == 16

    reload_state_dict = sampler.state_dict(num_yielded - 1)
    assert reload_state_dict[0]["current_iteration"] == 12

    sampler = FastForwardSampler(DistributedSampler(dataset, num_replicas=worldsize, rank=rank, seed=initial_seed))
    sampler.setup(batch_size)
    sampler.load_state_dict(reload_state_dict)
    dataloader = DataLoader(
        dataset, batch_size=batch_size, num_workers=num_workers, generator=generator, sampler=sampler
    )

    iter_dataloader = iter(dataloader)

    batches = []
    while True:
        try:
            batches.append(next(iter_dataloader))
        except StopIteration:
            break

    assert torch.equal(batches[-1], expected)
    assert sampler.state_dict(num_yielded)[0]["current_iteration"] == 16
def test_fast_forward_on_random_sampler():
    """
    This test ensures ``FastForwardSampler`` applied to ``RandomSampler`` correctly retrived
    the right next batch on restart.
    """
    seed = 42
    seed_everything(42)

    dataset = range(15)
    generator = torch.Generator().manual_seed(seed)
    values = list(RandomSampler(dataset, generator=generator))

    generator = torch.Generator().manual_seed(seed)
    random_sampler = RandomSampler(dataset, generator=generator)
    sampler = FastForwardSampler(random_sampler)
    sampler.setup(3)
    batch_sampler = BatchSampler(sampler, 3, False)

    batch_sampler_iter = iter(batch_sampler)

    assert next(batch_sampler_iter) == values[:3]
    assert next(batch_sampler_iter) == values[3:6]
    assert next(batch_sampler_iter) == values[6:9]

    state_dict = sampler.state_dict(3)
    assert state_dict[0]["current_iteration"] == 9
    state_dict[0]["current_iteration"] = 6

    seed_everything(42)
    generator = torch.Generator().manual_seed(seed)
    random_sampler = RandomSampler(dataset, generator=generator)
    sampler = FastForwardSampler(random_sampler)
    sampler.setup(3)
    batch_sampler = BatchSampler(sampler, 3, False)
    sampler.load_state_dict(state_dict)

    batch_sampler_iter = iter(batch_sampler)
    assert next(batch_sampler_iter) == values[6:9]
    has_raised = False
    try:
        for _ in range(5):
            next(batch_sampler_iter)
    except StopIteration:
        has_raised = True
        assert sampler._current_iteration == 0
        sampler.load_state_dict(sampler.state_dict(0))
    assert has_raised
Example #3
0
def test_fast_forward_on_sequential_sampler():
    """This test ensures ``FastForwardSampler`` applied to ``SequentialSampler`` correctly retrived the right next
    batch on restart."""
    dataset = range(15)
    sequential_sampler = SequentialSampler(dataset)
    sampler = FastForwardSampler(sequential_sampler)
    sampler.setup(3)
    batch_sampler = BatchSampler(sampler, 3, False)

    batch_sampler_iter = iter(batch_sampler)

    assert next(batch_sampler_iter) == [0, 1, 2]
    assert next(batch_sampler_iter) == [3, 4, 5]

    state_dict = sampler.state_dict(2)
    assert state_dict[0]["current_iteration"] == 6

    sampler.load_state_dict(state_dict)

    batch_sampler_iter = iter(batch_sampler)
    assert next(batch_sampler_iter) == [6, 7, 8]
Example #4
0
def test_fast_forward_on_batch_sampler():
    """This test ensures ``FastForwardSampler`` applied to ``BatchSampler`` correctly retrived the right next batch
    on restart."""
    dataset = range(15)
    sampler = SequentialSampler(dataset)
    batch_sampler = BatchSampler(sampler, 3, False)
    index_batch_sampler = FastForwardSampler(batch_sampler)

    assert isinstance(index_batch_sampler, Iterable)

    index_batch_sampler_iter = iter(index_batch_sampler)

    assert next(index_batch_sampler_iter) == [0, 1, 2]
    assert next(index_batch_sampler_iter) == [3, 4, 5]

    state_dict = index_batch_sampler.state_dict(2)

    index_batch_sampler = FastForwardSampler(batch_sampler)
    index_batch_sampler.load_state_dict(state_dict)

    index_batch_sampler_iter = iter(index_batch_sampler)
    assert next(index_batch_sampler_iter) == [6, 7, 8]