def TrainDataLoader( dataset: Dataset, *, transform: Transformation, batch_size: int, stack_fn: Callable, num_batches_per_epoch: Optional[int] = None, num_workers: Optional[int] = None, num_prefetch: Optional[int] = None, shuffle_buffer_length: Optional[int] = None, decode_fn: Callable = lambda x: x, ): transformed_dataset = TransformedDataset(Cyclic(dataset), transform, is_train=True) data_iterable = (PseudoShuffled( transformed_dataset, shuffle_buffer_length=shuffle_buffer_length) if shuffle_buffer_length is not None else transformed_dataset) data_loader = DataLoader( data_iterable=data_iterable, batch_size=batch_size, stack_fn=stack_fn, num_workers=num_workers, num_prefetch=num_prefetch, decode_fn=decode_fn, ) return (iter(data_loader) if num_batches_per_epoch is None else IterableSlice(iter(data_loader), num_batches_per_epoch))
def create_training_data_loader( self, data: Dataset, module: DeepARLightningModule, shuffle_buffer_length: Optional[int] = None, **kwargs, ) -> Iterable: transformation = self._create_instance_splitter( module, "training") + SelectFields(TRAINING_INPUT_NAMES) training_instances = transformation.apply( Cyclic(data) if shuffle_buffer_length is None else PseudoShuffled( Cyclic(data), shuffle_buffer_length=shuffle_buffer_length)) return IterableSlice( iter( DataLoader( IterableDataset(training_instances), batch_size=self.batch_size, **kwargs, )), self.num_batches_per_epoch, )
], ) def test_pseudo_shuffled(data: Iterable) -> None: list_data = list(data) shuffled_iter = PseudoShuffled(iter(list_data), shuffle_buffer_length=5) shuffled_data = list(shuffled_iter) assert len(shuffled_data) == len(list_data) assert all(d in shuffled_data for d in list_data) @pytest.mark.parametrize( "data, expected_elements_per_iteration", [ (Cached(range(4)), (list(range(4)), ) * 5), (batcher(range(10), 3), ([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]], [])), (IterableSlice(range(10), 3), ([0, 1, 2], ) * 5), ( IterableSlice(iter(range(10)), 3), ([0, 1, 2], [3, 4, 5], [6, 7, 8], [9], []), ), ( IterableSlice(iter(Cyclic(range(5))), 3), ([0, 1, 2], [3, 4, 0], [1, 2, 3], [4, 0, 1]), ), ], ) def test_iterate_multiple_times(data: Iterable, expected_elements_per_iteration: Tuple[List]): for expected_elements in expected_elements_per_iteration: assert list(data) == expected_elements
def TrainDataLoader( dataset: Dataset, *, transform: Transformation = Identity(), batch_size: int, stack_fn: Callable, num_batches_per_epoch: Optional[int] = None, num_prefetch: Optional[int] = None, num_workers: Optional[int] = None, shuffle_buffer_length: Optional[int] = None, decode_fn: Callable = lambda x: x, ): """Construct an iterator of batches for training purposes. This function wraps around ``DataLoader`` to offer training-specific behaviour and options, as follows: 1. The provided dataset is iterated cyclically, so that one can go over it multiple times in a single epoch. 2. A transformation must be provided, that is lazily applied as the dataset is being iterated; this is useful e.g. to slice random instances of fixed length out of each time series in the dataset. 3. The resulting batches can be iterated in a pseudo-shuffled order. The returned object is a stateful iterator, whose length is either ``num_batches_per_epoch`` (if not ``None``) or infinite (otherwise). Parameters ---------- dataset Data to iterate over. transform Transformation to be lazily applied as data is being iterated. The transformation is applied in "training mode" (``is_train=True``). batch_size Number of entries to include in a batch. stack_fn Function to use to stack data entries into batches. This can be used to set a specific array type or computing device the arrays should end up onto (CPU, GPU). num_batches_per_epoch Length of the iterator. If ``None``, then the iterator is endless. num_workers Number of worker processes to use. Default: None. num_prefetch Sets the length of the queue of batches being produced by worker processes. (Only meaningful when ``num_workers is not None``). shuffle_buffer_length Size of the buffer used for shuffling. Default: None, in which case no shuffling occurs. decode_fn A function called on each batch after it's been taken out of the queue. (Only meaningful when ``num_workers is not None``). Returns ------- Iterator[DataBatch] An iterator of batches. """ dataset = Cyclic(dataset) if shuffle_buffer_length: dataset = PseudoShuffled(dataset, shuffle_buffer_length) transform += Batch(batch_size=batch_size) + AdhocTransform(stack_fn) transformed_dataset = transform.apply(dataset, is_train=True) if num_workers is not None: loader = MultiProcessLoader( transformed_dataset, decode_fn=decode_fn, num_workers=num_workers, max_queue_size=num_prefetch, ) batches = iter(loader) else: batches = iter(transformed_dataset) if num_batches_per_epoch is None: return batches else: return IterableSlice(batches, num_batches_per_epoch)