Пример #1
0
    def _resolve_batch_sampler(dataloader, sampler, mode: Optional[RunningStage] = None) -> Dict[str, Any]:
        batch_sampler = getattr(dataloader, "batch_sampler")
        is_predicting = mode == RunningStage.PREDICTING
        # checking the batch sampler type is different than PyTorch default.
        if (batch_sampler is not None and type(batch_sampler) is not BatchSampler) or is_predicting:
            batch_sampler = type(batch_sampler)(
                sampler,
                batch_size=batch_sampler.batch_size,
                drop_last=(False if is_predicting else batch_sampler.drop_last),
            )
            if is_predicting:
                batch_sampler = IndexBatchSamplerWrapper(batch_sampler)

            if _fault_tolerant_enabled():
                fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler)
                fast_forward_sampler.setup(dataloader_batch_size=1)

            return {
                "sampler": None,
                "shuffle": False,
                "batch_sampler": batch_sampler,
                "batch_size": 1,
                "drop_last": False,
            }

        if _fault_tolerant_enabled():
            fast_forward_sampler = sampler = FastForwardSampler(sampler)
            fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size)

        return {"sampler": sampler, "shuffle": False, "batch_sampler": None}
Пример #2
0
 def next_fn(iterator: Iterator):
     batch = next(iterator)
     if not _fault_tolerant_enabled():
         return batch
     # when fault tolerant is enabled, the iterator will return
     # `FastForwardSampler` state_dict metadata
     # along side with the user data.
     # the metadata are extracted and store directly on the iterator
     # to simplify the collection on `state_dict` call.
     batch, samplers_state_dict = CaptureIterableDataset.extract_samplers_state_dict_from_batch(batch)
     # store the `sampler_state_dict` on the iterator
     CaptureIterableDataset.store_samplers_state_dict(iterator, samplers_state_dict)
     return batch
Пример #3
0
    def state_dict(self, num_batches_processed: int) -> Dict:
        """
        The state dict includes all states from wrapped dataloaders and their samplers through the
        ``CaptureIterableDataset`` and fast-forward samplers.

        Args:
            num_batches_processed: The number of batches processed so far, needed because the individual dataloaders
                may have already prefetched more batches by the time a state dict is requested.
        """
        if not _fault_tolerant_enabled():
            return DataLoaderDict()

        state_dict_fn = partial(self._state_dict_fn, num_batches_processed=num_batches_processed)

        return apply_to_collections(self.loaders, self._iterator.loader_iters, (Iterator, DataLoader), state_dict_fn)
    def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]:
        metrics = ([
            m for m in self.trainer.lightning_module.modules()
            if isinstance(m, Metric)
        ] if _fault_tolerant_enabled() else [])

        for metric in metrics:
            metric.persistent(True)
            metric.sync()

        state_dict = self.trainer.accelerator.lightning_module_state_dict()

        for metric in metrics:
            # sync can be a no-op (e.g. on cpu) so `unsync` would raise a user error exception if we don't check
            if metric._is_synced:
                metric.unsync()

        return state_dict
Пример #5
0
    def reset_train_dataloader(self, model: 'pl.LightningModule') -> None:
        """Resets the train dataloader and initialises required variables
        (number of batches, when to validate, etc.).

        Args:
            model: The current `LightningModule`
        """
        self.train_dataloader = self.request_dataloader(model, "train")

        if self.overfit_batches > 0:
            if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler):
                rank_zero_warn(
                    'You requested to overfit but enabled training dataloader shuffling.'
                    ' We are turning off the training dataloader shuffling for you.'
                )
                self.train_dataloader = self.replace_sampler(
                    self.train_dataloader, SequentialSampler(self.train_dataloader.dataset)
                )

        # debugging
        self.dev_debugger.track_load_dataloader_call('train_dataloader', dataloaders=[self.train_dataloader])

        # automatically add samplers
        self.train_dataloader = apply_to_collection(
            self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True
        )

        # check the workers recursively
        apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, 'train dataloader')

        # add worker_init_fn for correct seeding in worker processes
        apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn)

        # add collate_fn to collect metadata for fault tolerant training
        if _fault_tolerant_enabled():
            apply_to_collection(self.train_dataloader, DataLoader, self._add_sampler_metadata_collate)

        # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
        self.train_dataloader = CombinedLoader(self.train_dataloader, self.data_connector.multiple_trainloader_mode)

        # allow accelerator to modify dataloader
        self.train_dataloader = self.accelerator.on_reset_train_dataloader(self.train_dataloader)

        self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf')

        if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
            self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
        elif self.num_training_batches != float('inf'):
            self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
        elif self.limit_train_batches != 1.0:
            raise MisconfigurationException(
                'When using an IterableDataset for `limit_train_batches`,'
                ' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
                ' `num_training_batches` to use.'
            )

        # determine when to check validation
        # if int passed in, val checks that often
        # otherwise, it checks in [0, 1.0] % range of a training epoch
        if isinstance(self.val_check_interval, int):
            self.val_check_batch = self.val_check_interval
            if self.val_check_batch > self.num_training_batches:
                raise ValueError(
                    f'`val_check_interval` ({self.val_check_interval}) must be less than or equal '
                    f'to the number of the training batches ({self.num_training_batches}). '
                    'If you want to disable validation set `limit_val_batches` to 0.0 instead.'
                )
        else:
            if not has_len(self.train_dataloader):
                if self.val_check_interval == 1.0:
                    self.val_check_batch = float('inf')
                else:
                    raise MisconfigurationException(
                        'When using an IterableDataset for `train_dataloader`,'
                        ' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
                        ' checking validation every k training batches.'
                    )
            else:
                self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
                self.val_check_batch = max(1, self.val_check_batch)

        if self.logger and self.num_training_batches < self.log_every_n_steps:
            rank_zero_warn(
                f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"
                f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if"
                f" you want to see logs for the training epoch."
            )
