예제 #1
0
 def _find_next(self, instance_id: int) -> T_co:
     while True:
         if self.main_datapipe_exhausted:
             raise StopIteration
         if self._datapipe_iterator is None:
             raise ValueError(
                 "_datapipe_iterator has not been set, likely because this private method is called directly "
                 "without invoking get_next_element_by_instance() first.")
         value = next(self._datapipe_iterator)
         classification = self.classifier_fn(value)
         if classification is None and self.drop_none:
             StreamWrapper.close_streams(value)
             continue
         if classification is None or classification >= self.num_instances or classification < 0:
             raise ValueError(
                 f"Output of the classification fn should be between 0 and {self.num_instances - 1}. "
                 + f"{classification} is returned.")
         if classification == instance_id:
             return value
         self.child_buffers[classification].append(value)
         self.current_buffer_usage += 1
         if self.buffer_size >= 0 and self.current_buffer_usage > self.buffer_size:
             raise BufferError(
                 f"DemultiplexerIterDataPipe buffer overflow, buffer size {self.buffer_size} is insufficient."
             )
예제 #2
0
 def __iter__(self) -> Iterator[T_co]:
     for data in self.datapipe:
         filtered = self._returnIfTrue(data)
         if self._isNonEmpty(filtered):
             yield filtered
         else:
             StreamWrapper.close_streams(data)
예제 #3
0
    def __iter__(self) -> Iterator[Tuple[T_co]]:
        iterators = [iter(datapipe) for datapipe in self.datapipes]
        try:
            for data in zip(*iterators):
                yield data
        finally:
            unused = []
            for iterator in iterators:
                unused += list(iterator)

            # TODO(VitalyFedyunin): This should be Exception or warning when torchdata.debug is enabled
            for item in unused:
                StreamWrapper.close_streams(item)
예제 #4
0
 def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
     for data in self.datapipe:
         validate_pathname_binary_tuple(data)
         pathname, data_stream = data
         try:
             # typing.cast is used here to silence mypy's type checker
             tar = tarfile.open(fileobj=cast(Optional[IO[bytes]],
                                             data_stream),
                                mode=self.mode)
             for tarinfo in tar:
                 if not tarinfo.isfile():
                     continue
                 extracted_fobj = tar.extractfile(tarinfo)
                 if extracted_fobj is None:
                     warnings.warn(
                         "failed to extract file {} from source tarfile {}".
                         format(tarinfo.name, pathname))
                     raise tarfile.ExtractError
                 inner_pathname = os.path.normpath(
                     os.path.join(pathname, tarinfo.name))
                 yield inner_pathname, StreamWrapper(
                     extracted_fobj)  # type: ignore[misc]
         except Exception as e:
             warnings.warn(
                 "Unable to extract files from corrupted tarfile stream {} due to: {}, abort!"
                 .format(pathname, e))
             raise e
예제 #5
0
    def __iter__(self) -> Iterator[Tuple[T_co]]:
        iterators = [iter(datapipe) for datapipe in self.datapipes]
        try:
            for data in zip(*iterators):
                yield data
        finally:
            unused = []
            for iterator in iterators:
                try:
                    unused += list(iterator)
                except RuntimeError:  # Some iterators may have been invalidated by single iterator constraints
                    pass

            # TODO(VitalyFedyunin): This should be Exception or warning when torchdata.debug is enabled
            for item in unused:
                StreamWrapper.close_streams(item)
예제 #6
0
 def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
     for data in self.datapipe:
         validate_pathname_binary_tuple(data)
         pathname, data_stream = data
         folder_name = os.path.dirname(pathname)
         try:
             # typing.cast is used here to silence mypy's type checker
             zips = zipfile.ZipFile(cast(IO[bytes], data_stream))
             for zipinfo in zips.infolist():
                 # major version should always be 3 here.
                 if sys.version_info[1] >= 6:
                     if zipinfo.is_dir():
                         continue
                 elif zipinfo.filename.endswith('/'):
                     continue
                 extracted_fobj = zips.open(zipinfo)
                 inner_pathname = os.path.normpath(
                     os.path.join(folder_name, zipinfo.filename))
                 yield inner_pathname, StreamWrapper(
                     extracted_fobj)  # type: ignore[misc]
         except Exception as e:
             warnings.warn(
                 f"Unable to extract files from corrupted zipfile stream {pathname} due to: {e}, abort!"
             )
             raise e