Ejemplo n.º 1
0
def _patch_dataloader_get_iterators() -> None:
    """This function is used to replace the DataLoader iterator by their stateful version."""
    if not _FaultTolerantMode.detect_current_mode().is_manual:
        return
    if not hasattr(DataLoader, "_ori_get_iterator"):
        DataLoader._ori_get_iterator = DataLoader._get_iterator
    DataLoader._get_iterator = _get_iterator
Ejemplo n.º 2
0
    def state_dict(self,
                   destination: Optional[Dict] = None,
                   prefix: str = "") -> Dict:
        """The state dict is determined by the state and progress of this loop and all its children.

        Args:
            destination: An existing dictionary to update with this loop's state. By default a new dictionary
                is returned.
            prefix: A prefix for each key in the state dictionary
        """
        if destination is None:
            destination = {}

        destination[prefix + "state_dict"] = self.on_save_checkpoint()

        # do not get the mode from `self.trainer` because it might not have been attached yet
        ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled
        for k, v in self.__dict__.items():
            key = prefix + k
            if ft_enabled and isinstance(v, BaseProgress):
                destination[key] = v.state_dict()
            elif isinstance(v, Loop):
                v.state_dict(destination, key + ".")
            elif isinstance(v, _ResultCollection):
                # sync / unsync metrics
                v.sync()
                destination[key] = v.state_dict()
                v.unsync()

        return destination
Ejemplo n.º 3
0
def _validate_fault_tolerant_automatic(
        dataloader: Iterable, stage: "pl.trainer.states.RunningStage") -> None:
    """This function is used to validate that Fault-tolerance is possible with the user data."""
    if not _FaultTolerantMode.detect_current_mode().is_automatic:
        return

    from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator

    if isinstance(dataloader, CombinedLoader):
        dataloaders = dataloader.loaders
    else:
        dataloaders = dataloader

    dl_loaders = []

    def flatten_dataloader(
            dataloader: Union[DataLoader, CycleIterator, Iterable]) -> None:
        nonlocal dl_loaders
        if isinstance(dataloader, CycleIterator):
            dataloader = dataloader.loader
        dl_loaders.append(dataloader)

    apply_to_collection(dataloaders, (DataLoader, CycleIterator),
                        flatten_dataloader)

    if len(dl_loaders
           ) > 1 and stage == pl.trainer.states.RunningStage.TRAINING:
        raise ValueError("Fault-tolerance supports only a single dataloader.")

    for dataloader in dl_loaders:
        validator_fn = (_validate_iterable_dataset if isinstance(
            dataloader.dataset, IterableDataset) else _validate_map_dataset)
        validator_fn(dataloader)
Ejemplo n.º 4
0
def patch_dataloader_iterator(
    dataloader: DataLoader,
    iterator: Iterator,
    data_fetcher: "pl.utilities.fetching.DataFetcher",
    num_batches_fetched: int = 0,
) -> None:
    """Patches the iterator of a PyTorch dataloader by injecting logic for fault-tolerant training when it is
    necessary to remove the sampler state dict from provided data batch.

    The custom data has this format:
    .. code-block:: python
        {
            "batch": ...,  # data returned by DataLoader
            "__pl_restart_meta": {
                "sampler0": {
                    0: {"current_iteration": ...},
                    1: {"current_iteration": ...},
                },
                "sampler1": ...,
            },
        }
    Each sampler in the worker process tracks the current iteration. We return all of them to the main process
    as part of the sample and then a special collate function :func:`_capture_metadata_collate`
    will extract the current iteration as part of the metadata returned by a custom batch.
    """

    if not _FaultTolerantMode.detect_current_mode().is_automatic:
        return

    assert isinstance(dataloader.dataset,
                      (CaptureMapDataset, CaptureIterableDataset))
    iterator._next_data = _next_data_wrapper(iterator._next_data, iterator,
                                             dataloader, num_batches_fetched,
                                             data_fetcher)
Ejemplo n.º 5
0
    def on_run_start(self) -> None:  # type: ignore[override]
        """Calls the ``on_train_start`` hook."""
        # reset train dataloader and val dataloader
        self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)

        ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled
        if not ft_enabled and self.restarting and self.trainer.num_training_batches not in (
                0, float("inf")):
            self.trainer.accumulate_grad_batches = self.trainer.accumulation_scheduler.get_accumulate_grad_batches(
                self.trainer.current_epoch)
            expected_steps = math.ceil(self.trainer.num_training_batches /
                                       self.trainer.accumulate_grad_batches)

            # global_step is incremented during checkpointing (#11555)
            if (self.trainer.global_step - 1) % expected_steps != 0:
                rank_zero_warn(
                    "You're resuming from a checkpoint that ended mid-epoch."
                    " Training will start from the beginning of the next epoch."
                    " This can cause unreliable results if further training is done,"
                    " consider using an end of epoch checkpoint or use fault-tolerant training"
                    " to restart as if training did not stop.")

        self._is_fresh_start_epoch = True
        self._results.to(device=self.trainer.lightning_module.device)
        self.trainer._call_callback_hooks("on_train_start")
        self.trainer._call_lightning_module_hook("on_train_start")
        self.trainer._call_strategy_hook("on_train_start")