Пример #6
0
def result_collection_reload(**kwargs):
    """
    This test is going to validate ResultCollection is properly being reload
    and final accumulation with Fault Tolerant Training is correct.
    """

    if not _fault_tolerant_enabled():
        pytest.skip("Fault tolerant not available")

    num_processes = kwargs.get("gpus", 1)

    class CustomException(Exception):
        pass

    class ExtendedBoringModel(BoringModel):
        def __init__(self):
            super().__init__()
            self.breaking_batch_idx = 3
            self.has_validated_sum = False
            self.dummy_metric = DummyMeanMetric()

        @property
        def results(self):
            return self.trainer.fit_loop._results

        def training_step(self, batch, batch_idx):

            # In the training step, we will accumulate metrics using batch_idx from 0 to 4
            # Without failure, we would expect to get `total=10 * world_size` and `num_batches=5 * world_size`
            # Therefore, compute on `epoch_end` should provide 2 as `10 / 5`.
            # However, below we will simulate a failure on `batch_idx=3`.

            if self.trainer.fit_loop.restarting:
                self.log("tracking", batch_idx, on_step=True, on_epoch=True)
                self.log("tracking_2",
                         batch_idx,
                         on_step=True,
                         on_epoch=True,
                         sync_dist=True)

                self.dummy_metric(batch_idx)
                self.log("tracking_metric",
                         self.dummy_metric,
                         on_step=True,
                         on_epoch=True)

                value = self.results["training_step.tracking_metric"].value
                value_2 = self.results["training_step.tracking"].value

                # On failure, the Metric states are being accumulated on rank 0 and zeroed-out on other ranks.
                # The shift indicates we failed while the state was `shift=sign(is_global_zero > 0) * [0..3]`
                shift = 0
                if num_processes == 2:
                    shift = 3 if self.trainer.is_global_zero else -3
                expected = sum(range(batch_idx + 1)) + shift
                assert expected == value == value_2
            else:
                if batch_idx == self.breaking_batch_idx:
                    # simulate failure mid epoch
                    raise CustomException

                self.log("tracking", batch_idx, on_step=True, on_epoch=True)
                self.log("tracking_2",
                         batch_idx,
                         on_step=True,
                         on_epoch=True,
                         sync_dist=True)

                self.dummy_metric(batch_idx)
                self.log("tracking_metric",
                         self.dummy_metric,
                         on_step=True,
                         on_epoch=True)

                value = self.results["training_step.tracking"].value
                assert value == sum(range(batch_idx + 1))

                value = self.results["training_step.tracking_2"]
                assert value == sum(range(batch_idx + 1))

            return super().training_step(batch, batch_idx)

        def on_epoch_end(self) -> None:
            if self.trainer.fit_loop.restarting:
                total = sum(range(5)) * num_processes
                metrics = self.results.metrics(on_step=False)
                assert self.results["training_step.tracking"].value == total
                assert metrics[MetricSource.CALLBACK][
                    "tracking"] == self.dummy_metric.compute() == 2
                assert self.results["training_step.tracking_2"].value == total
                assert metrics[MetricSource.CALLBACK][
                    "tracking_2"] == self.dummy_metric.compute() == 2
                self.has_validated_sum = True

    model = ExtendedBoringModel()
    trainer_kwargs = {
        "max_epochs": 1,
        "limit_train_batches": 5,
        "limit_val_batches": 0
    }
    trainer_kwargs.update(kwargs)
    trainer = Trainer(**trainer_kwargs)

    with suppress(CustomException):
        trainer.fit(model)
    assert not model.has_validated_sum

    tmpdir = (trainer.training_type_plugin.broadcast(
        trainer_kwargs["default_root_dir"], 0)
              if num_processes >= 2 else trainer_kwargs["default_root_dir"])
    ckpt_path = os.path.join(tmpdir, ".pl_auto_save.ckpt")
    trainer_kwargs["resume_from_checkpoint"] = ckpt_path

    trainer = Trainer(**trainer_kwargs)
    trainer.fit(model)
    assert model.has_validated_sum
