def _find_next(self, instance_id: int) -> T_co: while True: if self.main_datapipe_exhausted: raise StopIteration if self._datapipe_iterator is None: raise ValueError( "_datapipe_iterator has not been set, likely because this private method is called directly " "without invoking get_next_element_by_instance() first.") value = next(self._datapipe_iterator) classification = self.classifier_fn(value) if classification is None and self.drop_none: StreamWrapper.close_streams(value) continue if classification is None or classification >= self.num_instances or classification < 0: raise ValueError( f"Output of the classification fn should be between 0 and {self.num_instances - 1}. " + f"{classification} is returned.") if classification == instance_id: return value self.child_buffers[classification].append(value) self.current_buffer_usage += 1 if self.buffer_size >= 0 and self.current_buffer_usage > self.buffer_size: raise BufferError( f"DemultiplexerIterDataPipe buffer overflow, buffer size {self.buffer_size} is insufficient." )
def __iter__(self) -> Iterator[T_co]: for data in self.datapipe: filtered = self._returnIfTrue(data) if self._isNonEmpty(filtered): yield filtered else: StreamWrapper.close_streams(data)
def __iter__(self) -> Iterator[Tuple[T_co]]: iterators = [iter(datapipe) for datapipe in self.datapipes] try: for data in zip(*iterators): yield data finally: unused = [] for iterator in iterators: unused += list(iterator) # TODO(VitalyFedyunin): This should be Exception or warning when torchdata.debug is enabled for item in unused: StreamWrapper.close_streams(item)
def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: for data in self.datapipe: validate_pathname_binary_tuple(data) pathname, data_stream = data try: # typing.cast is used here to silence mypy's type checker tar = tarfile.open(fileobj=cast(Optional[IO[bytes]], data_stream), mode=self.mode) for tarinfo in tar: if not tarinfo.isfile(): continue extracted_fobj = tar.extractfile(tarinfo) if extracted_fobj is None: warnings.warn( "failed to extract file {} from source tarfile {}". format(tarinfo.name, pathname)) raise tarfile.ExtractError inner_pathname = os.path.normpath( os.path.join(pathname, tarinfo.name)) yield inner_pathname, StreamWrapper( extracted_fobj) # type: ignore[misc] except Exception as e: warnings.warn( "Unable to extract files from corrupted tarfile stream {} due to: {}, abort!" .format(pathname, e)) raise e
def __iter__(self) -> Iterator[Tuple[T_co]]: iterators = [iter(datapipe) for datapipe in self.datapipes] try: for data in zip(*iterators): yield data finally: unused = [] for iterator in iterators: try: unused += list(iterator) except RuntimeError: # Some iterators may have been invalidated by single iterator constraints pass # TODO(VitalyFedyunin): This should be Exception or warning when torchdata.debug is enabled for item in unused: StreamWrapper.close_streams(item)
def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: for data in self.datapipe: validate_pathname_binary_tuple(data) pathname, data_stream = data folder_name = os.path.dirname(pathname) try: # typing.cast is used here to silence mypy's type checker zips = zipfile.ZipFile(cast(IO[bytes], data_stream)) for zipinfo in zips.infolist(): # major version should always be 3 here. if sys.version_info[1] >= 6: if zipinfo.is_dir(): continue elif zipinfo.filename.endswith('/'): continue extracted_fobj = zips.open(zipinfo) inner_pathname = os.path.normpath( os.path.join(folder_name, zipinfo.filename)) yield inner_pathname, StreamWrapper( extracted_fobj) # type: ignore[misc] except Exception as e: warnings.warn( f"Unable to extract files from corrupted zipfile stream {pathname} due to: {e}, abort!" ) raise e