def on_load_checkpoint(self, state_dict: Dict) -> None:
     # cache the dataloader state dict until the dataloader objects are available
     # dataset states are collected across all ranks
     dataloader_state_dict = state_dict.get("dataloader_state_dict", None)
     if not _fault_tolerant_training() or not dataloader_state_dict:
         return
     self._dataloader_state_dict = dataloader_state_dict[self.trainer.global_rank]
    def register_signal_handlers(self) -> None:
        self._original_handlers = self._get_current_signal_handlers()

        sigusr1_handlers: List[_HANDLER] = []
        sigterm_handlers: List[_HANDLER] = []

        if _fault_tolerant_training():
            sigterm_handlers.append(self.fault_tolerant_sigterm_handler_fn)

        environment = self.trainer._accelerator_connector.cluster_environment
        if isinstance(environment,
                      SLURMEnvironment) and environment.auto_requeue:
            log.info("SLURM auto-requeueing enabled. Setting signal handlers.")
            sigusr1_handlers.append(self.slurm_sigusr1_handler_fn)
            sigterm_handlers.append(self.sigterm_handler_fn)

        # signal.SIGUSR1 doesn't seem available on windows
        if not self._is_on_windows():
            if sigusr1_handlers and not self._has_already_handler(
                    signal.SIGUSR1):
                self._register_signal(signal.SIGUSR1,
                                      HandlersCompose(sigusr1_handlers))

            if sigterm_handlers and not self._has_already_handler(
                    signal.SIGTERM):
                self._register_signal(signal.SIGTERM,
                                      HandlersCompose(sigterm_handlers))
Ejemplo n.º 3
0
    def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Dict:
        """The state dict is determined by the state and progress of this loop and all its children.

        Args:
            destination: An existing dictionary to update with this loop's state. By default a new dictionary
                is returned.
            prefix: A prefix for each key in the state dictionary
        """
        if destination is None:
            destination = {}

        destination[prefix + "state_dict"] = self.on_save_checkpoint()

        # do not get the mode from `self.trainer` because it might not have been attached yet
        ft_enabled = _fault_tolerant_training()
        for k, v in self.__dict__.items():
            key = prefix + k
            if isinstance(v, BaseProgress):
                destination[key] = v.state_dict()
            elif isinstance(v, Loop):
                v.state_dict(destination, key + ".")
            elif ft_enabled and isinstance(v, _ResultCollection):
                # sync / unsync metrics
                v.sync()
                destination[key] = v.state_dict()
                v.unsync()

        return destination
Ejemplo n.º 4
0
        def _apply_patch_fn(loader: DataLoader, iterator: Iterator):
            if isinstance(loader, CycleIterator):
                loader = loader.loader
                # cycle_iterator = iterator
                iterator = iterator._loader_iter

            if isinstance(loader, DataLoader) and _fault_tolerant_training():
                loader._lightning_fetcher = self
                patch_dataloader_iterator(loader, iterator, self)
Ejemplo n.º 5
0
    def _prepare_dataloader(self,
                            dataloader: Any,
                            shuffle: Optional[bool] = None,
                            mode: Optional[RunningStage] = None) -> Any:
        """This function handles to following functionalities:

        - Injecting a `DistributedDataSampler` into the `DataLoader` if on a distributed environment
        - Wrapping the datasets and samplers into fault-tolerant components
        - Wrapping the dataloader based on strategy-specific logic
        """
        if isinstance(dataloader, CombinedLoader):
            # apply `_prepare_dataloader` on all the collection of loaders
            dataloader.loaders = apply_to_collection(
                dataloader.loaders, (DataLoader, CycleIterator),
                self._prepare_dataloader,
                shuffle,
                mode=mode)
            # the length need to recomputed across all dataloaders in case of special behavior.
            dataloader._apply_cycle_iterator_length()
            return dataloader

        # don't do anything if it's not a dataloader
        if not isinstance(dataloader, (DataLoader, CycleIterator)):
            return dataloader

        cycle_iterator: Optional[CycleIterator] = None

        if isinstance(dataloader, CycleIterator):
            cycle_iterator = dataloader
            dataloader = dataloader.loader

        if (_fault_tolerant_training()  # injects components to track the state
                or self._requires_distributed_sampler(
                    dataloader)  # sets the distributed sampler
                or mode ==
                RunningStage.PREDICTING  # to track indices for the predictions
                # IPUs use a custom `poptorch.DataLoader` which we might need to convert to
                or isinstance(self.trainer.accelerator, IPUAccelerator)):
            if shuffle is None:
                # for training, set to True always
                # for evaluation, decide based on existing sampler
                shuffle = True if mode == RunningStage.TRAINING else _is_dataloader_shuffled(
                    dataloader)

            sampler = self._resolve_sampler(dataloader,
                                            shuffle=shuffle,
                                            mode=mode)
            dataloader = _update_dataloader(dataloader, sampler, mode=mode)

        dataloader = self.trainer.strategy.process_dataloader(dataloader)

        if cycle_iterator is not None:
            cycle_iterator.loader = dataloader
            return cycle_iterator

        return dataloader
