Пример #1
0
 def __init__(self,
              datapipe: Iterable[Tuple[str, BufferedIOBase]],
              length: int = -1):
     super().__init__()
     self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe
     self.length: int = length
     deprecation_warning_torchdata(type(self).__name__)
Пример #2
0
 def __new__(cls,
             datapipe: Iterable[str],
             mode: str = 'b',
             length: int = -1):
     deprecation_warning_torchdata(type(cls).__name__)
     return FileOpenerIterDataPipe(datapipe=datapipe,
                                   mode=mode,
                                   length=length)
Пример #3
0
 def __init__(self,
              datapipe: Iterable[Tuple[str, BufferedIOBase]],
              *handlers: Callable,
              key_fn: Callable = extension_extract_fn) -> None:
     super().__init__()
     self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe
     if not handlers:
         handlers = (decoder_basichandlers, decoder_imagehandler('torch'))
     self.decoder = Decoder(*handlers, key_fn=key_fn)
     deprecation_warning_torchdata(type(self).__name__)
Пример #4
0
    def __init__(self,
                 datapipe: IterDataPipe[T_co],
                 batch_size: int,
                 drop_last: bool = False,
                 batch_num: int = 100,
                 bucket_num: int = 1,
                 sort_key: Optional[Callable] = None,
                 in_batch_shuffle: bool = True) -> None:
        assert batch_size > 0, "Batch size is required to be larger than 0!"
        assert batch_num > 0, "Number of batches is required to be larger than 0!"
        assert bucket_num > 0, "Number of buckets is required to be larger than 0!"
        deprecation_warning_torchdata(type(self).__name__)
        super().__init__()

        # TODO: Verify _datapippe is not going to be serialized twice
        # and be able to reconstruct
        self._datapipe = datapipe
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.batch_num = batch_num
        self.bucket_num = bucket_num
        self.sort_key = sort_key
        self.in_batch_shuffle = in_batch_shuffle

        self.bucket_size = batch_size * batch_num
        self.pool_size = self.bucket_size * bucket_num

        if bucket_num > 1 or sort_key is None:
            if in_batch_shuffle:
                datapipe = datapipe.batch(
                    batch_size=self.pool_size,
                    drop_last=False).map(fn=_in_batch_shuffle_fn).unbatch()
            else:
                datapipe = datapipe.shuffle(buffer_size=self.pool_size)
        if sort_key is not None:
            datapipe = datapipe.batch(
                self.bucket_size).map(fn=sort_key).unbatch()
        datapipe = datapipe.batch(batch_size, drop_last=drop_last)
        if sort_key is not None:
            # In-batch shuffle each bucket seems not that useful
            if in_batch_shuffle:
                datapipe = datapipe.batch(
                    batch_size=bucket_num,
                    drop_last=False).map(fn=_in_batch_shuffle_fn).unbatch()
            else:
                datapipe = datapipe.shuffle(buffer_size=self.bucket_size)
        self.datapipe = datapipe

        self.length = None
Пример #5
0
 def __init__(self, datapipe, timeout=None):
     self.datapipe = datapipe
     self.timeout = timeout
     deprecation_warning_torchdata(type(self).__name__)
Пример #6
0
 def __init__(self, datapipe):
     self.datapipe = datapipe
     deprecation_warning_torchdata(type(self).__name__)