Ejemplo n.º 6
0
def _add_capture_metadata_collate(dataloader: DataLoader) -> None:
    """Wrap default collate function to retrive captured dataset state dict when fault tolerant is enabled."""
    fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()
    collate_fn = dataloader.collate_fn
    if not fault_tolerant_mode.is_enabled or (
            isinstance(collate_fn, partial)
            and collate_fn.func is _capture_metadata_collate):
        return
    dataloader.collate_fn = partial(
        _capture_metadata_collate,
        dataset=dataloader.dataset,
        collate_fn=collate_fn,
        fault_tolerant_mode=fault_tolerant_mode,
    )
def _reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
    """Utility to reload state_dict within dataloader for fault tolerance."""

    fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()

    if not fault_tolerant_mode.is_enabled:
        return

    if fault_tolerant_mode.is_automatic:
        _reload_dataloader_state_dict_automatic(dataloader, state_dict)

    elif fault_tolerant_mode.is_manual:
        _reload_dataloader_state_dict_manual(dataloader, state_dict)

    else:
        raise MisconfigurationException("This shouldn't be happening. Please, open an issue.")
Ejemplo n.º 8
0
def _dataloader_init_kwargs_resolve_sampler(
        dataloader: DataLoader,
        sampler: Optional[Sampler],
        mode: Optional[RunningStage] = None) -> Dict[str, Any]:
    """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its
    re-instantiation.

    If the dataloader is being used for prediction, the sampler will be wrapped into an `IndexBatchSamplerWrapper`, so
    Lightning can keep track of its indices. If fault tolerant training is enabled, the sampler will be wrapped into a
    `FastForwardSampler`.
    """
    fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()
    batch_sampler = getattr(dataloader, "batch_sampler")
    is_predicting = mode == RunningStage.PREDICTING
    # checking the batch sampler type is different than PyTorch default.
    if batch_sampler is not None and (type(batch_sampler) is not BatchSampler
                                      or is_predicting):
        batch_sampler = type(batch_sampler)(
            sampler,
            batch_size=batch_sampler.batch_size,
            drop_last=(False if is_predicting else batch_sampler.drop_last),
        )
        if is_predicting:
            batch_sampler = IndexBatchSamplerWrapper(batch_sampler)

        if fault_tolerant_mode.is_automatic:
            fast_forward_sampler = batch_sampler = FastForwardSampler(
                batch_sampler)
            fast_forward_sampler.setup(dataloader_batch_size=1)

        return {
            "sampler": None,
            "shuffle": False,
            "batch_sampler": batch_sampler,
            "batch_size": 1,
            "drop_last": False,
        }

    if fault_tolerant_mode.is_automatic:
        fast_forward_sampler = sampler = FastForwardSampler(sampler)
        fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size)

    return {"sampler": sampler, "shuffle": False, "batch_sampler": None}