Ejemplo n.º 6
0
    def _dataloader_init_kwargs_resolve_sampler(
        dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None
    ) -> Dict[str, Any]:
        """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for
        its re-instantiation.

        If the dataloader is being used for prediction, the sampler will be wrapped into an `IndexBatchSamplerWrapper`,
        so Lightning can keep track of its indices. If fault tolerant training is enabled, the sampler will be wrapped
        into a `FastForwardSampler`.
        """
        batch_sampler = getattr(dataloader, "batch_sampler")
        is_predicting = mode == RunningStage.PREDICTING
        # checking the batch sampler type is different than PyTorch default.
        if batch_sampler is not None and (type(batch_sampler) is not BatchSampler or is_predicting):
            batch_sampler = type(batch_sampler)(
                sampler,
                batch_size=batch_sampler.batch_size,
                drop_last=(False if is_predicting else batch_sampler.drop_last),
            )
            if is_predicting:
                batch_sampler = IndexBatchSamplerWrapper(batch_sampler)

            if _fault_tolerant_training():
                fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler)
                fast_forward_sampler.setup(dataloader_batch_size=1)

            return {
                "sampler": None,
                "shuffle": False,
                "batch_sampler": batch_sampler,
                "batch_size": 1,
                "drop_last": False,
            }

        if _fault_tolerant_training():
            fast_forward_sampler = sampler = FastForwardSampler(sampler)
            fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size)

        return {"sampler": sampler, "shuffle": False, "batch_sampler": None}
Ejemplo n.º 7
0
    def _resolve_batch_sampler(
            dataloader: DataLoader,
            sampler: Optional[Sampler],
            mode: Optional[RunningStage] = None) -> Dict[str, Any]:
        batch_sampler = getattr(dataloader, "batch_sampler")
        is_predicting = mode == RunningStage.PREDICTING
        # checking the batch sampler type is different than PyTorch default.
        if (batch_sampler is not None
                and type(batch_sampler) is not BatchSampler) or is_predicting:
            batch_sampler = type(batch_sampler)(
                sampler,
                batch_size=batch_sampler.batch_size,
                drop_last=(False
                           if is_predicting else batch_sampler.drop_last),
            )
            if is_predicting:
                batch_sampler = IndexBatchSamplerWrapper(batch_sampler)

            if _fault_tolerant_training():
                fast_forward_sampler = batch_sampler = FastForwardSampler(
                    batch_sampler)
                fast_forward_sampler.setup(dataloader_batch_size=1)

            return {
                "sampler": None,
                "shuffle": False,
                "batch_sampler": batch_sampler,
                "batch_size": 1,
                "drop_last": False,
            }

        if _fault_tolerant_training():
            fast_forward_sampler = sampler = FastForwardSampler(sampler)
            fast_forward_sampler.setup(
                dataloader_batch_size=dataloader.batch_size)

        return {"sampler": sampler, "shuffle": False, "batch_sampler": None}
Ejemplo n.º 8
0
 def next_fn(iterator: Iterator):
     batch = next(iterator)
     if not _fault_tolerant_training():
         return batch
     # when fault tolerant is enabled, the iterator will return
     # `FastForwardSampler` state_dict metadata
     # along side with the user data.
     # the metadata are extracted and store directly on the iterator
     # to simplify the collection on `state_dict` call.
     batch, samplers_state_dict = CaptureIterableDataset.extract_samplers_state_dict_from_batch(
         batch)
     # store the `sampler_state_dict` on the iterator
     CaptureIterableDataset.store_samplers_state_dict(
         iterator, samplers_state_dict)
     return batch
