def _resolve_batch_sampler(dataloader, sampler, mode: Optional[RunningStage] = None) -> Dict[str, Any]: batch_sampler = getattr(dataloader, "batch_sampler") is_predicting = mode == RunningStage.PREDICTING # checking the batch sampler type is different than PyTorch default. if (batch_sampler is not None and type(batch_sampler) is not BatchSampler) or is_predicting: batch_sampler = type(batch_sampler)( sampler, batch_size=batch_sampler.batch_size, drop_last=(False if is_predicting else batch_sampler.drop_last), ) if is_predicting: batch_sampler = IndexBatchSamplerWrapper(batch_sampler) if _fault_tolerant_enabled(): fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler) fast_forward_sampler.setup(dataloader_batch_size=1) return { "sampler": None, "shuffle": False, "batch_sampler": batch_sampler, "batch_size": 1, "drop_last": False, } if _fault_tolerant_enabled(): fast_forward_sampler = sampler = FastForwardSampler(sampler) fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size) return {"sampler": sampler, "shuffle": False, "batch_sampler": None}
def next_fn(iterator: Iterator): batch = next(iterator) if not _fault_tolerant_enabled(): return batch # when fault tolerant is enabled, the iterator will return # `FastForwardSampler` state_dict metadata # along side with the user data. # the metadata are extracted and store directly on the iterator # to simplify the collection on `state_dict` call. batch, samplers_state_dict = CaptureIterableDataset.extract_samplers_state_dict_from_batch(batch) # store the `sampler_state_dict` on the iterator CaptureIterableDataset.store_samplers_state_dict(iterator, samplers_state_dict) return batch
def state_dict(self, num_batches_processed: int) -> Dict: """ The state dict includes all states from wrapped dataloaders and their samplers through the ``CaptureIterableDataset`` and fast-forward samplers. Args: num_batches_processed: The number of batches processed so far, needed because the individual dataloaders may have already prefetched more batches by the time a state dict is requested. """ if not _fault_tolerant_enabled(): return DataLoaderDict() state_dict_fn = partial(self._state_dict_fn, num_batches_processed=num_batches_processed) return apply_to_collections(self.loaders, self._iterator.loader_iters, (Iterator, DataLoader), state_dict_fn)
def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]: metrics = ([ m for m in self.trainer.lightning_module.modules() if isinstance(m, Metric) ] if _fault_tolerant_enabled() else []) for metric in metrics: metric.persistent(True) metric.sync() state_dict = self.trainer.accelerator.lightning_module_state_dict() for metric in metrics: # sync can be a no-op (e.g. on cpu) so `unsync` would raise a user error exception if we don't check if metric._is_synced: metric.unsync() return state_dict
def reset_train_dataloader(self, model: 'pl.LightningModule') -> None: """Resets the train dataloader and initialises required variables (number of batches, when to validate, etc.). Args: model: The current `LightningModule` """ self.train_dataloader = self.request_dataloader(model, "train") if self.overfit_batches > 0: if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler): rank_zero_warn( 'You requested to overfit but enabled training dataloader shuffling.' ' We are turning off the training dataloader shuffling for you.' ) self.train_dataloader = self.replace_sampler( self.train_dataloader, SequentialSampler(self.train_dataloader.dataset) ) # debugging self.dev_debugger.track_load_dataloader_call('train_dataloader', dataloaders=[self.train_dataloader]) # automatically add samplers self.train_dataloader = apply_to_collection( self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True ) # check the workers recursively apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, 'train dataloader') # add worker_init_fn for correct seeding in worker processes apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn) # add collate_fn to collect metadata for fault tolerant training if _fault_tolerant_enabled(): apply_to_collection(self.train_dataloader, DataLoader, self._add_sampler_metadata_collate) # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches self.train_dataloader = CombinedLoader(self.train_dataloader, self.data_connector.multiple_trainloader_mode) # allow accelerator to modify dataloader self.train_dataloader = self.accelerator.on_reset_train_dataloader(self.train_dataloader) self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf') if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches)) elif self.num_training_batches != float('inf'): self.num_training_batches = int(self.num_training_batches * self.limit_train_batches) elif self.limit_train_batches != 1.0: raise MisconfigurationException( 'When using an IterableDataset for `limit_train_batches`,' ' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies' ' `num_training_batches` to use.' ) # determine when to check validation # if int passed in, val checks that often # otherwise, it checks in [0, 1.0] % range of a training epoch if isinstance(self.val_check_interval, int): self.val_check_batch = self.val_check_interval if self.val_check_batch > self.num_training_batches: raise ValueError( f'`val_check_interval` ({self.val_check_interval}) must be less than or equal ' f'to the number of the training batches ({self.num_training_batches}). ' 'If you want to disable validation set `limit_val_batches` to 0.0 instead.' ) else: if not has_len(self.train_dataloader): if self.val_check_interval == 1.0: self.val_check_batch = float('inf') else: raise MisconfigurationException( 'When using an IterableDataset for `train_dataloader`,' ' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies' ' checking validation every k training batches.' ) else: self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) if self.logger and self.num_training_batches < self.log_every_n_steps: rank_zero_warn( f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval" f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if" f" you want to see logs for the training epoch." )
def result_collection_reload(**kwargs): """ This test is going to validate ResultCollection is properly being reload and final accumulation with Fault Tolerant Training is correct. """ if not _fault_tolerant_enabled(): pytest.skip("Fault tolerant not available") num_processes = kwargs.get("gpus", 1) class CustomException(Exception): pass class ExtendedBoringModel(BoringModel): def __init__(self): super().__init__() self.breaking_batch_idx = 3 self.has_validated_sum = False self.dummy_metric = DummyMeanMetric() @property def results(self): return self.trainer.fit_loop._results def training_step(self, batch, batch_idx): # In the training step, we will accumulate metrics using batch_idx from 0 to 4 # Without failure, we would expect to get `total=10 * world_size` and `num_batches=5 * world_size` # Therefore, compute on `epoch_end` should provide 2 as `10 / 5`. # However, below we will simulate a failure on `batch_idx=3`. if self.trainer.fit_loop.restarting: self.log("tracking", batch_idx, on_step=True, on_epoch=True) self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) self.dummy_metric(batch_idx) self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) value = self.results["training_step.tracking_metric"].value value_2 = self.results["training_step.tracking"].value # On failure, the Metric states are being accumulated on rank 0 and zeroed-out on other ranks. # The shift indicates we failed while the state was `shift=sign(is_global_zero > 0) * [0..3]` shift = 0 if num_processes == 2: shift = 3 if self.trainer.is_global_zero else -3 expected = sum(range(batch_idx + 1)) + shift assert expected == value == value_2 else: if batch_idx == self.breaking_batch_idx: # simulate failure mid epoch raise CustomException self.log("tracking", batch_idx, on_step=True, on_epoch=True) self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) self.dummy_metric(batch_idx) self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) value = self.results["training_step.tracking"].value assert value == sum(range(batch_idx + 1)) value = self.results["training_step.tracking_2"] assert value == sum(range(batch_idx + 1)) return super().training_step(batch, batch_idx) def on_epoch_end(self) -> None: if self.trainer.fit_loop.restarting: total = sum(range(5)) * num_processes metrics = self.results.metrics(on_step=False) assert self.results["training_step.tracking"].value == total assert metrics[MetricSource.CALLBACK][ "tracking"] == self.dummy_metric.compute() == 2 assert self.results["training_step.tracking_2"].value == total assert metrics[MetricSource.CALLBACK][ "tracking_2"] == self.dummy_metric.compute() == 2 self.has_validated_sum = True model = ExtendedBoringModel() trainer_kwargs = { "max_epochs": 1, "limit_train_batches": 5, "limit_val_batches": 0 } trainer_kwargs.update(kwargs) trainer = Trainer(**trainer_kwargs) with suppress(CustomException): trainer.fit(model) assert not model.has_validated_sum tmpdir = (trainer.training_type_plugin.broadcast( trainer_kwargs["default_root_dir"], 0) if num_processes >= 2 else trainer_kwargs["default_root_dir"]) ckpt_path = os.path.join(tmpdir, ".pl_auto_save.ckpt") trainer_kwargs["resume_from_checkpoint"] = ckpt_path trainer = Trainer(**trainer_kwargs) trainer.fit(model) assert model.has_validated_sum
def replace_sampler(self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None) -> DataLoader: if not isinstance(dataloader, DataLoader): raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`") # get the dataloader instance attributes attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")} # not part of `vars` attrs["multiprocessing_context"] = dataloader.multiprocessing_context # get the dataloader instance `__init__` parameters params = dict(inspect.signature(dataloader.__init__).parameters) # keep only the params whose default is different to the current attr value non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]} # add `dataset` as it might have been replaced with `*args` non_defaults.add("dataset") # kwargs to re-construct the dataloader dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults} dl_kwargs.update(self._resolve_batch_sampler(dataloader, sampler, mode=mode)) required_args = { p.name for p in params.values() if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) and p.default is p.empty and p.name not in dl_kwargs } # the dataloader has required args which we could not extract from the existing attributes if required_args: required_args = sorted(required_args) dataloader_cls_name = dataloader.__class__.__name__ raise MisconfigurationException( f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. " "This would fail as some of the `__init__` arguments are not available as instance attributes. " f"The missing attributes are {required_args}. " f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or " "manually add the `DistributedSampler` as: " f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`." ) has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values()) if not has_variadic_kwargs: # the dataloader signature does not allow keyword arguments that need to be passed missing_kwargs = dl_kwargs.keys() - params.keys() if missing_kwargs: missing_kwargs = sorted(missing_kwargs) dataloader_cls_name = dataloader.__class__.__name__ raise MisconfigurationException( f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. " "This would fail as it doesn't expose all its attributes in the `__init__` signature. " f"The missing arguments are {missing_kwargs}. " f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or " "manually add the `DistributedSampler` as: " f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`." ) # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states. if _fault_tolerant_enabled() and isinstance(dl_kwargs["dataset"], IterableDataset): dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"]) dl_kwargs["sampler"] = None if isinstance(dl_kwargs["dataset"], IterableDataset): del dl_kwargs["sampler"] del dl_kwargs["batch_sampler"] dl_cls = type(dataloader) dataloader = dl_cls(**dl_kwargs) return dataloader
def dump_checkpoint(self, weights_only: bool = False) -> dict: """Creating a model checkpoint dictionary object from various component states. Args: weights_only: saving model weights only Return: structured dictionary: { 'epoch': training epoch 'global_step': training global step 'pytorch-lightning_version': PyTorch Lightning's version 'callbacks': "callback specific state"[] # if not weights_only 'optimizer_states': "PT optim's state_dict"[] # if not weights_only 'lr_schedulers': "PT sched's state_dict"[] # if not weights_only 'native_amp_scaling_state': PT amp's state_dict # if not weights_only and use native amp 'amp_scaling_state': Apex's state_dict # if not weights_only and use apex amp 'state_dict': Model's state_dict (e.g. network weights) CHECKPOINT_HYPER_PARAMS_NAME: CHECKPOINT_HYPER_PARAMS_KEY: CHECKPOINT_HYPER_PARAMS_TYPE: something_cool_i_want_to_save: anything you define through model.on_save_checkpoint LightningDataModule.__class__.__name__: pl DataModule's state } """ # dump epoch/global_step/pytorch-lightning_version current_epoch = self.trainer.current_epoch global_step = self.trainer.global_step has_reached_max_steps = self.trainer.max_steps and self.trainer.max_steps <= global_step global_step += 1 if not has_reached_max_steps: current_epoch += 1 model = self.trainer.lightning_module checkpoint = { "epoch": current_epoch, "global_step": global_step, "pytorch-lightning_version": pl.__version__, "state_dict": self.trainer.accelerator.lightning_module_state_dict(), } if _fault_tolerant_enabled(): checkpoint["loops"] = self._get_loops_state_dict() if not weights_only: # dump callbacks checkpoint["callbacks"] = self.trainer.on_save_checkpoint(checkpoint) optimizer_states = [] for i, optimizer in enumerate(self.trainer.optimizers): # Rely on accelerator to dump optimizer state optimizer_state = self.trainer.accelerator.optimizer_state(optimizer) optimizer_states.append(optimizer_state) checkpoint["optimizer_states"] = optimizer_states # dump lr schedulers lr_schedulers = [] for scheduler in self.trainer.lr_schedulers: lr_schedulers.append(scheduler["scheduler"].state_dict()) checkpoint["lr_schedulers"] = lr_schedulers self.trainer.precision_plugin.on_save_checkpoint(checkpoint) # dump hyper-parameters if model.hparams: if hasattr(model, "_hparams_name"): checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name # dump arguments if _OMEGACONF_AVAILABLE and isinstance(model.hparams, Container): checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams) else: checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams) # give the model a chance to dump a few things model.on_save_checkpoint(checkpoint) if self.trainer.datamodule is not None: self.trainer.datamodule.on_save_checkpoint(checkpoint) return checkpoint