Ejemplo n.º 1
0
    def _train_loader_specifics(self, dataset, loader_kwargs):
        sampler = loader_kwargs.get("sampler", None)
        # Shuffling should really only matter for the train stage. Shuffling
        # will also lead to more padding in batches if the order was otherwise
        # sorted by length.
        shuffle = loader_kwargs.get("shuffle", False)
        if shuffle and not self.distributed_launch:
            if sampler is not None:
                raise ValueError(
                    "Cannot specify both shuffle=True"
                    "and a sampler in loader_kwargs"
                )
            sampler = ReproducibleRandomSampler(dataset)
            self.train_sampler = sampler
            loader_kwargs["sampler"] = self.train_sampler
            # Delete the shuffle flag, since you cannot specify both a sampler and
            # shuffling:
            del loader_kwargs["shuffle"]

        # Possibly make a DistributedSampler or a wrapper for some other sampler
        if self.distributed_launch and not isinstance(dataset, IterableDataset):
            drop_last = loader_kwargs.get("drop_last", False)
            # num_replicas arg is equal to world_size
            # and retrieved automatically within
            # DistributedSampler obj.
            if sampler is not None:
                self.train_sampler = DistributedSamplerWrapper(
                    sampler,
                    rank=self.rank,
                    drop_last=drop_last,
                    shuffle=shuffle,
                )

                # with DistributedSamplerWrapper, one must disable shuffling for dataloader
                loader_kwargs["shuffle"] = False
                loader_kwargs["sampler"] = self.train_sampler
            elif loader_kwargs.get("batch_sampler") is None:
                # no sampler and batch-sampler
                self.train_sampler = DistributedSampler(
                    dataset, rank=self.rank, shuffle=False, drop_last=drop_last
                )

                # with DistributedSamplerWrapper, one must disable shuffling for dataloader
                loader_kwargs["shuffle"] = False
                loader_kwargs["sampler"] = self.train_sampler
            else:  # batch_sampler was specified
                self.train_sampler = DistributedSamplerWrapper(
                    loader_kwargs.get("batch_sampler", None),
                    rank=self.rank,
                    shuffle=False,
                )
                loader_kwargs["batch_sampler"] = self.train_sampler
        elif self.distributed_launch and isinstance(dataset, IterableDataset):
            logger.warning(
                "Cannot automatically solve distributed sampling "
                "for IterableDataset."
            )
        return loader_kwargs
Ejemplo n.º 2
0
def make_dataloader(dataset, **loader_kwargs):
    """Makes a basic DataLoader with SpeechBrain defaults.

    For DynamicItemDatasets (which return dicts), use
    PaddedBatch as the default collate_fn.

    Shuffling gets implemented by ReproducibleRandomSampler.

    If the Dataset is not an IterableDataset, the DataLoader
    is a SaveableDataLoader.

    Arguments
    ---------
    dataset : Dataset
        The dataset to make a DataLoader for.
    **loader_kwargs : dict
        Keyword args to DataLoader, see PyTorch DataLoader for
        options.

    Returns
    -------
    DataLoader
    """
    # PaddedBatch as default collation for DynamicItemDataset
    if "collate_fn" not in loader_kwargs and isinstance(
        dataset, DynamicItemDataset
    ):
        loader_kwargs["collate_fn"] = PaddedBatch
    # Reproducible random sampling
    if loader_kwargs.get("shuffle", False):
        if loader_kwargs.get("sampler") is not None:
            raise ValueError(
                "Cannot specify both shuffle=True and a "
                "sampler in loader_kwargs"
            )
        sampler = ReproducibleRandomSampler(dataset)
        loader_kwargs["sampler"] = sampler
        # Should delete shuffle because you can't set both Sampler and
        # shuffle
        # NOTE: the dict of loader options may get used elsewhere!
        # However, this del doesn't touch those because loader_kwargs comes
        # from a **kwargs dict.
        del loader_kwargs["shuffle"]
    # Create the loader
    if isinstance(dataset, IterableDataset):
        dataloader = DataLoader(dataset, **loader_kwargs)
    else:
        dataloader = SaveableDataLoader(dataset, **loader_kwargs)
    return dataloader
Ejemplo n.º 3
0
def test_ConcatDatasetBatchSampler():
    from torch.utils.data import TensorDataset, ConcatDataset, DataLoader
    from speechbrain.dataio.sampler import (
        ReproducibleRandomSampler,
        ConcatDatasetBatchSampler,
    )
    import numpy as np

    datasets = []
    for i in range(3):
        if i == 0:
            datasets.append(TensorDataset(torch.arange(i * 10, (i + 1) * 10)))
        else:
            datasets.append(TensorDataset(torch.arange(i * 6, (i + 1) * 6)))

    samplers = [ReproducibleRandomSampler(x) for x in datasets]
    dataset = ConcatDataset(datasets)
    loader = DataLoader(
        dataset, batch_sampler=ConcatDatasetBatchSampler(samplers, [1, 1, 1]),
    )

    concat_data = []

    for data in loader:
        concat_data.append([x.item() for x in data[0]])
    concat_data = np.array(concat_data)

    non_cat_data = []
    for i in range(len(samplers)):
        c_data = []
        loader = DataLoader(dataset.datasets[i], sampler=samplers[i],)

        for data in loader:
            c_data.append(data[0].item())

        non_cat_data.append(c_data)

    minlen = min([len(x) for x in non_cat_data])
    non_cat_data = [x[:minlen] for x in non_cat_data]
    non_cat_data = np.array(non_cat_data)
    np.testing.assert_array_equal(non_cat_data.T, concat_data)