def test_has_len(): assert has_len(DataLoader(RandomDataset(1, 1))) with pytest.raises(ValueError, match="`Dataloader` returned 0 length."): assert has_len(DataLoader(RandomDataset(0, 0))) assert not has_len(DataLoader(RandomIterableDataset(1, 1)))
def test_has_len(): assert has_len(DataLoader(RandomDataset(1, 1))) with pytest.warns(UserWarning, match="`DataLoader` returned 0 length."): assert has_len(DataLoader(RandomDataset(0, 0))) assert not has_len(DataLoader(RandomIterableDataset(1, 1)))
def reset_train_dataloader(self, model: LightningModule) -> None: """Resets the train dataloader and initialises required variables (number of batches, when to validate, etc.). Args: model: The current `LightningModule` """ self.train_dataloader = self.request_dataloader(model.train_dataloader) # debugging self.dev_debugger.track_load_dataloader_call( 'train_dataloader', dataloaders=[self.train_dataloader]) self.num_training_batches = 0 # automatically add samplers self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True) self.num_training_batches = len(self.train_dataloader) if has_len( self.train_dataloader) else float('inf') self._worker_check(self.train_dataloader, 'train dataloader') if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches)) elif self.num_training_batches != float('inf'): self.num_training_batches = int(self.num_training_batches * self.limit_train_batches) elif self.limit_train_batches != 1.0: raise MisconfigurationException( 'When using an IterableDataset for `limit_train_batches`,' ' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies' ' `num_training_batches` to use.') # determine when to check validation # if int passed in, val checks that often # otherwise, it checks in [0, 1.0] % range of a training epoch if isinstance(self.val_check_interval, int): self.val_check_batch = self.val_check_interval if self.val_check_batch > self.num_training_batches: raise ValueError( f'`val_check_interval` ({self.val_check_interval}) must be less than or equal ' f'to the number of the training batches ({self.num_training_batches}). ' 'If you want to disable validation set `limit_val_batches` to 0.0 instead.' ) else: if not has_len(self.train_dataloader): if self.val_check_interval == 1.0: self.val_check_batch = float('inf') else: raise MisconfigurationException( 'When using an IterableDataset for `train_dataloader`,' ' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies' ' checking validation every k training batches.') else: self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch)
def test_warning_with_iterable_dataset_and_len(tmpdir): """ Tests that a warning message is shown when an IterableDataset defines `__len__`. """ model = EvalModelTemplate() original_dataset = model.train_dataloader().dataset class IterableWithLen(IterableDataset): def __iter__(self): return iter(original_dataset) def __len__(self): return len(original_dataset) dataloader = DataLoader(IterableWithLen(), batch_size=16) assert has_len(dataloader) assert has_iterable_dataset(dataloader) trainer = Trainer( default_root_dir=tmpdir, max_steps=3, ) with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader]) with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): trainer.test(model, test_dataloaders=[dataloader])
def training_epoch_end(self, outputs): ids = torch.cat([o['ids'] for o in outputs], dim=0) # in distributed mode collect ids from every process (gpu) if distributed_available(): gather_ids = [ torch.zeros_like(ids) for _ in range(torch.distributed.get_world_size()) ] torch.distributed.all_gather(gather_ids, ids) ids = torch.cat(gather_ids, dim=0) if has_len(self.trainer.datamodule.train_dataset): received = torch.zeros(len( self.trainer.datamodule.train_dataset)).to(dtype=bool) else: received = torch.zeros( len(list( self.trainer.datamodule.train_dataset))).to(dtype=bool) received[ids] = True if self.check_ids: # assert no duplicate element received assert len(set(ids.tolist())) == len( ids.tolist()), (f"Received {len(ids.tolist())} ids but only" f" {len(set(ids.tolist()))} are unique: {ids}") # assert all elements received assert all(received), ( f"({self.trainer.max_steps}) Received not all {len(received)} ids: {received}" )
def test_index_batch_sampler_methods(): dataset = range(15) sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, 3, False) index_batch_sampler = IndexBatchSamplerWrapper(batch_sampler) assert isinstance(index_batch_sampler, Iterable) assert has_len(index_batch_sampler)
def setup( # type: ignore[override] self, dataloader: Iterable, batch_to_device: Optional[Callable[[Any], Any]] = None) -> None: super().setup(dataloader) self._has_len = has_len(dataloader) if batch_to_device is not None: self.batch_to_device = batch_to_device
def _validate_dataloader(dataloaders: Union[List['DataLoader'], 'DataLoader']): if not isinstance(dataloaders, list): dataloaders = [dataloaders] for dataloader in dataloaders: if not has_len(dataloader): raise MisconfigurationException( "TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`." " HINT: You can mock the length on your dataset to bypass this MisconfigurationException." )
def num_training_steps(self) -> int: r""" Total training steps inferred from datasets length, nodes and devices. """ if self.trainer.max_steps is not None and self.trainer.max_steps >= 0: return self.trainer.max_steps if not has_len(self.trainer.datamodule.train_dataset): rank_zero_warn("Using IterableDataset, cannot compute max_steps, returning None") return None # train samples train_samples = len(self.trainer.datamodule.train_dataset) # number of training devices if self.trainer._accelerator_connector.use_dp: total_devices = 1 # with dp, a single batch is divided across many gpus elif self.trainer._accelerator_connector.use_ddp2: total_devices = self.trainer.num_nodes else: total_devices = self.trainer.num_processes * self.trainer.num_nodes # the number of training samples may be modified in distributed training # to be divisible by the number of GPUs... train_samples_per_device = math.ceil(train_samples / total_devices) # train batches from the dataloader train_batches_per_device = math.ceil(train_samples_per_device / self.hyperparameters.batch_size) # eventually limit train batches limit_batches = self.trainer.limit_train_batches train_batches_per_device = ( min(train_batches_per_device, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * train_batches_per_device) ) # train steps for each device train_steps_per_device = math.ceil(train_batches_per_device / self.trainer.accumulate_grad_batches) # total train steps across all epochs total_train_steps = train_steps_per_device * self.trainer.max_epochs rank_zero_warn(f"Automatically computed total steps equal to {total_train_steps}") return total_train_steps
def _is_valid_batch_size(current_size, dataloader): return not has_len(dataloader) or current_size <= len(dataloader)
def _reset_eval_dataloader( self, model: LightningModule, mode: str) -> Tuple[List[Union[int, float]], List[DataLoader]]: """Generic method to reset a dataloader for evaluation. Args: model: The current `LightningModule` mode: Either `'val'` or `'test'` Returns: Tuple (num_batches, dataloaders) """ # always get the loaders first so we can count how many there are loader_name = f'{mode}_dataloader' dataloaders = self.request_dataloader(getattr(model, loader_name)) if not isinstance(dataloaders, list): dataloaders = [dataloaders] # when overfitting use the training loader as val and test # duplicate it the numb of times needed to match the train loaders if self.overfit_batches > 0: num_loaders = len(dataloaders) train_dataloader = self.request_dataloader( getattr(model, 'train_dataloader')) dataloaders = [ deepcopy(train_dataloader) for _ in range(num_loaders) ] self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders) for loader_i in range(len(dataloaders)): loader = dataloaders[loader_i] # shuffling in val and test set is bad practice if mode in ('val', 'test') and hasattr( loader, 'sampler') and isinstance(loader.sampler, RandomSampler): # when overfitting, the dataloader should not have sampler if self.overfit_batches > 0: rank_zero_warn( 'You requested to overfit but enabled test/val dataloader shuffling.' ' We are turning it off for you.') dataloaders[loader_i] = self.replace_sampler( loader, SequentialSampler(loader.dataset)) else: rank_zero_warn( f'Your {mode}_dataloader has `shuffle=True`, it is best practice to turn' ' this off for validation and test dataloaders.') if any([dl is None for dl in dataloaders]): rank_zero_warn( "One of given dataloaders is None and it will be skipped.") # add samplers dataloaders = [ self.auto_add_sampler(dl, shuffle=False) for dl in dataloaders if dl is not None ] loader_num_batches = [] # determine number of batches # datasets could be none, 1 or 2+ if len(dataloaders) != 0: for i, dataloader in enumerate(dataloaders): num_batches = len(dataloader) if has_len( dataloader) else float('inf') self._worker_check(dataloader, f'{mode} dataloader {i}') # percent or num_steps limit_eval_batches = getattr(self, f'limit_{mode}_batches') # limit num batches either as a percent or num steps if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0: num_batches = min(num_batches, int(limit_eval_batches)) elif num_batches != float('inf'): num_batches = int(num_batches * limit_eval_batches) elif limit_eval_batches != 1.0: raise MisconfigurationException( 'When using an IterableDataset for `limit_{mode}_batches`,' f' `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies' f' `num_{mode}_batches` to use.') if num_batches == 0 and limit_eval_batches > 0.0 and isinstance( limit_eval_batches, float): min_pct = 1.0 / len(dataloader) raise MisconfigurationException( f'you requested to check {limit_eval_batches} of the {mode} dataloader but' f' {limit_eval_batches}*{num_batches} < 1. Please increase the limit_{mode}_batches.' f' Try at least limit_{mode}_batches={min_pct}') loader_num_batches.append(num_batches) return loader_num_batches, dataloaders
def reset_train_dataloader(self, model: LightningModule) -> None: """Resets the train dataloader and initialises required variables (number of batches, when to validate, etc.). Args: model: The current `LightningModule` """ self.train_dataloader = self.request_dataloader(model, "train") if self.overfit_batches > 0: if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler): rank_zero_warn( 'You requested to overfit but enabled training dataloader shuffling.' ' We are turning it off for you.' ) self.train_dataloader = self.replace_sampler( self.train_dataloader, SequentialSampler(self.train_dataloader.dataset) ) # debugging self.dev_debugger.track_load_dataloader_call('train_dataloader', dataloaders=[self.train_dataloader]) # automatically add samplers self.train_dataloader = apply_to_collection( self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True ) # check the workers recursively apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, 'train dataloader') # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches self.train_dataloader = CombinedLoader(self.train_dataloader, self._multiple_trainloader_mode) self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf') if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches)) elif self.num_training_batches != float('inf'): self.num_training_batches = int(self.num_training_batches * self.limit_train_batches) elif self.limit_train_batches != 1.0: raise MisconfigurationException( 'When using an IterableDataset for `limit_train_batches`,' ' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies' ' `num_training_batches` to use.' ) # determine when to check validation # if int passed in, val checks that often # otherwise, it checks in [0, 1.0] % range of a training epoch if isinstance(self.val_check_interval, int): self.val_check_batch = self.val_check_interval if self.val_check_batch > self.num_training_batches: raise ValueError( f'`val_check_interval` ({self.val_check_interval}) must be less than or equal ' f'to the number of the training batches ({self.num_training_batches}). ' 'If you want to disable validation set `limit_val_batches` to 0.0 instead.' ) else: if not has_len(self.train_dataloader): if self.val_check_interval == 1.0: self.val_check_batch = float('inf') else: raise MisconfigurationException( 'When using an IterableDataset for `train_dataloader`,' ' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies' ' checking validation every k training batches.' ) else: self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch)
def _reset_eval_dataloader( self, mode: RunningStage, model: Optional["pl.LightningModule"] = None ) -> Tuple[List[Union[int, float]], List[DataLoader]]: """Generic method to reset a dataloader for evaluation. Args: mode: The running stage of the ``Trainer`` model: The ``LightningModule`` if calling this outside of the trainer scope. Returns: Tuple (num_batches, dataloaders) """ assert mode.evaluating or mode == RunningStage.PREDICTING # always get the loaders first so we can count how many there are loader_name = f"{mode.dataloader_prefix}_dataloader" dataloaders = self.request_dataloader(mode, model=model) if not isinstance(dataloaders, list): dataloaders = [dataloaders] # when overfitting, use the training loader as val and test # duplicate it the numb of times needed to match the train loaders if self.overfit_batches > 0: train_dataloader = self.request_dataloader(RunningStage.TRAINING, model=model) dataloaders = [ deepcopy(train_dataloader) for _ in range(len(dataloaders)) ] self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders) for loader_i in range(len(dataloaders)): loader = dataloaders[loader_i] if hasattr(loader, "sampler") and isinstance( loader.sampler, RandomSampler): # when overfitting, the dataloader should not have sampler if self.overfit_batches > 0 and mode.evaluating: rank_zero_warn( "You requested to overfit but enabled val/test dataloader shuffling." " We are turning it off for you.") dataloaders[loader_i] = self.replace_sampler( loader, SequentialSampler(loader.dataset), mode=mode) else: rank_zero_warn( f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`," "it is strongly recommended that you turn this off for val/test/predict dataloaders." ) if any(dl is None for dl in dataloaders): rank_zero_warn( "One of given dataloaders is None and it will be skipped.") # add samplers dataloaders = [ self.auto_add_sampler(dl, False, mode=mode) for dl in dataloaders if dl is not None ] # add worker_init_fn for correct seeding in worker processes apply_to_collection(dataloaders, dtype=DataLoader, function=self.auto_add_worker_init_fn) loader_num_batches = [] # determine number of batches # datasets could be none, 1 or 2+ if len(dataloaders) != 0: for i, dataloader in enumerate(dataloaders): num_batches = len(dataloader) if has_len( dataloader) else float("inf") self._worker_check(dataloader, f"{mode.dataloader_prefix}_dataloader {i}") # percent or num_steps limit_eval_batches = getattr( self, f"limit_{mode.dataloader_prefix}_batches") # limit num batches either as a percent or num steps if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0: num_batches = min(num_batches, int(limit_eval_batches)) elif num_batches != float("inf"): num_batches = int(num_batches * limit_eval_batches) elif limit_eval_batches != 1.0: raise MisconfigurationException( f"When using an IterableDataset for `limit_{mode}_batches`," f" `Trainer(limit_{mode.dataloader_prefix}_batches)` must be `0.0`, `1.0` or an int. An int k" f" specifies `num_{mode.dataloader_prefix}_batches` to use." ) if num_batches == 0 and limit_eval_batches > 0.0 and isinstance( limit_eval_batches, float): min_pct = 1.0 / len(dataloader) raise MisconfigurationException( f"you requested to check {limit_eval_batches} of the `{mode.dataloader_prefix}_dataloader` but" f" {limit_eval_batches}*{num_batches} < 1. Please increase the" f" `limit_{mode.dataloader_prefix}_batches` flag. Try at least" f" `limit_{mode.dataloader_prefix}_batches={min_pct}`") loader_num_batches.append(num_batches) return loader_num_batches, dataloaders
def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None ) -> None: """Resets the train dataloader and initialises required variables (number of batches, when to validate, etc.). Args: model: The `LightningModule` if calling this outside of the trainer scope. """ self.train_dataloader = self.request_dataloader(RunningStage.TRAINING, model=model) if self.overfit_batches > 0: if hasattr(self.train_dataloader, "sampler") and isinstance( self.train_dataloader.sampler, RandomSampler): rank_zero_warn( "You requested to overfit but enabled training dataloader shuffling." " We are turning off the training dataloader shuffling for you." ) self.train_dataloader = self.replace_sampler( self.train_dataloader, SequentialSampler(self.train_dataloader.dataset), mode=RunningStage.TRAINING) # debugging self.dev_debugger.track_load_dataloader_call( "train_dataloader", dataloaders=[self.train_dataloader]) # automatically add samplers self.train_dataloader = apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True, mode=RunningStage.TRAINING) # check the workers recursively apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, "train_dataloader") # add worker_init_fn for correct seeding in worker processes apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn) # add collate_fn to collect metadata for fault tolerant training if _fault_tolerant_training(): apply_to_collection(self.train_dataloader, DataLoader, self._add_sampler_metadata_collate) # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches self.train_dataloader = CombinedLoader( self.train_dataloader, self.data_connector.multiple_trainloader_mode) self.num_training_batches = len(self.train_dataloader) if has_len( self.train_dataloader) else float("inf") if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches)) elif self.num_training_batches != float("inf"): self.num_training_batches = int(self.num_training_batches * self.limit_train_batches) elif self.limit_train_batches != 1.0: raise MisconfigurationException( "When using an IterableDataset for `limit_train_batches`," " `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies" " `num_training_batches` to use.") # determine when to check validation # if int passed in, val checks that often # otherwise, it checks in [0, 1.0] % range of a training epoch if isinstance(self.val_check_interval, int): self.val_check_batch = self.val_check_interval if self.val_check_batch > self.num_training_batches: raise ValueError( f"`val_check_interval` ({self.val_check_interval}) must be less than or equal " f"to the number of the training batches ({self.num_training_batches}). " "If you want to disable validation set `limit_val_batches` to 0.0 instead." ) else: if not has_len(self.train_dataloader): if self.val_check_interval == 1.0: self.val_check_batch = float("inf") else: raise MisconfigurationException( "When using an IterableDataset for `train_dataloader`," " `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies" " checking validation every k training batches.") else: self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) if self.logger and self.num_training_batches < self.log_every_n_steps: rank_zero_warn( f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval" f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if" " you want to see logs for the training epoch.")
def _reset_eval_dataloader( self, mode: str, model: Optional["pl.LightningModule"] = None ) -> Tuple[List[Union[int, float]], List[DataLoader]]: """Generic method to reset a dataloader for evaluation. Args: mode: Either `'val'`, `'test'` or `'predict'` model: The `LightningModule` if calling this outside of the trainer scope. Returns: Tuple (num_batches, dataloaders) """ # always get the loaders first so we can count how many there are loader_name = f"{mode}_dataloader" dataloaders = self.request_dataloader(mode, model=model) if not isinstance(dataloaders, list): dataloaders = [dataloaders] # when overfitting use the training loader as val and test # duplicate it the numb of times needed to match the train loaders if self.overfit_batches > 0: num_loaders = len(dataloaders) train_dataloader = self.request_dataloader("train", model=model) dataloaders = [deepcopy(train_dataloader) for _ in range(num_loaders)] self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders) for loader_i in range(len(dataloaders)): loader = dataloaders[loader_i] # shuffling in val and test set is bad practice modes = ("val", "test", "predict") if mode in modes and hasattr(loader, "sampler") and isinstance(loader.sampler, RandomSampler): # when overfitting, the dataloader should not have sampler if self.overfit_batches > 0 and mode != "predict": rank_zero_warn( "You requested to overfit but enabled val/test dataloader shuffling." " We are turning it off for you." ) dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset)) else: rank_zero_warn( f"Your {mode}_dataloader has `shuffle=True`, it is best practice to turn" " this off for val/test/predict dataloaders." ) if any(dl is None for dl in dataloaders): rank_zero_warn("One of given dataloaders is None and it will be skipped.") # add samplers dataloaders = [ self.auto_add_sampler(dl, shuffle=False, mode=self.state.stage) for dl in dataloaders if dl is not None ] # add worker_init_fn for correct seeding in worker processes apply_to_collection(dataloaders, dtype=DataLoader, function=self.auto_add_worker_init_fn) # allow accelerator to modify dataloader hook_name = f"on_reset_{mode}_dataloader" dataloaders = getattr(self.accelerator, hook_name)(dataloaders) loader_num_batches = [] # determine number of batches # datasets could be none, 1 or 2+ if len(dataloaders) != 0: for i, dataloader in enumerate(dataloaders): num_batches = len(dataloader) if has_len(dataloader) else float("inf") self._worker_check(dataloader, f"{mode} dataloader {i}") # percent or num_steps limit_eval_batches = getattr(self, f"limit_{mode}_batches") # limit num batches either as a percent or num steps if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0: num_batches = min(num_batches, int(limit_eval_batches)) elif num_batches != float("inf"): num_batches = int(num_batches * limit_eval_batches) elif limit_eval_batches != 1.0: raise MisconfigurationException( "When using an IterableDataset for `limit_{mode}_batches`," f" `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies" f" `num_{mode}_batches` to use." ) if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float): min_pct = 1.0 / len(dataloader) raise MisconfigurationException( f"you requested to check {limit_eval_batches} of the {mode} dataloader but" f" {limit_eval_batches}*{num_batches} < 1. Please increase the limit_{mode}_batches." f" Try at least limit_{mode}_batches={min_pct}" ) loader_num_batches.append(num_batches) return loader_num_batches, dataloaders