Ejemplo n.º 9
0
    def state_dict(self, has_completed: bool = False) -> Dict:
        """The state dict includes all states from wrapped dataloaders and their samplers through the
        ``CaptureIterableDataset`` and fast-forward samplers.

        Args:
            has_completed: whether the current state of data fetching is considered completed or not. If it is, the
                current state gets returned, otherwise the previously cached state.
        """
        if not _fault_tolerant_training() or self._iterator is None:
            return {}

        return apply_to_collection(
            self._iterator.loader_iters,
            Iterator,
            self._state_dict_fn,
            has_completed=has_completed,
        )
Ejemplo n.º 10
0
    def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]:
        metrics = ([
            m for m in self.trainer.lightning_module.modules()
            if isinstance(m, Metric)
        ] if _fault_tolerant_training() else [])

        for metric in metrics:
            metric.persistent(True)
            metric.sync()

        state_dict = self.trainer.strategy.lightning_module_state_dict()

        for metric in metrics:
            # sync can be a no-op (e.g. on cpu) so `unsync` would raise a user error exception if we don't check
            if metric._is_synced:
                metric.unsync()

        return state_dict
    def register_signal_handlers(self) -> None:
        sigusr1_handlers: List[Callable] = []
        sigterm_handlers: List[Callable] = []

        if _fault_tolerant_training():
            sigusr1_handlers.append(self.fault_tolerant_sigusr1_handler_fn)

        if self._is_on_slurm():
            log.info("Set SLURM handle signals.")
            sigusr1_handlers.append(self.slurm_sigusr1_handler_fn)
            sigterm_handlers.append(self.sigterm_handler_fn)

        # signal.SIGUSR1 doesn't seem available on windows
        if not self._is_on_windows():
            if not self._has_already_handler(signal.SIGUSR1):
                signal.signal(signal.SIGUSR1,
                              HandlersCompose(sigusr1_handlers))

            if not self._has_already_handler(signal.SIGTERM):
                signal.signal(signal.SIGTERM,
                              HandlersCompose(sigterm_handlers))
    def prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[RunningStage] = None) -> Any:
        """This function handles to following functionalities:

        - Injecting a `DistributedDataSampler` into the `DataLoader` if on a distributed environment
        - Wrapping the datasets and samplers into fault-tolerant components
        """
        if isinstance(dataloader, CombinedLoader):
            # apply `prepare_dataloader` on all the collection of loaders
            dataloader.loaders = apply_to_collection(
                dataloader.loaders, (DataLoader, CycleIterator), self.prepare_dataloader, shuffle, mode=mode
            )
            # the length need to recomputed across all dataloaders in case of special behavior.
            dataloader._apply_cycle_iterator_length()
            return dataloader

        # don't do anything if it's not a dataloader
        if not isinstance(dataloader, (DataLoader, CycleIterator)):
            return dataloader

        cycle_iterator: Optional[CycleIterator] = None

        if isinstance(dataloader, CycleIterator):
            cycle_iterator = dataloader
            dataloader = dataloader.loader

        if (
            _fault_tolerant_training()  # injects components to track the state
            or self._requires_distributed_sampler(dataloader)  # sets the distributed sampler
            or mode == RunningStage.PREDICTING  # to track indices for the predictions
            or self._accelerator_connector.use_ipu  # IPUs use a custom `DataLoader`
        ):
            sampler = self._resolve_sampler(dataloader, shuffle=shuffle, mode=mode)
            dataloader = _update_dataloader(dataloader, sampler, mode=mode)

        if cycle_iterator is not None:
            cycle_iterator.loader = dataloader
            return cycle_iterator

        return dataloader
