コード例 #1
0
def _wrap_with_capture_dataset(dataset: Dataset) -> Dataset:
    if isinstance(dataset, IterableDataset):
        # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
        return CaptureIterableDataset(dataset=dataset)
    if get_len(dataset) != float("inf"):
        return CaptureMapDataset(dataset=dataset)
    raise RuntimeError(
        "This shouldn't happen, please open an issue on Lightning Github repository."
    )
コード例 #2
0
def _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs: Dict) -> Dict:
    dataset = dl_kwargs["dataset"]
    if isinstance(dataset, IterableDataset):
        # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
        dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dataset)
    elif get_len(dataset) != float("inf"):
        dl_kwargs["dataset"] = CaptureMapDataset(dataset=dataset)
    else:
        raise MisconfigurationException("This shouldn't happen, please open an issue on Lightning Github repository.")
    return dl_kwargs
コード例 #3
0
    def _get_dataloader_init_kwargs(
            dataloader: DataLoader,
            sampler: Optional[Sampler],
            mode: Optional[RunningStage] = None) -> Dict[str, Any]:
        if not isinstance(dataloader, DataLoader):
            raise ValueError(
                f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`"
            )

        # get the dataloader instance attributes
        attrs = {
            k: v
            for k, v in vars(dataloader).items() if not k.startswith("_")
        }
        # not part of `vars`
        attrs["multiprocessing_context"] = dataloader.multiprocessing_context

        # get the dataloader instance `__init__` parameters
        params = dict(inspect.signature(dataloader.__init__).parameters)

        # keep only the params whose default is different to the current attr value
        non_defaults = {
            name
            for name, p in params.items()
            if name in attrs and p.default != attrs[name]
        }
        # add `dataset` as it might have been replaced with `*args`
        non_defaults.add("dataset")

        # kwargs to re-construct the dataloader
        dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
        dl_kwargs.update(
            TrainerDataLoadingMixin._resolve_batch_sampler(dataloader,
                                                           sampler,
                                                           mode=mode))

        required_args = {
            p.name
            for p in params.values()
            if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
            and p.default is p.empty and p.name not in dl_kwargs
        }
        # the dataloader has required args which we could not extract from the existing attributes
        if required_args:
            required_args = sorted(required_args)
            dataloader_cls_name = dataloader.__class__.__name__
            raise MisconfigurationException(
                f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
                "This would fail as some of the `__init__` arguments are not available as instance attributes. "
                f"The missing attributes are {required_args}. "
                f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or "
                "manually add the `DistributedSampler` as: "
                f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
            )

        has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD
                                  for p in params.values())
        if not has_variadic_kwargs:
            # the dataloader signature does not allow keyword arguments that need to be passed
            missing_kwargs = dl_kwargs.keys() - params.keys()
            if missing_kwargs:
                missing_kwargs = sorted(missing_kwargs)
                dataloader_cls_name = dataloader.__class__.__name__
                raise MisconfigurationException(
                    f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
                    "This would fail as it doesn't expose all its attributes in the `__init__` signature. "
                    f"The missing arguments are {missing_kwargs}. "
                    f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or "
                    "manually add the `DistributedSampler` as: "
                    f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
                )

        if isinstance(dl_kwargs["dataset"], IterableDataset):
            dl_kwargs["batch_sampler"] = None
            dl_kwargs["sampler"] = None

        if _fault_tolerant_training():
            if isinstance(dl_kwargs["dataset"], IterableDataset):
                # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
                dl_kwargs["dataset"] = CaptureIterableDataset(
                    dataset=dl_kwargs["dataset"])
            elif len(dl_kwargs["dataset"]):
                dl_kwargs["dataset"] = CaptureMapDataset(
                    dataset=dl_kwargs["dataset"])
            else:
                raise MisconfigurationException(
                    "This shouldn't happen, please open an issue on Lightning Github repository."
                )

        return dl_kwargs
コード例 #4
0
 def create_dataset_sampler():
     dset = CaptureMapDataset(dataset_class(16, 8))
     random_sampler = RandomSampler(dset, generator=torch.Generator())
     return dset, random_sampler