Exemplo n.º 1
0
def test_get_len():
    assert get_len(DataLoader(RandomDataset(1, 1))) == 1

    value = get_len(DataLoader(RandomIterableDataset(1, 1)))

    assert isinstance(value, float)
    assert value == float("inf")
Exemplo n.º 2
0
def test_combined_data_loader_with_max_size_cycle_and_ddp(accelerator, replace_sampler_ddp):
    """This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader
    with ddp and `max_size_cycle` mode."""
    trainer = Trainer(strategy="ddp", accelerator=accelerator, devices=2, replace_sampler_ddp=replace_sampler_ddp)

    dataloader = CombinedLoader(
        {"a": DataLoader(RandomDataset(32, 8), batch_size=1), "b": DataLoader(RandomDataset(32, 8), batch_size=1)},
    )
    dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False)
    assert len(dataloader) == 4 if replace_sampler_ddp else 8

    for a_length in [6, 8, 10]:
        dataloader = CombinedLoader(
            {
                "a": DataLoader(range(a_length), batch_size=1),
                "b": DataLoader(range(8), batch_size=1),
            },
            mode="max_size_cycle",
        )

        length = max(a_length, 8)
        assert len(dataloader) == length
        dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False)
        assert len(dataloader) == length // 2 if replace_sampler_ddp else length
        if replace_sampler_ddp:
            last_batch = list(dataloader)[-1]
            if a_length == 6:
                assert last_batch == {"a": torch.tensor([0]), "b": torch.tensor([6])}
            elif a_length == 8:
                assert last_batch == {"a": torch.tensor([6]), "b": torch.tensor([6])}
            elif a_length == 10:
                assert last_batch == {"a": torch.tensor([8]), "b": torch.tensor([0])}

    class InfiniteDataset(IterableDataset):
        def __iter__(self):
            while True:
                yield 1

    dataloader = CombinedLoader(
        {
            "a": DataLoader(InfiniteDataset(), batch_size=1),
            "b": DataLoader(range(8), batch_size=1),
        },
        mode="max_size_cycle",
    )
    assert get_len(dataloader) == float("inf")
    assert len(dataloader.loaders["b"].loader) == 8
    dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False)
    assert len(dataloader.loaders["b"].loader) == 4 if replace_sampler_ddp else 8
    assert get_len(dataloader) == float("inf")
Exemplo n.º 3
0
    def _apply_cycle_iterator_length(self) -> None:
        """When the model is `max_size_cycle`, compute the length across all ``CycleIterator`` and re-assign it to
        all dataloaders."""
        if self.mode != "max_size_cycle":
            return

        def set_len(cycle_iterator: CycleIterator, length: int) -> None:
            cycle_iterator.length = length

        all_lengths = apply_to_collection(self.loaders, CycleIterator,
                                          lambda c: get_len(c.loader))
        max_length = _nested_calc_num_data(all_lengths, max)
        apply_to_collection(self.loaders,
                            CycleIterator,
                            set_len,
                            length=max_length)
Exemplo n.º 4
0
    def _get_dataloader_init_kwargs(
        dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None
    ) -> Dict[str, Any]:
        if not isinstance(dataloader, DataLoader):
            raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")

        # get the dataloader instance attributes
        attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")}
        # not part of `vars`
        attrs["multiprocessing_context"] = dataloader.multiprocessing_context

        # get the dataloader instance `__init__` parameters
        params = dict(inspect.signature(dataloader.__init__).parameters)
        has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
        if has_variadic_kwargs:
            # if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`
            params.update(inspect.signature(DataLoader.__init__).parameters)
            del params["self"]

        # keep only the params whose default is different to the current attr value
        non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]}
        # add `dataset` as it might have been replaced with `*args`
        non_defaults.add("dataset")

        # kwargs to re-construct the dataloader
        dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
        dl_kwargs.update(
            TrainerDataLoadingMixin._dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode=mode)
        )

        required_args = {
            p.name
            for p in params.values()
            if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
            and p.default is p.empty
            and p.name not in dl_kwargs
        }
        # the dataloader has required args which we could not extract from the existing attributes
        if required_args:
            required_args = sorted(required_args)
            dataloader_cls_name = dataloader.__class__.__name__
            raise MisconfigurationException(
                f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
                "This would fail as some of the `__init__` arguments are not available as instance attributes. "
                f"The missing attributes are {required_args}. "
                f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or "
                "manually add the `DistributedSampler` as: "
                f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
            )

        if not has_variadic_kwargs:
            # the dataloader signature does not allow keyword arguments that need to be passed
            missing_kwargs = dl_kwargs.keys() - params.keys()
            if missing_kwargs:
                missing_kwargs = sorted(missing_kwargs)
                dataloader_cls_name = dataloader.__class__.__name__
                raise MisconfigurationException(
                    f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
                    "This would fail as it doesn't expose all its attributes in the `__init__` signature. "
                    f"The missing arguments are {missing_kwargs}. "
                    f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or "
                    "manually add the `DistributedSampler` as: "
                    f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
                )

        if isinstance(dl_kwargs["dataset"], IterableDataset):
            dl_kwargs["batch_sampler"] = None
            dl_kwargs["sampler"] = None

        if _fault_tolerant_training():
            dataset = dl_kwargs["dataset"]
            if isinstance(dataset, IterableDataset):
                # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
                dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"])
            elif get_len(dataset) != float("inf"):
                dl_kwargs["dataset"] = CaptureMapDataset(dataset=dl_kwargs["dataset"])
            else:
                raise MisconfigurationException(
                    "This shouldn't happen, please open an issue on Lightning Github repository."
                )

        return dl_kwargs