Ejemplo n.º 13
0
    def _get_cache(result_metric: _ResultMetric,
                   on_step: bool) -> Optional[Tensor]:
        cache = None
        if on_step and result_metric.meta.on_step:
            cache = result_metric._forward_cache
        elif not on_step and result_metric.meta.on_epoch:
            if result_metric._computed is None:
                should = result_metric.meta.sync.should
                if not result_metric.meta.sync.should and distributed_available(
                ):
                    # ensure sync happens for FT since during a failure, the metrics are synced and saved to the
                    # checkpoint, so during restart, metrics on rank 0 are from the accumulated ones from the previous
                    # run, and on other ranks, they are 0. So we need to make sure they are synced in further training
                    # to ensure correct calculation.
                    if _fault_tolerant_training():
                        result_metric.meta.sync.should = True
                    else:
                        warning_cache.warn(
                            f"It is recommended to use `self.log({result_metric.meta.name!r}, ..., sync_dist=True)`"
                            " when logging on epoch level in distributed setting to accumulate the metric across"
                            " devices.",
                            category=PossibleUserWarning,
                        )
                result_metric.compute()
                result_metric.meta.sync.should = should

            cache = result_metric._computed

        if cache is not None:
            if not isinstance(cache, Tensor):
                raise ValueError(
                    f"The `.compute()` return of the metric logged as {result_metric.meta.name!r} must be a tensor."
                    f" Found {cache}")
            if not result_metric.meta.enable_graph:
                return cache.detach()

        return cache
