def create_dataloader(): dataset = range(50) batch_size = 8 sampler = FastForwardSampler(SequentialSampler(dataset)) sampler.setup(batch_size) return DataLoader(dataset, sampler=sampler, batch_size=batch_size)
def create_dataloader_sampler(dset, sampler): sampler = FastForwardSampler(sampler) sampler.setup(batch_size) dl = DataLoader(dset, num_workers=num_workers, sampler=sampler, batch_size=batch_size) _add_capture_metadata_collate(dl) return dl, sampler
def _test_fast_forward_sampler_with_distributed_sampler(rank, worldsize): _setup_ddp(rank, worldsize) initial_seed = seed_everything(42) generator = torch.Generator() generator.manual_seed(initial_seed) num_workers = 2 batch_size = 4 dataset = range(30) sampler = FastForwardSampler(DistributedSampler(dataset, num_replicas=worldsize, rank=rank, seed=initial_seed)) sampler.setup(batch_size) dataloader = DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, generator=generator, sampler=sampler ) iter_dataloader = iter(dataloader) num_yielded = 0 batches = [] while True: try: batches.append(next(iter_dataloader)) num_yielded += 1 except StopIteration: break expected = torch.tensor([17, 27, 24]) if rank == 0 else torch.tensor([19, 5, 3]) assert torch.equal(batches[-1], expected) assert sampler.state_dict(num_yielded)[0]["current_iteration"] == 16 reload_state_dict = sampler.state_dict(num_yielded - 1) assert reload_state_dict[0]["current_iteration"] == 12 sampler = FastForwardSampler(DistributedSampler(dataset, num_replicas=worldsize, rank=rank, seed=initial_seed)) sampler.setup(batch_size) sampler.load_state_dict(reload_state_dict) dataloader = DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, generator=generator, sampler=sampler ) iter_dataloader = iter(dataloader) batches = [] while True: try: batches.append(next(iter_dataloader)) except StopIteration: break assert torch.equal(batches[-1], expected) assert sampler.state_dict(num_yielded)[0]["current_iteration"] == 16
def test_fast_forward_on_random_sampler(): """ This test ensures ``FastForwardSampler`` applied to ``RandomSampler`` correctly retrived the right next batch on restart. """ seed = 42 seed_everything(42) dataset = range(15) generator = torch.Generator().manual_seed(seed) values = list(RandomSampler(dataset, generator=generator)) generator = torch.Generator().manual_seed(seed) random_sampler = RandomSampler(dataset, generator=generator) sampler = FastForwardSampler(random_sampler) sampler.setup(3) batch_sampler = BatchSampler(sampler, 3, False) batch_sampler_iter = iter(batch_sampler) assert next(batch_sampler_iter) == values[:3] assert next(batch_sampler_iter) == values[3:6] assert next(batch_sampler_iter) == values[6:9] state_dict = sampler.state_dict(3) assert state_dict[0]["current_iteration"] == 9 state_dict[0]["current_iteration"] = 6 seed_everything(42) generator = torch.Generator().manual_seed(seed) random_sampler = RandomSampler(dataset, generator=generator) sampler = FastForwardSampler(random_sampler) sampler.setup(3) batch_sampler = BatchSampler(sampler, 3, False) sampler.load_state_dict(state_dict) batch_sampler_iter = iter(batch_sampler) assert next(batch_sampler_iter) == values[6:9] has_raised = False try: for _ in range(5): next(batch_sampler_iter) except StopIteration: has_raised = True assert sampler._current_iteration == 0 sampler.load_state_dict(sampler.state_dict(0)) assert has_raised
def create_dataloader(): dataset = range(50) num_workers = 2 batch_size = 8 sampler = FastForwardSampler(SequentialSampler(dataset)) sampler.setup(batch_size) dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size) dataloader.fast_forward_sampler = sampler loader_dict = { "a": [DataLoader(create_iterable_dataset(3, num_workers), num_workers=num_workers, batch_size=3), dataloader], "b": DataLoader( create_iterable_dataset(2, num_workers=1, attr_name="custom_sampler"), num_workers=0, batch_size=2 ), } apply_to_collection(loader_dict, DataLoader, Trainer._add_sampler_metadata_collate) return CombinedLoader(loader_dict)
def test_fast_forward_on_sequential_sampler(): """This test ensures ``FastForwardSampler`` applied to ``SequentialSampler`` correctly retrived the right next batch on restart.""" dataset = range(15) sequential_sampler = SequentialSampler(dataset) sampler = FastForwardSampler(sequential_sampler) sampler.setup(3) batch_sampler = BatchSampler(sampler, 3, False) batch_sampler_iter = iter(batch_sampler) assert next(batch_sampler_iter) == [0, 1, 2] assert next(batch_sampler_iter) == [3, 4, 5] state_dict = sampler.state_dict(2) assert state_dict[0]["current_iteration"] == 6 sampler.load_state_dict(state_dict) batch_sampler_iter = iter(batch_sampler) assert next(batch_sampler_iter) == [6, 7, 8]