def _apply_patch_fn(loader: DataLoader, iterator: Iterator): if isinstance(loader, CycleIterator): loader = loader.loader # cycle_iterator = iterator iterator = iterator._loader_iter if isinstance(loader, DataLoader) and _fault_tolerant_training(): loader._lightning_fetcher = self patch_dataloader_iterator(loader, iterator, self)
def __next__(self) -> Any: """ Fetches the next batch from internal dataloader and restarts it if necessary Returns: Any: the resulting batch Raises: StopIteration: if more then :attr:`length` batches have been returned """ # Note: if self.length is `inf`, then the iterator will never stop if self.counter >= self.__len__() or self.state.done: raise StopIteration try: return next(self._loader_iter) except StopIteration: # inform the shared state this loader has completed self.state.has_finished[id(self.loader)] = True # check if iteration should be stopped. if self.state.done: raise StopIteration self._loader_iter = iter(self.loader) # if fault tolerant is enabled, we need to patch the iterator to collect the states # before the batch gets returned. fetcher = getattr(self.loader, "_lightning_fetcher", None) if fetcher: patch_dataloader_iterator(self.loader, self._loader_iter, fetcher) return next(self._loader_iter) finally: self.counter += 1