Пример #7
0
    def replace_sampler(self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None) -> DataLoader:
        if not isinstance(dataloader, DataLoader):
            raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")

        # get the dataloader instance attributes
        attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")}
        # not part of `vars`
        attrs["multiprocessing_context"] = dataloader.multiprocessing_context

        # get the dataloader instance `__init__` parameters
        params = dict(inspect.signature(dataloader.__init__).parameters)

        # keep only the params whose default is different to the current attr value
        non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]}
        # add `dataset` as it might have been replaced with `*args`
        non_defaults.add("dataset")

        # kwargs to re-construct the dataloader
        dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
        dl_kwargs.update(self._resolve_batch_sampler(dataloader, sampler, mode=mode))

        required_args = {
            p.name
            for p in params.values()
            if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
            and p.default is p.empty
            and p.name not in dl_kwargs
        }
        # the dataloader has required args which we could not extract from the existing attributes
        if required_args:
            required_args = sorted(required_args)
            dataloader_cls_name = dataloader.__class__.__name__
            raise MisconfigurationException(
                f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
                "This would fail as some of the `__init__` arguments are not available as instance attributes. "
                f"The missing attributes are {required_args}. "
                f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or "
                "manually add the `DistributedSampler` as: "
                f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
            )

        has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
        if not has_variadic_kwargs:
            # the dataloader signature does not allow keyword arguments that need to be passed
            missing_kwargs = dl_kwargs.keys() - params.keys()
            if missing_kwargs:
                missing_kwargs = sorted(missing_kwargs)
                dataloader_cls_name = dataloader.__class__.__name__
                raise MisconfigurationException(
                    f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
                    "This would fail as it doesn't expose all its attributes in the `__init__` signature. "
                    f"The missing arguments are {missing_kwargs}. "
                    f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or "
                    "manually add the `DistributedSampler` as: "
                    f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
                )

        # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
        if _fault_tolerant_enabled() and isinstance(dl_kwargs["dataset"], IterableDataset):
            dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"])
            dl_kwargs["sampler"] = None

        if isinstance(dl_kwargs["dataset"], IterableDataset):
            del dl_kwargs["sampler"]
            del dl_kwargs["batch_sampler"]

        dl_cls = type(dataloader)
        dataloader = dl_cls(**dl_kwargs)
        return dataloader
Пример #8
0
    def dump_checkpoint(self, weights_only: bool = False) -> dict:
        """Creating a model checkpoint dictionary object from various component states.
        Args:
            weights_only: saving model weights only
        Return:
            structured dictionary: {
                'epoch':                     training epoch
                'global_step':               training global step
                'pytorch-lightning_version': PyTorch Lightning's version
                'callbacks':                 "callback specific state"[] # if not weights_only
                'optimizer_states':          "PT optim's state_dict"[]   # if not weights_only
                'lr_schedulers':             "PT sched's state_dict"[]   # if not weights_only
                'native_amp_scaling_state':  PT amp's state_dict         # if not weights_only and use native amp
                'amp_scaling_state':         Apex's state_dict           # if not weights_only and use apex amp
                'state_dict':                Model's state_dict (e.g. network weights)
                CHECKPOINT_HYPER_PARAMS_NAME:
                CHECKPOINT_HYPER_PARAMS_KEY:
                CHECKPOINT_HYPER_PARAMS_TYPE:
                something_cool_i_want_to_save: anything you define through model.on_save_checkpoint
                LightningDataModule.__class__.__name__: pl DataModule's state
            }
        """

        # dump epoch/global_step/pytorch-lightning_version
        current_epoch = self.trainer.current_epoch
        global_step = self.trainer.global_step
        has_reached_max_steps = self.trainer.max_steps and self.trainer.max_steps <= global_step

        global_step += 1
        if not has_reached_max_steps:
            current_epoch += 1

        model = self.trainer.lightning_module

        checkpoint = {
            "epoch": current_epoch,
            "global_step": global_step,
            "pytorch-lightning_version": pl.__version__,
            "state_dict": self.trainer.accelerator.lightning_module_state_dict(),
        }
        if _fault_tolerant_enabled():
            checkpoint["loops"] = self._get_loops_state_dict()

        if not weights_only:
            # dump callbacks
            checkpoint["callbacks"] = self.trainer.on_save_checkpoint(checkpoint)

            optimizer_states = []
            for i, optimizer in enumerate(self.trainer.optimizers):
                # Rely on accelerator to dump optimizer state
                optimizer_state = self.trainer.accelerator.optimizer_state(optimizer)
                optimizer_states.append(optimizer_state)

            checkpoint["optimizer_states"] = optimizer_states

            # dump lr schedulers
            lr_schedulers = []
            for scheduler in self.trainer.lr_schedulers:
                lr_schedulers.append(scheduler["scheduler"].state_dict())
            checkpoint["lr_schedulers"] = lr_schedulers

            self.trainer.precision_plugin.on_save_checkpoint(checkpoint)

        # dump hyper-parameters
        if model.hparams:
            if hasattr(model, "_hparams_name"):
                checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
            # dump arguments
            if _OMEGACONF_AVAILABLE and isinstance(model.hparams, Container):
                checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
                checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
            else:
                checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams)

        # give the model a chance to dump a few things
        model.on_save_checkpoint(checkpoint)
        if self.trainer.datamodule is not None:
            self.trainer.datamodule.on_save_checkpoint(checkpoint)

        return checkpoint