def _test_fast_forward_sampler_with_distributed_sampler(rank, worldsize): _setup_ddp(rank, worldsize) initial_seed = seed_everything(42) generator = torch.Generator() generator.manual_seed(initial_seed) num_workers = 2 batch_size = 4 dataset = range(30) sampler = FastForwardSampler(DistributedSampler(dataset, num_replicas=worldsize, rank=rank, seed=initial_seed)) sampler.setup(batch_size) dataloader = DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, generator=generator, sampler=sampler ) iter_dataloader = iter(dataloader) num_yielded = 0 batches = [] while True: try: batches.append(next(iter_dataloader)) num_yielded += 1 except StopIteration: break expected = torch.tensor([17, 27, 24]) if rank == 0 else torch.tensor([19, 5, 3]) assert torch.equal(batches[-1], expected) assert sampler.state_dict(num_yielded)[0]["current_iteration"] == 16 reload_state_dict = sampler.state_dict(num_yielded - 1) assert reload_state_dict[0]["current_iteration"] == 12 sampler = FastForwardSampler(DistributedSampler(dataset, num_replicas=worldsize, rank=rank, seed=initial_seed)) sampler.setup(batch_size) sampler.load_state_dict(reload_state_dict) dataloader = DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, generator=generator, sampler=sampler ) iter_dataloader = iter(dataloader) batches = [] while True: try: batches.append(next(iter_dataloader)) except StopIteration: break assert torch.equal(batches[-1], expected) assert sampler.state_dict(num_yielded)[0]["current_iteration"] == 16
def test_fast_forward_on_random_sampler(): """ This test ensures ``FastForwardSampler`` applied to ``RandomSampler`` correctly retrived the right next batch on restart. """ seed = 42 seed_everything(42) dataset = range(15) generator = torch.Generator().manual_seed(seed) values = list(RandomSampler(dataset, generator=generator)) generator = torch.Generator().manual_seed(seed) random_sampler = RandomSampler(dataset, generator=generator) sampler = FastForwardSampler(random_sampler) sampler.setup(3) batch_sampler = BatchSampler(sampler, 3, False) batch_sampler_iter = iter(batch_sampler) assert next(batch_sampler_iter) == values[:3] assert next(batch_sampler_iter) == values[3:6] assert next(batch_sampler_iter) == values[6:9] state_dict = sampler.state_dict(3) assert state_dict[0]["current_iteration"] == 9 state_dict[0]["current_iteration"] = 6 seed_everything(42) generator = torch.Generator().manual_seed(seed) random_sampler = RandomSampler(dataset, generator=generator) sampler = FastForwardSampler(random_sampler) sampler.setup(3) batch_sampler = BatchSampler(sampler, 3, False) sampler.load_state_dict(state_dict) batch_sampler_iter = iter(batch_sampler) assert next(batch_sampler_iter) == values[6:9] has_raised = False try: for _ in range(5): next(batch_sampler_iter) except StopIteration: has_raised = True assert sampler._current_iteration == 0 sampler.load_state_dict(sampler.state_dict(0)) assert has_raised
def test_fast_forward_on_sequential_sampler(): """This test ensures ``FastForwardSampler`` applied to ``SequentialSampler`` correctly retrived the right next batch on restart.""" dataset = range(15) sequential_sampler = SequentialSampler(dataset) sampler = FastForwardSampler(sequential_sampler) sampler.setup(3) batch_sampler = BatchSampler(sampler, 3, False) batch_sampler_iter = iter(batch_sampler) assert next(batch_sampler_iter) == [0, 1, 2] assert next(batch_sampler_iter) == [3, 4, 5] state_dict = sampler.state_dict(2) assert state_dict[0]["current_iteration"] == 6 sampler.load_state_dict(state_dict) batch_sampler_iter = iter(batch_sampler) assert next(batch_sampler_iter) == [6, 7, 8]
def test_fast_forward_on_batch_sampler(): """This test ensures ``FastForwardSampler`` applied to ``BatchSampler`` correctly retrived the right next batch on restart.""" dataset = range(15) sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, 3, False) index_batch_sampler = FastForwardSampler(batch_sampler) assert isinstance(index_batch_sampler, Iterable) index_batch_sampler_iter = iter(index_batch_sampler) assert next(index_batch_sampler_iter) == [0, 1, 2] assert next(index_batch_sampler_iter) == [3, 4, 5] state_dict = index_batch_sampler.state_dict(2) index_batch_sampler = FastForwardSampler(batch_sampler) index_batch_sampler.load_state_dict(state_dict) index_batch_sampler_iter = iter(index_batch_sampler) assert next(index_batch_sampler_iter) == [6, 7, 8]