Ejemplo n.º 14
0
def reload_dataloader_state_dict(dataloader: DataLoader,
                                 state_dict: Dict[str, Any]) -> None:
    """Utility to reload state_dict within dataloader for fault tolerance."""

    if not _fault_tolerant_training():
        return

    dataset = dataloader.dataset

    if isinstance(dataset, CaptureMapDataset):
        iterator_state = state_dict["state"][0]

        if not isinstance(iterator_state, IteratorState):
            iterator_state = IteratorState.from_state_dict(iterator_state)

        # reload sampler state
        ff_sampler = _find_fast_forward_samplers(dataloader)
        ff_sampler.load_state_dict(iterator_state.sampler_state)

        # reload dataset state
        dataset.load_state_dict(
            iterator_state.dataset_state,
            latest_worker_id=state_dict["latest_worker_id"],
            num_workers=iterator_state.num_workers,
        )

    elif isinstance(dataset, CaptureIterableDataset):
        dataset.load_state_dict({
            sampler_name: state[0]["sampler_state"]
            for sampler_name, state in state_dict["state"].items()
        })

    else:
        raise MisconfigurationException(
            "This shouldn't happen. Please, open an issue on PyTorch Lightning Github."
        )
    def dump_checkpoint(self, weights_only: bool = False) -> dict:
        """Creating a model checkpoint dictionary object from various component states.
        Args:
            weights_only: saving model weights only
        Return:
            structured dictionary: {
                'epoch':                     training epoch
                'global_step':               training global step
                'pytorch-lightning_version': The version of PyTorch Lightning that produced this checkpoint
                'callbacks':                 "callback specific state"[] # if not weights_only
                'optimizer_states':          "PT optim's state_dict"[]   # if not weights_only
                'lr_schedulers':             "PT sched's state_dict"[]   # if not weights_only
                'native_amp_scaling_state':  PT amp's state_dict         # if not weights_only and use native amp
                'amp_scaling_state':         Apex's state_dict           # if not weights_only and use apex amp
                'state_dict':                Model's state_dict (e.g. network weights)
                CHECKPOINT_HYPER_PARAMS_NAME:
                CHECKPOINT_HYPER_PARAMS_KEY:
                CHECKPOINT_HYPER_PARAMS_TYPE:
                something_cool_i_want_to_save: anything you define through model.on_save_checkpoint
                LightningDataModule.__class__.__name__: pl DataModule's state
            }
        """

        # dump epoch/global_step/pytorch-lightning_version
        current_epoch = self.trainer.current_epoch
        global_step = self.trainer.global_step
        has_reached_max_steps = self.trainer.max_steps and self.trainer.max_steps <= global_step

        global_step += 1
        if not has_reached_max_steps:
            current_epoch += 1

        model = self.trainer.lightning_module

        checkpoint = {
            "epoch": current_epoch,
            "global_step": global_step,
            "pytorch-lightning_version": pl.__version__,
            "state_dict": self._get_lightning_module_state_dict(),
        }
        if _fault_tolerant_training():
            checkpoint["loops"] = self._get_loops_state_dict()

        if not weights_only:
            # dump callbacks
            checkpoint["callbacks"] = self.trainer.on_save_checkpoint(
                checkpoint)

            optimizer_states = []
            for i, optimizer in enumerate(self.trainer.optimizers):
                # Rely on accelerator to dump optimizer state
                optimizer_state = self.trainer.accelerator.optimizer_state(
                    optimizer)
                optimizer_states.append(optimizer_state)

            checkpoint["optimizer_states"] = optimizer_states

            # dump lr schedulers
            lr_schedulers = []
            for scheduler in self.trainer.lr_schedulers:
                lr_schedulers.append(scheduler["scheduler"].state_dict())
            checkpoint["lr_schedulers"] = lr_schedulers

            self.trainer.precision_plugin.on_save_checkpoint(checkpoint)

        # dump hyper-parameters
        if model.hparams:
            if hasattr(model, "_hparams_name"):
                checkpoint[pl.LightningModule.
                           CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
            # dump arguments
            if _OMEGACONF_AVAILABLE and isinstance(model.hparams, Container):
                checkpoint[pl.LightningModule.
                           CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
                checkpoint[
                    pl.LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(
                        model.hparams)
            else:
                checkpoint[
                    pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(
                        model.hparams)

        # give the model a chance to dump a few things
        model.on_save_checkpoint(checkpoint)
        if self.trainer.datamodule is not None:
            self.trainer.datamodule.on_save_checkpoint(checkpoint)

        return checkpoint
Ejemplo n.º 16
0
def result_collection_reload(**kwargs):
    """This test is going to validate ResultCollection is properly being reload and final accumulation with Fault
    Tolerant Training is correct."""

    if not _fault_tolerant_training():
        pytest.skip("Fault tolerant not available")

    num_processes = kwargs.get("gpus", 1)

    class CustomException(Exception):
        pass

    class ExtendedBoringModel(BoringModel):
        def __init__(self):
            super().__init__()
            self.breaking_batch_idx = 3
            self.has_validated_sum = False
            self.dummy_metric = DummyMeanMetric()

        @property
        def results(self):
            return self.trainer.fit_loop._results

        def training_step(self, batch, batch_idx):

            # In the training step, we will accumulate metrics using batch_idx from 0 to 4
            # Without failure, we would expect to get `total=10 * world_size` and `num_batches=5 * world_size`
            # Therefore, compute on `epoch_end` should provide 2 as `10 / 5`.
            # However, below we will simulate a failure on `batch_idx=3`.

            if self.trainer.fit_loop.restarting:
                self.log("tracking", batch_idx, on_step=True, on_epoch=True)
                self.log("tracking_2",
                         batch_idx,
                         on_step=True,
                         on_epoch=True,
                         sync_dist=True)

                self.dummy_metric(batch_idx)
                self.log("tracking_metric",
                         self.dummy_metric,
                         on_step=True,
                         on_epoch=True)

                value = self.results["training_step.tracking_metric"].value
                value_2 = self.results["training_step.tracking"].value

                # On failure, the Metric states are being accumulated on rank 0 and zeroed-out on other ranks.
                # The shift indicates we failed while the state was `shift=sign(is_global_zero > 0) * [0..3]`
                shift = 0
                if num_processes == 2:
                    shift = 3 if self.trainer.is_global_zero else -3
                expected = sum(range(batch_idx + 1)) + shift
                assert expected == value == value_2
            else:
                if batch_idx == self.breaking_batch_idx:
                    # simulate failure mid epoch
                    raise CustomException

                self.log("tracking", batch_idx, on_step=True, on_epoch=True)
                self.log("tracking_2",
                         batch_idx,
                         on_step=True,
                         on_epoch=True,
                         sync_dist=True)

                self.dummy_metric(batch_idx)
                self.log("tracking_metric",
                         self.dummy_metric,
                         on_step=True,
                         on_epoch=True)

                value = self.results["training_step.tracking"].value
                assert value == sum(range(batch_idx + 1))

                value = self.results["training_step.tracking_2"]
                assert value == sum(range(batch_idx + 1))

            return super().training_step(batch, batch_idx)

        def on_epoch_end(self) -> None:
            if self.trainer.fit_loop.restarting:
                total = sum(range(5)) * num_processes
                metrics = self.results.metrics(on_step=False)
                assert self.results["training_step.tracking"].value == total
                assert metrics["callback"][
                    "tracking"] == self.dummy_metric.compute() == 2
                assert self.results["training_step.tracking_2"].value == total
                assert metrics["callback"][
                    "tracking_2"] == self.dummy_metric.compute() == 2
                self.has_validated_sum = True

    model = ExtendedBoringModel()
    trainer_kwargs = {
        "max_epochs": 1,
        "limit_train_batches": 5,
        "limit_val_batches": 0
    }
    trainer_kwargs.update(kwargs)
    trainer = Trainer(**trainer_kwargs)

    with suppress(CustomException):
        trainer.fit(model)
    assert not model.has_validated_sum

    tmpdir = (trainer.training_type_plugin.broadcast(
        trainer_kwargs["default_root_dir"], 0)
              if num_processes >= 2 else trainer_kwargs["default_root_dir"])
    ckpt_path = os.path.join(tmpdir, ".pl_auto_save.ckpt")
    trainer_kwargs["resume_from_checkpoint"] = ckpt_path

    trainer = Trainer(**trainer_kwargs)
    trainer.fit(model)
    assert model.has_validated_sum
Ejemplo n.º 17
0
def test_fault_tolerant_not_supported():
    assert not _fault_tolerant_training()
Ejemplo n.º 18
0
        def _attach_data_fetcher_fn(loader: DataLoader) -> None:
            if isinstance(loader, CycleIterator):
                loader = loader.loader

            if isinstance(loader, DataLoader) and _fault_tolerant_training():
                loader._lightning_fetcher = self
Ejemplo n.º 19
0
    def reset_train_dataloader(self,
                               model: Optional["pl.LightningModule"] = None
                               ) -> None:
        """Resets the train dataloader and initialises required variables
        (number of batches, when to validate, etc.).

        Args:
            model: The `LightningModule` if calling this outside of the trainer scope.
        """
        self.train_dataloader = self.request_dataloader(RunningStage.TRAINING,
                                                        model=model)

        if self.overfit_batches > 0:
            if hasattr(self.train_dataloader, "sampler") and isinstance(
                    self.train_dataloader.sampler, RandomSampler):
                rank_zero_warn(
                    "You requested to overfit but enabled training dataloader shuffling."
                    " We are turning off the training dataloader shuffling for you."
                )
                self.train_dataloader = self.replace_sampler(
                    self.train_dataloader,
                    SequentialSampler(self.train_dataloader.dataset),
                    mode=RunningStage.TRAINING)

        # debugging
        self.dev_debugger.track_load_dataloader_call(
            "train_dataloader", dataloaders=[self.train_dataloader])

        # automatically add samplers
        self.train_dataloader = apply_to_collection(self.train_dataloader,
                                                    DataLoader,
                                                    self.auto_add_sampler,
                                                    shuffle=True,
                                                    mode=RunningStage.TRAINING)

        # check the workers recursively
        apply_to_collection(self.train_dataloader, DataLoader,
                            self._worker_check, "train_dataloader")

        # add worker_init_fn for correct seeding in worker processes
        apply_to_collection(self.train_dataloader, DataLoader,
                            self.auto_add_worker_init_fn)

        # add collate_fn to collect metadata for fault tolerant training
        if _fault_tolerant_training():
            apply_to_collection(self.train_dataloader, DataLoader,
                                self._add_sampler_metadata_collate)

        # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
        self.train_dataloader = CombinedLoader(
            self.train_dataloader,
            self.data_connector.multiple_trainloader_mode)

        self.num_training_batches = len(self.train_dataloader) if has_len(
            self.train_dataloader) else float("inf")

        if isinstance(self.limit_train_batches,
                      int) or self.limit_train_batches == 0.0:
            self.num_training_batches = min(self.num_training_batches,
                                            int(self.limit_train_batches))
        elif self.num_training_batches != float("inf"):
            self.num_training_batches = int(self.num_training_batches *
                                            self.limit_train_batches)
        elif self.limit_train_batches != 1.0:
            raise MisconfigurationException(
                "When using an IterableDataset for `limit_train_batches`,"
                " `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies"
                " `num_training_batches` to use.")

        # determine when to check validation
        # if int passed in, val checks that often
        # otherwise, it checks in [0, 1.0] % range of a training epoch
        if isinstance(self.val_check_interval, int):
            self.val_check_batch = self.val_check_interval
            if self.val_check_batch > self.num_training_batches:
                raise ValueError(
                    f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
                    f"to the number of the training batches ({self.num_training_batches}). "
                    "If you want to disable validation set `limit_val_batches` to 0.0 instead."
                )
        else:
            if not has_len(self.train_dataloader):
                if self.val_check_interval == 1.0:
                    self.val_check_batch = float("inf")
                else:
                    raise MisconfigurationException(
                        "When using an IterableDataset for `train_dataloader`,"
                        " `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies"
                        " checking validation every k training batches.")
            else:
                self.val_check_batch = int(self.num_training_batches *
                                           self.val_check_interval)
                self.val_check_batch = max(1, self.val_check_batch)

        if self.logger and self.num_training_batches < self.log_every_n_steps:
            rank_zero_warn(
                f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"
                f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if"
                " you want to see logs for the training epoch.")
Ejemplo n.º 20
0
    def _get_dataloader_init_kwargs(
            dataloader: DataLoader,
            sampler: Optional[Sampler],
            mode: Optional[RunningStage] = None) -> Dict[str, Any]:
        if not isinstance(dataloader, DataLoader):
            raise ValueError(
                f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`"
            )

        # get the dataloader instance attributes
        attrs = {
            k: v
            for k, v in vars(dataloader).items() if not k.startswith("_")
        }
        # not part of `vars`
        attrs["multiprocessing_context"] = dataloader.multiprocessing_context

        # get the dataloader instance `__init__` parameters
        params = dict(inspect.signature(dataloader.__init__).parameters)

        # keep only the params whose default is different to the current attr value
        non_defaults = {
            name
            for name, p in params.items()
            if name in attrs and p.default != attrs[name]
        }
        # add `dataset` as it might have been replaced with `*args`
        non_defaults.add("dataset")

        # kwargs to re-construct the dataloader
        dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
        dl_kwargs.update(
            TrainerDataLoadingMixin._resolve_batch_sampler(dataloader,
                                                           sampler,
                                                           mode=mode))

        required_args = {
            p.name
            for p in params.values()
            if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
            and p.default is p.empty and p.name not in dl_kwargs
        }
        # the dataloader has required args which we could not extract from the existing attributes
        if required_args:
            required_args = sorted(required_args)
            dataloader_cls_name = dataloader.__class__.__name__
            raise MisconfigurationException(
                f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
                "This would fail as some of the `__init__` arguments are not available as instance attributes. "
                f"The missing attributes are {required_args}. "
                f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or "
                "manually add the `DistributedSampler` as: "
                f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
            )

        has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD
                                  for p in params.values())
        if not has_variadic_kwargs:
            # the dataloader signature does not allow keyword arguments that need to be passed
            missing_kwargs = dl_kwargs.keys() - params.keys()
            if missing_kwargs:
                missing_kwargs = sorted(missing_kwargs)
                dataloader_cls_name = dataloader.__class__.__name__
                raise MisconfigurationException(
                    f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
                    "This would fail as it doesn't expose all its attributes in the `__init__` signature. "
                    f"The missing arguments are {missing_kwargs}. "
                    f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or "
                    "manually add the `DistributedSampler` as: "
                    f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
                )

        if isinstance(dl_kwargs["dataset"], IterableDataset):
            dl_kwargs["batch_sampler"] = None
            dl_kwargs["sampler"] = None

        if _fault_tolerant_training():
            if isinstance(dl_kwargs["dataset"], IterableDataset):
                # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
                dl_kwargs["dataset"] = CaptureIterableDataset(
                    dataset=dl_kwargs["dataset"])
            elif len(dl_kwargs["dataset"]):
                dl_kwargs["dataset"] = CaptureMapDataset(
                    dataset=dl_kwargs["dataset"])
            else:
                raise MisconfigurationException(
                    "This shouldn't happen, please open an issue on Lightning Github repository."
                )

        return dl_kwargs