コード例 #1
0
ファイル: loader.py プロジェクト: yifeim/gluon-ts
def TrainDataLoader(
    dataset: Dataset,
    *,
    transform: Transformation,
    batch_size: int,
    stack_fn: Callable,
    num_batches_per_epoch: Optional[int] = None,
    num_workers: Optional[int] = None,
    num_prefetch: Optional[int] = None,
    shuffle_buffer_length: Optional[int] = None,
    decode_fn: Callable = lambda x: x,
):
    transformed_dataset = TransformedDataset(Cyclic(dataset),
                                             transform,
                                             is_train=True)
    data_iterable = (PseudoShuffled(
        transformed_dataset, shuffle_buffer_length=shuffle_buffer_length)
                     if shuffle_buffer_length is not None else
                     transformed_dataset)
    data_loader = DataLoader(
        data_iterable=data_iterable,
        batch_size=batch_size,
        stack_fn=stack_fn,
        num_workers=num_workers,
        num_prefetch=num_prefetch,
        decode_fn=decode_fn,
    )
    return (iter(data_loader) if num_batches_per_epoch is None else
            IterableSlice(iter(data_loader), num_batches_per_epoch))
コード例 #2
0
ファイル: estimator.py プロジェクト: vishalbelsare/gluon-ts
    def create_training_data_loader(
        self,
        data: Dataset,
        module: DeepARLightningModule,
        shuffle_buffer_length: Optional[int] = None,
        **kwargs,
    ) -> Iterable:
        transformation = self._create_instance_splitter(
            module, "training") + SelectFields(TRAINING_INPUT_NAMES)

        training_instances = transformation.apply(
            Cyclic(data) if shuffle_buffer_length is None else PseudoShuffled(
                Cyclic(data), shuffle_buffer_length=shuffle_buffer_length))

        return IterableSlice(
            iter(
                DataLoader(
                    IterableDataset(training_instances),
                    batch_size=self.batch_size,
                    **kwargs,
                )),
            self.num_batches_per_epoch,
        )
コード例 #3
0
ファイル: test_itertools.py プロジェクト: RomaKoks/gluon-ts
    ],
)
def test_pseudo_shuffled(data: Iterable) -> None:
    list_data = list(data)
    shuffled_iter = PseudoShuffled(iter(list_data), shuffle_buffer_length=5)
    shuffled_data = list(shuffled_iter)
    assert len(shuffled_data) == len(list_data)
    assert all(d in shuffled_data for d in list_data)


@pytest.mark.parametrize(
    "data, expected_elements_per_iteration",
    [
        (Cached(range(4)), (list(range(4)), ) * 5),
        (batcher(range(10), 3), ([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]], [])),
        (IterableSlice(range(10), 3), ([0, 1, 2], ) * 5),
        (
            IterableSlice(iter(range(10)), 3),
            ([0, 1, 2], [3, 4, 5], [6, 7, 8], [9], []),
        ),
        (
            IterableSlice(iter(Cyclic(range(5))), 3),
            ([0, 1, 2], [3, 4, 0], [1, 2, 3], [4, 0, 1]),
        ),
    ],
)
def test_iterate_multiple_times(data: Iterable,
                                expected_elements_per_iteration: Tuple[List]):
    for expected_elements in expected_elements_per_iteration:
        assert list(data) == expected_elements
コード例 #4
0
ファイル: loader.py プロジェクト: vishalbelsare/gluon-ts
def TrainDataLoader(
    dataset: Dataset,
    *,
    transform: Transformation = Identity(),
    batch_size: int,
    stack_fn: Callable,
    num_batches_per_epoch: Optional[int] = None,
    num_prefetch: Optional[int] = None,
    num_workers: Optional[int] = None,
    shuffle_buffer_length: Optional[int] = None,
    decode_fn: Callable = lambda x: x,
):
    """Construct an iterator of batches for training purposes.

    This function wraps around ``DataLoader`` to offer training-specific
    behaviour and options, as follows:

        1. The provided dataset is iterated cyclically, so that one can go over
        it multiple times in a single epoch. 2. A transformation must be
        provided, that is lazily applied as the dataset is being iterated;
        this is useful e.g. to slice random instances of fixed length out of
        each time series in the dataset. 3. The resulting batches can be
        iterated in a pseudo-shuffled order.

    The returned object is a stateful iterator, whose length is either
    ``num_batches_per_epoch`` (if not ``None``) or infinite (otherwise).

    Parameters
    ----------
    dataset
        Data to iterate over.
    transform
        Transformation to be lazily applied as data is being iterated.
        The transformation is applied in "training mode" (``is_train=True``).
    batch_size
        Number of entries to include in a batch.
    stack_fn
        Function to use to stack data entries into batches.
        This can be used to set a specific array type or computing device
        the arrays should end up onto (CPU, GPU).
    num_batches_per_epoch
        Length of the iterator. If ``None``, then the iterator is endless.
    num_workers
        Number of worker processes to use. Default: None.
    num_prefetch
        Sets the length of the queue of batches being produced by worker
        processes. (Only meaningful when ``num_workers is not None``).
    shuffle_buffer_length
        Size of the buffer used for shuffling. Default: None, in which case no
        shuffling occurs.
    decode_fn
        A function called on each batch after it's been taken out of the queue.
        (Only meaningful when ``num_workers is not None``).

    Returns
    -------
    Iterator[DataBatch]
        An iterator of batches.
    """
    dataset = Cyclic(dataset)

    if shuffle_buffer_length:
        dataset = PseudoShuffled(dataset, shuffle_buffer_length)

    transform += Batch(batch_size=batch_size) + AdhocTransform(stack_fn)
    transformed_dataset = transform.apply(dataset, is_train=True)

    if num_workers is not None:
        loader = MultiProcessLoader(
            transformed_dataset,
            decode_fn=decode_fn,
            num_workers=num_workers,
            max_queue_size=num_prefetch,
        )
        batches = iter(loader)
    else:
        batches = iter(transformed_dataset)

    if num_batches_per_epoch is None:
        return batches
    else:
        return IterableSlice(batches, num_batches_per_epoch)