Ejemplo n.º 9
0
def _get_dataloader_init_kwargs(
        dataloader: DataLoader,
        sampler: Optional[Sampler],
        mode: Optional[RunningStage] = None) -> Dict[str, Any]:
    if not isinstance(dataloader, DataLoader):
        raise ValueError(
            f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`"
        )

    # get the dataloader instance attributes
    attrs = {
        k: v
        for k, v in vars(dataloader).items() if not k.startswith("_")
    }
    # not part of `vars`
    attrs["multiprocessing_context"] = dataloader.multiprocessing_context

    # get the dataloader instance `__init__` parameters
    params = dict(inspect.signature(dataloader.__init__).parameters)
    has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
    if has_variadic_kwargs:
        # if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`
        params.update(inspect.signature(DataLoader.__init__).parameters)
        del params["self"]

    # keep only the params whose default is different to the current attr value
    non_defaults = {
        name
        for name, p in params.items()
        if name in attrs and p.default != attrs[name]
    }
    # add `dataset` as it might have been replaced with `*args`
    non_defaults.add("dataset")

    # kwargs to re-construct the dataloader
    dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
    if isinstance(dl_kwargs["dataset"], IterableDataset):
        dl_kwargs["batch_sampler"] = None
        dl_kwargs["sampler"] = None
    else:
        dl_kwargs.update(
            _dataloader_init_kwargs_resolve_sampler(dataloader,
                                                    sampler,
                                                    mode=mode))

    required_args = {
        p.name
        for p in params.values()
        if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
        and p.default is p.empty and p.name not in dl_kwargs
    }
    # the dataloader has required args which we could not extract from the existing attributes
    if required_args:
        required_args = sorted(required_args)
        dataloader_cls_name = dataloader.__class__.__name__
        raise MisconfigurationException(
            f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
            "This would fail as some of the `__init__` arguments are not available as instance attributes. "
            f"The missing attributes are {required_args}. "
            f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or "
            "manually add the `DistributedSampler` as: "
            f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
        )

    if not has_variadic_kwargs:
        # the dataloader signature does not allow keyword arguments that need to be passed
        missing_kwargs = dl_kwargs.keys() - params.keys()
        if missing_kwargs:
            missing_kwargs = sorted(missing_kwargs)
            dataloader_cls_name = dataloader.__class__.__name__
            raise MisconfigurationException(
                f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
                "This would fail as it doesn't expose all its attributes in the `__init__` signature. "
                f"The missing arguments are {missing_kwargs}. "
                f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or "
                "manually add the `DistributedSampler` as: "
                f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
            )

    if _FaultTolerantMode.detect_current_mode().is_automatic:
        dl_kwargs = _apply_fault_tolerant_automatic_capture_dataset_wrapper(
            dl_kwargs)

    return dl_kwargs
Ejemplo n.º 10
0
def _fault_tolerant_training() -> bool:
    from pytorch_lightning.utilities.enums import _FaultTolerantMode

    return _FaultTolerantMode.detect_current_mode().is_enabled
Ejemplo n.º 11
0
def _get_dataloader_init_args_and_kwargs(
        dataloader: DataLoader,
        sampler: Optional[Sampler],
        mode: Optional[RunningStage] = None
) -> Tuple[Tuple[Any], Dict[str, Any]]:
    if not isinstance(dataloader, DataLoader):
        raise ValueError(
            f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`"
        )

    was_wrapped = hasattr(dataloader, "__pl_dl_args")
    if was_wrapped:
        dl_args = dataloader.__pl_dl_args
        dl_kwargs = dataloader.__pl_dl_kwargs
        arg_names = dataloader.__pl_dl_arg_names
        original_dataset = dataloader.__dataset  # we have this saved from _wrap_init
    else:
        # get the dataloader instance attributes
        attrs = {
            k: v
            for k, v in vars(dataloader).items() if not k.startswith("_")
        }
        # We cannot be 100% sure the class sets dataset argument. Let's set it to None to be safe
        # and hope we can get it from the instance attributes
        original_dataset = None
        # not part of `vars`
        attrs["multiprocessing_context"] = dataloader.multiprocessing_context
        arg_names = ()

    # get the dataloader instance `__init__` parameters
    params = dict(inspect.signature(dataloader.__init__).parameters)
    has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
    if has_variadic_kwargs:
        # if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`

        if was_wrapped:
            # if the dataloader was wrapped in a hook, only take arguments with default values
            # and assume user passes their kwargs correctly
            params.update({
                k: v
                for k, v in inspect.signature(
                    DataLoader.__init__).parameters.items()
                if v.default is not v.empty
            })
        else:
            params.update(inspect.signature(DataLoader.__init__).parameters)
            params.pop("self", None)

    if not was_wrapped:
        # keep only the params whose default is different to the current attr value
        non_defaults = {
            name
            for name, p in params.items()
            if name in attrs and p.default != attrs[name]
        }

        # add `dataset` as it might have been replaced with `*args`
        non_defaults.add("dataset")
        # kwargs to re-construct the dataloader
        dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
        dl_args = ()

    dataset = dl_kwargs.get("dataset", original_dataset)
    if isinstance(dataset, IterableDataset):
        dl_kwargs["batch_sampler"] = None
        dl_kwargs["sampler"] = None
    else:
        dl_kwargs.update(
            _dataloader_init_kwargs_resolve_sampler(dataloader,
                                                    sampler,
                                                    mode=mode))

    required_args = {
        p.name
        for p in params.values()
        if p.kind in (p.POSITIONAL_ONLY,
                      p.POSITIONAL_OR_KEYWORD) and p.default is p.empty
        and p.name not in dl_kwargs and p.name not in arg_names
    }
    # the dataloader has required args which we could not extract from the existing attributes
    if required_args:
        required_args = sorted(required_args)
        dataloader_cls_name = dataloader.__class__.__name__
        missing_args_message = ", ".join(f"`self.{arg_name}`"
                                         for arg_name in required_args)
        raise MisconfigurationException(
            f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. "
            "This would fail as some of the `__init__` arguments are not available as instance attributes. "
            f"The missing attributes are {required_args}. If you instantiate your `{dataloader_cls_name}` inside a "
            "`*_dataloader` hook of your module, we will do this for you."
            f" Otherwise, define {missing_args_message} inside your `__init__`."
        )

    if not has_variadic_kwargs:
        # the dataloader signature does not allow keyword arguments that need to be passed
        missing_kwargs = (set(dl_kwargs) | set(arg_names)) - params.keys()
        if missing_kwargs:
            missing_kwargs = sorted(missing_kwargs)
            dataloader_cls_name = dataloader.__class__.__name__
            raise MisconfigurationException(
                f"Trying to inject parameters into the `{dataloader_cls_name}` instance. "
                "This would fail as it doesn't expose all its attributes in the `__init__` signature. "
                f"The missing arguments are {missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` class, "
                "add the `__init__` arguments or allow passing `**kwargs`")

    if _FaultTolerantMode.detect_current_mode().is_automatic:
        dl_args, dl_kwargs = _apply_fault_tolerant_automatic_capture_dataset_wrapper(
            was_wrapped, arg_names, dl_args, dl_kwargs)

    return dl_args, dl_kwargs