def test_combined_data_loader_with_max_size_cycle_and_ddp(accelerator, replace_sampler_ddp): """This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader with ddp and `max_size_cycle` mode.""" trainer = Trainer(strategy="ddp", accelerator=accelerator, devices=2, replace_sampler_ddp=replace_sampler_ddp) dataloader = CombinedLoader( {"a": DataLoader(RandomDataset(32, 8), batch_size=1), "b": DataLoader(RandomDataset(32, 8), batch_size=1)}, ) dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False) assert len(dataloader) == 4 if replace_sampler_ddp else 8 for a_length in [6, 8, 10]: dataloader = CombinedLoader( { "a": DataLoader(range(a_length), batch_size=1), "b": DataLoader(range(8), batch_size=1), }, mode="max_size_cycle", ) length = max(a_length, 8) assert len(dataloader) == length dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False) assert len(dataloader) == length // 2 if replace_sampler_ddp else length if replace_sampler_ddp: last_batch = list(dataloader)[-1] if a_length == 6: assert last_batch == {"a": torch.tensor([0]), "b": torch.tensor([6])} elif a_length == 8: assert last_batch == {"a": torch.tensor([6]), "b": torch.tensor([6])} elif a_length == 10: assert last_batch == {"a": torch.tensor([8]), "b": torch.tensor([0])} class InfiniteDataset(IterableDataset): def __iter__(self): while True: yield 1 dataloader = CombinedLoader( { "a": DataLoader(InfiniteDataset(), batch_size=1), "b": DataLoader(range(8), batch_size=1), }, mode="max_size_cycle", ) assert get_len(dataloader) == float("inf") assert len(dataloader.loaders["b"].loader) == 8 dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False) assert len(dataloader.loaders["b"].loader) == 4 if replace_sampler_ddp else 8 assert get_len(dataloader) == float("inf")
def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloaders): """Test `CombinedLoader` of mode 'min_size' given sequence loaders.""" if use_multiple_dataloaders: loaders = [ torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2), torch.utils.data.DataLoader(TestIterableDataset(20), batch_size=2), ] else: loaders = [ torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2), ] combined_loader = CombinedLoader(loaders, mode) has_break = False for idx, item in enumerate(combined_loader): assert isinstance(item, Sequence) assert len(item) == 2 if use_multiple_dataloaders else 1 if not use_multiple_dataloaders and idx == 4: has_break = True break if mode == "max_size_cycle": assert combined_loader.loaders[0].state.done == (not has_break) expected = (10 if mode == "max_size_cycle" else 5) if use_multiple_dataloaders else 5 assert (expected - 1) == idx, (mode, use_multiple_dataloaders)
def test_combined_data_loader_validation_test(cuda_available_mock, device_count_mock, tmpdir): """This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader.""" class CustomDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, index): return self.data[index] dataloader = CombinedLoader( { "a": DataLoader(CustomDataset(range(10))), "b": {"c": DataLoader(CustomDataset(range(10))), "d": DataLoader(CustomDataset(range(10)))}, "e": [DataLoader(CustomDataset(range(10))), DataLoader(CustomDataset(range(10)))], } ) trainer = Trainer(replace_sampler_ddp=True, accelerator="ddp", gpus=2) dataloader = trainer.auto_add_sampler(dataloader, shuffle=True) _count = 0 def _assert_distributed_sampler(v): nonlocal _count _count += 1 assert isinstance(v, DistributedSampler) apply_to_collection(dataloader.sampler, Sampler, _assert_distributed_sampler) assert _count == 5
def predict_dataloader(self): return CombinedLoader({ "a": DataLoader(RandomDataset(32, 8), batch_size=2), "b": DataLoader(RandomDataset(32, 8), batch_size=4), })
def test_combined_loader_sequence_with_map_and_iterable(lengths): class MyIterableDataset(IterableDataset): def __init__(self, size: int = 10): self.size = size def __iter__(self): self.sampler = SequentialSampler(range(self.size)) self.iter_sampler = iter(self.sampler) return self def __next__(self): return next(self.iter_sampler) class MyMapDataset(Dataset): def __init__(self, size: int = 10): self.size = size def __getitem__(self, index): return index def __len__(self): return self.size x, y = lengths loaders = [DataLoader(MyIterableDataset(x)), DataLoader(MyMapDataset(y))] dataloader = CombinedLoader(loaders, mode="max_size_cycle") counter = 0 for _ in dataloader: counter += 1 assert counter == max(x, y)
def val_dataloader(self): val_dataloader_head = DataLoader( TestDataset( self.val_triples, self.train_triples + self.val_triples + self.test_triples, len(self.entity2id), len(self.relation2id), "head-batch", ), batch_size=self.val_batch_size, shuffle=False, num_workers=self.num_workers, collate_fn=TestDataset.collate_fn, drop_last=True, pin_memory=True, ) val_dataloader_tail = DataLoader( TestDataset( self.val_triples, self.train_triples + self.val_triples + self.test_triples, len(self.entity2id), len(self.relation2id), "tail-batch", ), batch_size=self.val_batch_size, shuffle=False, num_workers=self.num_workers, collate_fn=TestDataset.collate_fn, drop_last=True, pin_memory=True, ) return CombinedLoader([val_dataloader_head, val_dataloader_tail])
def test_prefetch_iterator(use_combined_loader, dataset_cls, prefetch_batches): fetcher = DataFetcher(prefetch_batches=prefetch_batches) assert fetcher.prefetch_batches == prefetch_batches if use_combined_loader: loader = CombinedLoader( [DataLoader(dataset_cls()), DataLoader(dataset_cls())]) else: loader = DataLoader(dataset_cls()) fetcher.setup(loader) def generate(): generated = [(fetcher.fetched, data, fetcher.done) for data in fetcher] assert fetcher.fetched == 3 assert fetcher.done return generated # we can only know the last batch with sized iterables or when we prefetch is_last_batch = [ False, False, prefetch_batches > 0 or dataset_cls is SizedDataset ] fetched = list(range(prefetch_batches + 1, 4)) fetched += [3] * (3 - len(fetched)) batches = [[1, 1], [2, 2], [3, 3]] if use_combined_loader else [1, 2, 3] expected = list(zip(fetched, batches, is_last_batch)) assert len(expected) == 3 assert generate() == expected # validate reset works properly. assert generate() == expected assert fetcher.fetched == 3
def train_dataloader(self): loader_l = DataLoader(self.train_l, self.batch_size, shuffle=True) loader_u = DataLoader(self.train_u, self.batch_size, shuffle=True) loader_real = DataLoader(self.train, self.batch_size, shuffle=True) loaders = {"u": loader_u, "l": loader_l, "real": loader_real} combined_loaders = CombinedLoader(loaders, "max_size_cycle") return combined_loaders
def test_combined_dataloader_for_training_with_ddp( replace_sampler_ddp: bool, is_min_size_mode: bool, use_combined_loader: bool ): """When providing a CombinedLoader as the training data, it should be correctly receive the distributed samplers.""" mode = "min_size" if is_min_size_mode else "max_size_cycle" dim = 3 n1 = 8 n2 = 6 dataloader = { "a": DataLoader(RandomDataset(dim, n1), batch_size=1), "b": DataLoader(RandomDataset(dim, n2), batch_size=1), } if use_combined_loader: dataloader = CombinedLoader(dataloader, mode=mode) expected_length_before_ddp = min(n1, n2) if is_min_size_mode else max(n1, n2) expected_length_after_ddp = expected_length_before_ddp // 2 if replace_sampler_ddp else expected_length_before_ddp model = BoringModel() trainer = Trainer( strategy="ddp", accelerator="auto", devices=2, replace_sampler_ddp=replace_sampler_ddp, multiple_trainloader_mode="max_size_cycle" if use_combined_loader else mode, ) trainer._data_connector.attach_data( model=model, train_dataloaders=dataloader, val_dataloaders=None, datamodule=None ) trainer.reset_train_dataloader(model=model) assert trainer.train_dataloader is not None assert isinstance(trainer.train_dataloader, CombinedLoader) assert trainer.train_dataloader.mode == mode assert trainer.num_training_batches == expected_length_after_ddp
def test_combined_loader_loader_type_error(): """Test the ValueError when wrapping the loaders.""" with pytest.raises( TypeError, match= "Expected data to be int, Sequence or Mapping, but got NoneType"): CombinedLoader(None, "max_size_cycle")
def val_dataloader(self): loader_test = DataLoader(self.test_dataset, int(len(self.test_dataset) / 10)) loader_u = DataLoader(self.train_dataset_u, int(len(self.train_dataset_u) / 10)) loaders = {"u": loader_u, "test": loader_test} combined_loaders = CombinedLoader(loaders, "max_size_cycle") return combined_loaders
def test_prefetch_iterator(use_combined_loader): """Test the DataFetcher with PyTorch IterableDataset.""" class IterDataset(IterableDataset): def __iter__(self): yield 1 yield 2 yield 3 for prefetch_batches in range(5): iterator = DataFetcher(prefetch_batches=prefetch_batches) assert iterator.prefetch_batches == prefetch_batches if use_combined_loader: loader = CombinedLoader( [DataLoader(IterDataset()), DataLoader(IterDataset())]) else: loader = DataLoader(IterDataset()) iterator.setup(loader) def generate(): generated = [ (iterator.fetched, *data) for i, data in enumerate(iterator, prefetch_batches + 1) ] assert iterator.fetched == 3 assert iterator.done return generated is_last_batch = [False, False, prefetch_batches > 0] fetched = list(range(prefetch_batches + 1, 4)) fetched += [3] * (3 - len(fetched)) if use_combined_loader: batches = [[tensor(1), tensor(1)], [tensor(2), tensor(2)], [tensor(3), tensor(3)]] else: batches = [1, 2, 3] expected = list(zip(fetched, batches, is_last_batch)) assert len(expected) == 3 assert generate() == expected # validate reset works properly. assert generate() == expected assert iterator.fetched == 3 class EmptyIterDataset(IterableDataset): def __iter__(self): return iter([]) loader = DataLoader(EmptyIterDataset()) iterator = DataFetcher() iterator.setup(loader) assert not list(iterator)
def val_dataloader(self, *args: Any, **kwargs: Any) -> CombinedLoader: # type: ignore """ The val dataloader """ dataloaders = { SSLDataModuleType.ENCODER: self.encoder_module.val_dataloader(), SSLDataModuleType.LINEAR_HEAD: self.linear_head_module.val_dataloader() } return CombinedLoader(dataloaders, mode="max_size_cycle")
def get_combined_loader(self, encoder_loader: Sized, linear_head_loader: Sized) -> CombinedLoader: """ Creates a CombinedLoader from the data loaders for the encoder and the linear head. The cycle mode is chosen such that in all cases the encoder dataset is only cycled through once. :param encoder_loader: The dataloader to use for the SSL encoder. :param linear_head_loader: The dataloader to use for the linear head. """ mode = self._cycle_mode(len(encoder_loader), len(linear_head_loader)) dataloaders = { SSLDataModuleType.ENCODER: encoder_loader, SSLDataModuleType.LINEAR_HEAD: linear_head_loader } return CombinedLoader(dataloaders, mode=mode)
def test_prefetch_iterator(use_combined_loader): """Test the DataFetcher with PyTorch IterableDataset.""" class IterDataset(IterableDataset): def __iter__(self): yield 1 yield 2 yield 3 for prefetch_batches in range(0, 4): if use_combined_loader: loader = CombinedLoader( [DataLoader(IterDataset()), DataLoader(IterDataset())]) expected = [ ([tensor([1]), tensor([1])], False), ([tensor([2]), tensor([2])], False), ([tensor([3]), tensor([3])], True), ] else: loader = DataLoader(IterDataset()) expected = [(1, False), (2, False), (3, True)] iterator = DataFetcher(prefetch_batches=prefetch_batches) prefetch_batches += 1 assert iterator.prefetch_batches == prefetch_batches iterator.setup(loader) def generate(): generated = [] for idx, data in enumerate(iterator, 1): if iterator.done: assert iterator.fetched == 3 else: assert iterator.fetched == (idx + prefetch_batches) generated.append(data) return generated assert generate() == expected # validate reset works properly. assert generate() == expected assert iterator.fetched == 3 class EmptyIterDataset(IterableDataset): def __iter__(self): return iter([]) dataloader = DataLoader(EmptyIterDataset()) iterator = DataFetcher() iterator.setup(dataloader) assert list(iterator) == []
def test_combined_loader_sequence_max_size_cycle(): """Test `CombinedLoader` of mode 'max_size_cycle' given sequence loaders.""" loaders = [ torch.utils.data.DataLoader(range(10), batch_size=4), torch.utils.data.DataLoader(range(20), batch_size=5), ] combined_loader = CombinedLoader(loaders, "max_size_cycle") assert len(combined_loader) == max(len(v) for v in loaders) for idx, item in enumerate(combined_loader): assert isinstance(item, Sequence) assert len(item) == 2 assert idx == len(combined_loader) - 1
def val_dataloader(self): dataset_root = self.hydra_conf["trainer"]["dataset_root"] dataset_paths = self.hydra_conf["trainer"]["valid_dataset"].split("*") scene_dataset, img_dataset = self.setup_dataset(dataset_root, dataset_paths) self.valid_scene_dataset = torch.utils.data.ConcatDataset(scene_dataset) self.valid_img_dataset = torch.utils.data.ConcatDataset(img_dataset) valid_scene_sampler = My_ddp_sampler2(self.valid_scene_dataset, self.batch_size, v_sample_mode="internal", shuffle=False) valid_img_sampler = My_ddp_sampler2(self.valid_img_dataset, self.batch_size, v_sample_mode="internal", shuffle=False) if self.involved_imgs: combined_dataset = { "scene": DataLoader(self.valid_scene_dataset, batch_size=self.batch_size, num_workers=self.scene_worker, shuffle=False, pin_memory=True, collate_fn=self.dataset_builder.collate_fn, sampler=valid_scene_sampler, persistent_workers=True ), "img": DataLoader(self.valid_img_dataset, batch_size=self.batch_size, num_workers=self.img_worker, shuffle=False, pin_memory=True, collate_fn=self.dataset_builder.collate_fn, sampler=valid_img_sampler, persistent_workers=True )} assert len(combined_dataset["scene"]) == len(combined_dataset["img"]) else: combined_dataset = { "scene": DataLoader(self.valid_scene_dataset, batch_size=self.batch_size, num_workers=self.scene_worker, shuffle=False, pin_memory=True, collate_fn=self.dataset_builder.collate_fn, sampler=valid_scene_sampler, persistent_workers=True ) } return CombinedLoader(combined_dataset, mode="min_size")
def test_combined_loader_dict_max_size_cycle(): """Test `CombinedLoader` of mode 'max_size_cycle' given mapping loaders.""" loaders = { "a": torch.utils.data.DataLoader(range(10), batch_size=4), "b": torch.utils.data.DataLoader(range(20), batch_size=5), } combined_loader = CombinedLoader(loaders, "max_size_cycle") assert len(combined_loader) == max(len(v) for v in loaders.values()) for idx, item in enumerate(combined_loader): assert isinstance(item, dict) assert len(item) == 2 assert "a" in item and "b" in item assert idx == len(combined_loader) - 1
def train_dataloader(self): return CombinedLoader( { "src": DataLoader(self.trainset_src, batch_size=self.batch_size, shuffle=True, pin_memory=True, num_workers=self.num_workers, drop_last=True), "tgt": DataLoader(self.trainset_tgt, batch_size=self.batch_size, shuffle=True, pin_memory=True, num_workers=self.num_workers, drop_last=True), }, "max_size_cycle")
def create_dataloader(): dataset = range(50) num_workers = 2 batch_size = 8 sampler = FastForwardSampler(SequentialSampler(dataset)) sampler.setup(batch_size) dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size) dataloader.fast_forward_sampler = sampler loader_dict = { "a": [DataLoader(create_iterable_dataset(3, num_workers), num_workers=num_workers, batch_size=3), dataloader], "b": DataLoader( create_iterable_dataset(2, num_workers=1, attr_name="custom_sampler"), num_workers=0, batch_size=2 ), } apply_to_collection(loader_dict, DataLoader, Trainer._add_sampler_metadata_collate) return CombinedLoader(loader_dict)
def val_dataloader(self): loaders = { "content": DataLoader( self.content_val, batch_size=1, shuffle=False, num_workers=self.workers, pin_memory=True, ), "style": DataLoader( self.style_val, batch_size=1, shuffle=False, num_workers=self.workers, pin_memory=True, ), } return CombinedLoader(loaders, "max_size_cycle")
def test_dataloader(self): dataset_root = self.hydra_conf["trainer"]["dataset_root"] dataset_paths = self.hydra_conf["trainer"]["test_dataset"].split("*") scene_dataset, img_dataset = self.setup_dataset(dataset_root, dataset_paths) self.test_scene_dataset = torch.utils.data.ConcatDataset(scene_dataset) self.test_img_dataset = torch.utils.data.ConcatDataset(img_dataset) if self.involved_imgs: combined_dataset = { "scene": DataLoader(self.test_scene_dataset, batch_size=self.batch_size, num_workers=self.scene_worker, shuffle=False, pin_memory=True, collate_fn=self.dataset_builder.collate_fn, ), "img": DataLoader(self.test_img_dataset, batch_size=self.batch_size, num_workers=self.img_worker, shuffle=False, pin_memory=True, collate_fn=self.dataset_builder.collate_fn, )} assert len(combined_dataset["scene"]) == len(combined_dataset["img"]) else: combined_dataset = { "scene": DataLoader(self.test_scene_dataset, batch_size=self.batch_size, num_workers=self.scene_worker, shuffle=False, pin_memory=True, collate_fn=self.dataset_builder.collate_fn, ), } return CombinedLoader(combined_dataset, mode="min_size")
def val_dataloader(self) -> List[DataLoader]: main_loader = DataLoader( self.val_data.batched( self.val_dataloader_conf["batch_size"] - sum(self.batch_size_extra), partial=False, ), batch_size=None, pin_memory=False, num_workers=self.val_dataloader_conf["num_workers"], ) loaders = {"main": main_loader} for cnt, (bs, val_data) in enumerate( zip(self.batch_size_extra, self.extra_valid_data) ): loaders[f"extra_{cnt}"] = DataLoader( val_data.batched(bs, partial=False), batch_size=None, pin_memory=True, num_workers=bs // 2, ) combined_loaders = CombinedLoader(loaders, "max_size_cycle") return combined_loaders
with pytest.raises( MisconfigurationException, match= rf"{limit_val_batches} \* {dl_size} < 1. Please increase the `limit_val_batches`", ): trainer._data_connector._reset_eval_dataloader(RunningStage.VALIDATING, model) @pytest.mark.parametrize( "val_dl,warns", [ (DataLoader(dataset=RandomDataset(32, 64), shuffle=True), True), (DataLoader(dataset=RandomDataset(32, 64), sampler=list( range(64))), False), (CombinedLoader(DataLoader(dataset=RandomDataset(32, 64), shuffle=True)), True), ( CombinedLoader([ DataLoader(dataset=RandomDataset(32, 64)), DataLoader(dataset=RandomDataset(32, 64), shuffle=True) ]), True, ), ( CombinedLoader({ "dl1": DataLoader(dataset=RandomDataset(32, 64)), "dl2": DataLoader(dataset=RandomDataset(32, 64), shuffle=True), }), True,
def val_dataloader(self): loader_test = DataLoader(self.test_dataset, len(self.test_dataset)) loader_u = DataLoader(self.train_dataset_u, len(self.train_dataset_u)) loaders = {"u": loader_u, "test": loader_test} combined_loaders = CombinedLoader(loaders) return combined_loaders
def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None ) -> None: """Resets the train dataloader and initialises required variables (number of batches, when to validate, etc.). Args: model: The `LightningModule` if calling this outside of the trainer scope. """ self.train_dataloader = self.request_dataloader(RunningStage.TRAINING, model=model) if self.overfit_batches > 0: if hasattr(self.train_dataloader, "sampler") and isinstance( self.train_dataloader.sampler, RandomSampler): rank_zero_warn( "You requested to overfit but enabled training dataloader shuffling." " We are turning off the training dataloader shuffling for you." ) self.train_dataloader = self.replace_sampler( self.train_dataloader, SequentialSampler(self.train_dataloader.dataset), mode=RunningStage.TRAINING) # debugging self.dev_debugger.track_load_dataloader_call( "train_dataloader", dataloaders=[self.train_dataloader]) # automatically add samplers self.train_dataloader = apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True, mode=RunningStage.TRAINING) # check the workers recursively apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, "train_dataloader") # add worker_init_fn for correct seeding in worker processes apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn) # add collate_fn to collect metadata for fault tolerant training if _fault_tolerant_training(): apply_to_collection(self.train_dataloader, DataLoader, self._add_sampler_metadata_collate) # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches self.train_dataloader = CombinedLoader( self.train_dataloader, self.data_connector.multiple_trainloader_mode) self.num_training_batches = len(self.train_dataloader) if has_len( self.train_dataloader) else float("inf") if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches)) elif self.num_training_batches != float("inf"): self.num_training_batches = int(self.num_training_batches * self.limit_train_batches) elif self.limit_train_batches != 1.0: raise MisconfigurationException( "When using an IterableDataset for `limit_train_batches`," " `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies" " `num_training_batches` to use.") # determine when to check validation # if int passed in, val checks that often # otherwise, it checks in [0, 1.0] % range of a training epoch if isinstance(self.val_check_interval, int): self.val_check_batch = self.val_check_interval if self.val_check_batch > self.num_training_batches: raise ValueError( f"`val_check_interval` ({self.val_check_interval}) must be less than or equal " f"to the number of the training batches ({self.num_training_batches}). " "If you want to disable validation set `limit_val_batches` to 0.0 instead." ) else: if not has_len(self.train_dataloader): if self.val_check_interval == 1.0: self.val_check_batch = float("inf") else: raise MisconfigurationException( "When using an IterableDataset for `train_dataloader`," " `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies" " checking validation every k training batches.") else: self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) if self.logger and self.num_training_batches < self.log_every_n_steps: rank_zero_warn( f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval" f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if" " you want to see logs for the training epoch.")
def reset_train_dataloader(self, model: LightningModule) -> None: """Resets the train dataloader and initialises required variables (number of batches, when to validate, etc.). Args: model: The current `LightningModule` """ self.train_dataloader = self.request_dataloader(model, "train") if self.overfit_batches > 0: if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler): rank_zero_warn( 'You requested to overfit but enabled training dataloader shuffling.' ' We are turning it off for you.' ) self.train_dataloader = self.replace_sampler( self.train_dataloader, SequentialSampler(self.train_dataloader.dataset) ) # debugging self.dev_debugger.track_load_dataloader_call('train_dataloader', dataloaders=[self.train_dataloader]) # automatically add samplers self.train_dataloader = apply_to_collection( self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True ) # check the workers recursively apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, 'train dataloader') # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches self.train_dataloader = CombinedLoader(self.train_dataloader, self._multiple_trainloader_mode) self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf') if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches)) elif self.num_training_batches != float('inf'): self.num_training_batches = int(self.num_training_batches * self.limit_train_batches) elif self.limit_train_batches != 1.0: raise MisconfigurationException( 'When using an IterableDataset for `limit_train_batches`,' ' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies' ' `num_training_batches` to use.' ) # determine when to check validation # if int passed in, val checks that often # otherwise, it checks in [0, 1.0] % range of a training epoch if isinstance(self.val_check_interval, int): self.val_check_batch = self.val_check_interval if self.val_check_batch > self.num_training_batches: raise ValueError( f'`val_check_interval` ({self.val_check_interval}) must be less than or equal ' f'to the number of the training batches ({self.num_training_batches}). ' 'If you want to disable validation set `limit_val_batches` to 0.0 instead.' ) else: if not has_len(self.train_dataloader): if self.val_check_interval == 1.0: self.val_check_batch = float('inf') else: raise MisconfigurationException( 'When using an IterableDataset for `train_dataloader`,' ' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies' ' checking validation every k training batches.' ) else: self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch)
def test_combined_data_loader_validation_test(cuda_available_mock, device_count_mock, use_fault_tolerant, replace_sampler_ddp, tmpdir): """This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader.""" class CustomDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, index): return self.data[index] class CustomSampler(RandomSampler): def __init__(self, data_source, name) -> None: super().__init__(data_source) self.name = name dataset = CustomDataset(range(10)) dataloader = CombinedLoader({ "a": DataLoader(CustomDataset(range(10))), "b": DataLoader(dataset, sampler=CustomSampler(dataset, "custom_sampler")), "c": { "c": DataLoader(CustomDataset(range(10))), "d": DataLoader(CustomDataset(range(10))) }, "d": [ DataLoader(CustomDataset(range(10))), DataLoader(CustomDataset(range(10))) ], }) with mock.patch.dict( os.environ, {"PL_FAULT_TOLERANT_TRAINING": str(int(use_fault_tolerant))}): trainer = Trainer(replace_sampler_ddp=replace_sampler_ddp, strategy="ddp", gpus=2) dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=True) _count = 0 _has_fastforward_sampler = False def _assert_distributed_sampler(v): nonlocal _count nonlocal _has_fastforward_sampler _count += 1 if use_fault_tolerant: _has_fastforward_sampler = True assert isinstance(v, FastForwardSampler) v = v._sampler if replace_sampler_ddp: assert isinstance(v, DistributedSampler) else: assert isinstance(v, (SequentialSampler, CustomSampler)) apply_to_collection(dataloader.sampler, Sampler, _assert_distributed_sampler) assert _count == 6 assert _has_fastforward_sampler == use_fault_tolerant def _assert_dataset(loader): d = loader.dataset if use_fault_tolerant: assert isinstance(d, CaptureMapDataset) else: assert isinstance(d, CustomDataset) apply_to_collection(dataloader.loaders, DataLoader, _assert_dataset)
def test_combined_loader_loader_type_error(): """Test the ValueError when wrapping the loaders""" with pytest.raises(ValueError, match="Invalid Datatype"): CombinedLoader(None, "max_size_cycle")
def test_combined_loader_init_mode_error(): """Test the ValueError when constructing `CombinedLoader`""" with pytest.raises(MisconfigurationException, match="Invalid Mode"): CombinedLoader([range(10)], "testtt")