예제 #1
0
파일: grouping.py 프로젝트: donhuvy/pytorch
    def __init__(self,
                 datapipe: IterDataPipe[T_co],
                 group_key_fn: Callable,
                 *,
                 buffer_size: int = 10000,
                 group_size: Optional[int] = None,
                 guaranteed_group_size: Optional[int] = None,
                 drop_remaining: bool = False):
        _check_unpickable_fn(group_key_fn)
        self.datapipe = datapipe
        self.group_key_fn = group_key_fn

        self.max_buffer_size = buffer_size
        self.buffer_elements: DefaultDict[Any, List] = defaultdict(list)
        self.curr_buffer_size = 0
        self.group_size = group_size
        self.guaranteed_group_size = None
        if group_size is not None and buffer_size is not None:
            assert 0 < group_size <= buffer_size
            self.guaranteed_group_size = group_size
        if guaranteed_group_size is not None:
            assert group_size is not None and 0 < guaranteed_group_size <= group_size
            self.guaranteed_group_size = guaranteed_group_size
        self.drop_remaining = drop_remaining
        self.wrapper_class = DataChunk
예제 #2
0
    def __init__(
        self,
        datapipe: IterDataPipe,
        filter_fn: Callable,
        drop_empty_batches: Optional[bool] = None,
        input_col=None,
    ) -> None:
        super().__init__()
        self.datapipe = datapipe

        _check_unpickable_fn(filter_fn)
        self.filter_fn = filter_fn  # type: ignore[assignment]

        if drop_empty_batches is None:
            drop_empty_batches = True
        else:
            _deprecation_warning(
                type(self).__name__,
                deprecation_version="1.12",
                removal_version="1.14",
                old_argument_name="drop_empty_batches",
            )
        self.drop_empty_batches = drop_empty_batches

        self.input_col = input_col
예제 #3
0
 def __init__(
     self,
     datapipe: MapDataPipe,
     fn: Callable = default_fn,
 ) -> None:
     super().__init__()
     self.datapipe = datapipe
     _check_unpickable_fn(fn)
     self.fn = fn  # type: ignore[assignment]
예제 #4
0
    def __new__(cls, datapipe: IterDataPipe, num_instances: int,
                classifier_fn: Callable[[T_co], Optional[int]], drop_none: bool = False, buffer_size: int = 1000):
        if num_instances < 1:
            raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found")

        _check_unpickable_fn(classifier_fn)

        # When num_instances == 1, demux can be replaced by filter,
        # but keep it as Demultiplexer for the sake of consistency
        # like throwing Error when classification result is out of o range
        container = _DemultiplexerIterDataPipe(datapipe, num_instances, classifier_fn, drop_none, buffer_size)
        return [_ChildDataPipe(container, i) for i in range(num_instances)]
예제 #5
0
    def __init__(
        self,
        datapipe: IterDataPipe,
        fn: Callable,
        input_col=None,
        output_col=None,
    ) -> None:
        super().__init__()
        self.datapipe = datapipe

        _check_unpickable_fn(fn)
        self.fn = fn  # type: ignore[assignment]

        self.input_col = input_col
        if input_col is None and output_col is not None:
            raise ValueError("`output_col` must be None when `input_col` is None.")
        if isinstance(output_col, (list, tuple)):
            if len(output_col) > 1:
                raise ValueError("`output_col` must be a single-element list or tuple")
            output_col = output_col[0]
        self.output_col = output_col