def InferenceDataLoader( dataset: Dataset, *, transform: Transformation = Identity(), batch_size: int, stack_fn: Callable, ): """Construct an iterator of batches for inference purposes. Parameters ---------- dataset Data to iterate over. transform Transformation to be lazily applied as data is being iterated. The transformation is applied in "inference mode" (``is_train=False``). batch_size Number of entries to include in a batch. stack_fn Function to use to stack data entries into batches. This can be used to set a specific array type or computing device the arrays should end up onto (CPU, GPU). Returns ------- Iterable[DataBatch] An iterable sequence of batches. """ transform += Batch(batch_size=batch_size) + AdhocTransform(stack_fn) return transform.apply(dataset, is_train=False)
def ValidationDataLoader( dataset: Dataset, *, transform: Transformation = Identity(), batch_size: int, stack_fn: Callable, num_prefetch: Optional[int] = None, num_workers: Optional[int] = None, decode_fn: Callable = lambda x: x, ): """ Construct an iterator of batches for validation purposes. Parameters ---------- dataset Data to iterate over. transform Transformation to be lazily applied as data is being iterated. The transformation is applied in "training mode" (``is_train=True``). batch_size Number of entries to include in a batch. stack_fn Function to use to stack data entries into batches. This can be used to set a specific array type or computing device the arrays should end up onto (CPU, GPU). num_workers Number of worker processes to use. Default: None. num_prefetch Sets the length of the queue of batches being produced by worker processes. (Only meaningful when ``num_workers is not None``). decode_fn A function called on each batch after it's been taken out of the queue. (Only meaningful when ``num_workers is not None``). Returns ------- Iterable[DataBatch] An iterable sequence of batches. """ transform += Batch(batch_size=batch_size) + AdhocTransform(stack_fn) transformed_dataset = transform.apply(dataset, is_train=True) if num_workers is None: return transformed_dataset return MultiProcessLoader( transformed_dataset, decode_fn=decode_fn, num_workers=num_workers, max_queue_size=num_prefetch, )
def TrainDataLoader( dataset: Dataset, *, transform: Transformation = Identity(), batch_size: int, stack_fn: Callable, num_batches_per_epoch: Optional[int] = None, num_prefetch: Optional[int] = None, num_workers: Optional[int] = None, shuffle_buffer_length: Optional[int] = None, decode_fn: Callable = lambda x: x, ): """Construct an iterator of batches for training purposes. This function wraps around ``DataLoader`` to offer training-specific behaviour and options, as follows: 1. The provided dataset is iterated cyclically, so that one can go over it multiple times in a single epoch. 2. A transformation must be provided, that is lazily applied as the dataset is being iterated; this is useful e.g. to slice random instances of fixed length out of each time series in the dataset. 3. The resulting batches can be iterated in a pseudo-shuffled order. The returned object is a stateful iterator, whose length is either ``num_batches_per_epoch`` (if not ``None``) or infinite (otherwise). Parameters ---------- dataset Data to iterate over. transform Transformation to be lazily applied as data is being iterated. The transformation is applied in "training mode" (``is_train=True``). batch_size Number of entries to include in a batch. stack_fn Function to use to stack data entries into batches. This can be used to set a specific array type or computing device the arrays should end up onto (CPU, GPU). num_batches_per_epoch Length of the iterator. If ``None``, then the iterator is endless. num_workers Number of worker processes to use. Default: None. num_prefetch Sets the length of the queue of batches being produced by worker processes. (Only meaningful when ``num_workers is not None``). shuffle_buffer_length Size of the buffer used for shuffling. Default: None, in which case no shuffling occurs. decode_fn A function called on each batch after it's been taken out of the queue. (Only meaningful when ``num_workers is not None``). Returns ------- Iterator[DataBatch] An iterator of batches. """ dataset = Cyclic(dataset) if shuffle_buffer_length: dataset = PseudoShuffled(dataset, shuffle_buffer_length) transform += Batch(batch_size=batch_size) + AdhocTransform(stack_fn) transformed_dataset = transform.apply(dataset, is_train=True) if num_workers is not None: loader = MultiProcessLoader( transformed_dataset, decode_fn=decode_fn, num_workers=num_workers, max_queue_size=num_prefetch, ) batches = iter(loader) else: batches = iter(transformed_dataset) if num_batches_per_epoch is None: return batches else: return IterableSlice(batches, num_batches_per_epoch)