def update_replica_device_attributes(self, inputs: Any) -> None: """Updates the device information of LightningModule by reading the device from the inputs. In :class:`~torch.nn.data_parallel.DataParallel` changes to the state during the `forward` pass are lost when the replicas get discarded. The only way to know the current device is from the inputs passed into the model. Args: inputs: A collection of inputs (typically a tuple). If the inputs don't contain tensors, a warning is shown that accessing ``self.device`` will not return the correct device. """ replica_device = None def find_tensor_with_device(tensor: torch.Tensor) -> torch.Tensor: nonlocal replica_device if replica_device is None and tensor.device != torch.device("cpu"): replica_device = tensor.device return tensor apply_to_collection(inputs, dtype=torch.Tensor, function=find_tensor_with_device) if replica_device is not None: # by calling .to() we force the update to the self.device property self.module.to(device=replica_device) else: rank_zero_warn( "Could not determine on which device the inputs are." " When using DataParallel (strategy='dp'), be aware that in case you are using self.device" " in your code, it will reference only the root device." )
def _resolve_overfit_batches(dataloader: Collection[DataLoader]) -> Collection[DataLoader]: all_have_sequential_sampler = True def resolve_has_no_sequential_sampler(dataloader: DataLoader): nonlocal all_have_sequential_sampler all_have_sequential_sampler = all_have_sequential_sampler & isinstance( dataloader.sampler, SequentialSampler ) apply_to_collection(dataloader, DataLoader, resolve_has_no_sequential_sampler) if not all_have_sequential_sampler: rank_zero_warn( "You requested to overfit but enabled training dataloader shuffling." " We are turning off the training dataloader shuffling for you." ) def replace_sampler(dataloader: DataLoader) -> DataLoader: return TrainerDataLoadingMixin._update_dataloader( dataloader, SequentialSampler(dataloader.dataset), mode=RunningStage.TRAINING ) dataloader = apply_to_collection(dataloader, DataLoader, replace_sampler) return dataloader
def test_combined_data_loader_validation_test(cuda_available_mock, device_count_mock, tmpdir): """This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader.""" class CustomDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, index): return self.data[index] dataloader = CombinedLoader( { "a": DataLoader(CustomDataset(range(10))), "b": {"c": DataLoader(CustomDataset(range(10))), "d": DataLoader(CustomDataset(range(10)))}, "e": [DataLoader(CustomDataset(range(10))), DataLoader(CustomDataset(range(10)))], } ) trainer = Trainer(replace_sampler_ddp=True, accelerator="ddp", gpus=2) dataloader = trainer.auto_add_sampler(dataloader, shuffle=True) _count = 0 def _assert_distributed_sampler(v): nonlocal _count _count += 1 assert isinstance(v, DistributedSampler) apply_to_collection(dataloader.sampler, Sampler, _assert_distributed_sampler) assert _count == 5
def forward(self, *args: Any, **kwargs: Any) -> Any: """Casts all inputs to the right precision and handles autocast for operations in the module forward method.""" precision = self._precision_plugin.precision precision_to_type = { "bf16": torch.bfloat16, 16: torch.float16, 32: torch.float32, 64: torch.float64, } # TODO (@awaelchli): let the precision plugin handle the conversion to_type = precision_to_type[precision] def _convert_float_tensor(t: Tensor) -> Tensor: return t.to(to_type) if torch.is_floating_point(t) else t args, kwargs = apply_to_collection([args, kwargs], function=_convert_float_tensor, dtype=Tensor) with self._precision_plugin.forward_context(): output = self.module(*args, **kwargs) to_type = torch.get_default_dtype() output = apply_to_collection(output, function=_convert_float_tensor, dtype=Tensor) return output
def reset(self): if self._iterator: self._iterator._loader_iters = None if self.loaders is not None: apply_to_collection(self.loaders, DataLoader, self._shutdown_workers_and_reset_iterator) self._iterator = None
def _resolve_overfit_batches(dataloaders: Collection[DataLoader], mode: RunningStage) -> Collection[DataLoader]: all_have_sequential_sampler = True def resolve_has_no_sequential_sampler(dataloader: DataLoader): nonlocal all_have_sequential_sampler all_have_sequential_sampler = all_have_sequential_sampler & isinstance( dataloader.sampler, SequentialSampler) apply_to_collection(dataloaders, DataLoader, resolve_has_no_sequential_sampler) if not all_have_sequential_sampler: rank_zero_warn( "You requested to overfit but enabled training dataloader shuffling." f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you." ) def replace_sampler(dataloader: DataLoader) -> DataLoader: return _update_dataloader(dataloader, sampler=SequentialSampler( dataloader.dataset), mode=mode) dataloaders = apply_to_collection(dataloaders, DataLoader, replace_sampler) return dataloaders
def _wrap_loaders_max_size_cycle(self) -> Any: """ Wraps all loaders to make sure they are cycled until the longest loader is exhausted Returns: the wrapped loaders """ all_lengths = apply_to_collection(self.loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping)) length = _nested_calc_num_data(all_lengths, max) # multiple loaders if isinstance(self.loaders, (Sequence, Mapping)): state = SharedCycleIteratorState() self.loaders = apply_to_collection(self.loaders, Iterable, CycleIterator, length=length, state=state, wrong_dtype=(Sequence, Mapping)) state.reset()
def _validate_fault_tolerant_automatic( dataloader: Iterable, stage: "pl.trainer.states.RunningStage") -> None: """This function is used to validate that Fault-tolerance is possible with the user data.""" if not _FaultTolerantMode.detect_current_mode().is_automatic: return from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator if isinstance(dataloader, CombinedLoader): dataloaders = dataloader.loaders else: dataloaders = dataloader dl_loaders = [] def flatten_dataloader( dataloader: Union[DataLoader, CycleIterator, Iterable]) -> None: nonlocal dl_loaders if isinstance(dataloader, CycleIterator): dataloader = dataloader.loader dl_loaders.append(dataloader) apply_to_collection(dataloaders, (DataLoader, CycleIterator), flatten_dataloader) if len(dl_loaders ) > 1 and stage == pl.trainer.states.RunningStage.TRAINING: raise ValueError("Fault-tolerance supports only a single dataloader.") for dataloader in dl_loaders: validator_fn = (_validate_iterable_dataset if isinstance( dataloader.dataset, IterableDataset) else _validate_map_dataset) validator_fn(dataloader)
def teardown(self): args = (torch.Tensor, move_data_to_device, "cpu") self._logged_metrics = apply_to_collection(self._logged_metrics, *args) self._progress_bar_metrics = apply_to_collection( self._progress_bar_metrics, *args) self._callback_metrics = apply_to_collection(self._callback_metrics, *args)
def test_sklearn_metric(metric_class, sklearn_func, inputs: dict): numpy_inputs = apply_to_collection( inputs, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy) sklearn_result = sklearn_func(**numpy_inputs) lightning_result = metric_class(**inputs) sklearn_result = apply_to_collection( sklearn_result, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy) lightning_result = apply_to_collection( lightning_result, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy) assert isinstance(lightning_result, type(sklearn_result)) if isinstance(lightning_result, np.ndarray): assert np.allclose(lightning_result, sklearn_result) elif isinstance(lightning_result, Mapping): for key in lightning_result.keys(): assert np.allclose(lightning_result[key], sklearn_result[key]) elif isinstance(lightning_result, Sequence): for val_lightning, val_sklearn in zip(lightning_result, sklearn_result): assert np.allclose(val_lightning, val_sklearn) else: raise TypeError
def result_metrics(self) -> List[ResultMetric]: o = [] def append_fn(v: ResultMetric) -> None: nonlocal o o.append(v) apply_to_collection(list(self.values()), ResultMetric, append_fn) return o
def _add_capture_metadata_collate(dataloader: Iterable) -> None: if not isinstance(dataloader, (DataLoader, CombinedLoader)): return if isinstance(dataloader, CombinedLoader): dataloader = dataloader.loaders apply_to_collection(dataloader, DataLoader, _add_capture_metadata_collate)
def _load_from_state_dict(self, state_dict: Dict, prefix: str, restart_progress: bool) -> None: for k, v in self.__dict__.items(): if isinstance(v, BaseProgress): v.load_state_dict(state_dict[prefix + k]) if restart_progress: apply_to_collection(v, Progress, lambda p: p.current.reset_on_restart()) self.on_load_checkpoint(state_dict[prefix + "state_dict"]) self.restarting = True
def _check_extra_detach_deprecation(extra: Dict[str, Any]) -> None: def check_fn(v: Tensor) -> Tensor: if v.grad_fn is not None: rank_zero_deprecation( f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically" " but this behaviour will change in v1.6. Please detach it manually:" " `return {'loss': ..., 'something': something.detach()}`") return v apply_to_collection(extra, Tensor, check_fn)
def test_apply_to_collection_frozen_dataclass(): @dataclasses.dataclass(frozen=True) class Foo: input: torch.Tensor foo = Foo(torch.tensor(0)) with pytest.raises(MisconfigurationException, match="frozen dataclass was passed"): apply_to_collection(foo, torch.Tensor, lambda t: t.to(torch.int))
def _prepare_input(self, args: Any): def to_tuple(x): return tuple(x) def to_tensor(x): return torch.tensor(x).unsqueeze(0).repeat(self._n_replicate) args = apply_to_collection(args, dtype=list, function=to_tuple) args = apply_to_collection(args, dtype=(int, float), function=to_tensor) return args
def _attach_data_fetcher(self) -> None: def _attach_data_fetcher_fn(loader: DataLoader) -> None: if isinstance(loader, CycleIterator): loader = loader.loader if isinstance(loader, DataLoader) and _fault_tolerant_training(): loader._lightning_fetcher = self apply_to_collection(self.loaders, (DataLoader, CycleIterator), _attach_data_fetcher_fn)
def collect_tensors(data: Any) -> List[torch.Tensor]: """ Filters all tensors in a collection and returns them in a list. """ tensors = [] def collect_batches(tensor): tensors.append(tensor) return tensor apply_to_collection(data, dtype=torch.Tensor, function=collect_batches) return tensors
def convert_to_modules(transforms: Optional[Dict[str, Callable]]): if transforms is None or isinstance(transforms, torch.nn.Module): return transforms transforms = apply_to_collection(transforms, Callable, FuncModule, wrong_dtype=torch.nn.Module) transforms = apply_to_collection(transforms, Mapping, torch.nn.ModuleDict, wrong_dtype=torch.nn.ModuleDict) transforms = apply_to_collection( transforms, Iterable, torch.nn.ModuleList, wrong_dtype=(torch.nn.ModuleList, torch.nn.ModuleDict) ) return transforms
def test_apply_to_collection_include_none(): to_reduce = [1, 2, 3.4, 5.6, 7] def fn(x): if isinstance(x, float): return x reduced = apply_to_collection(to_reduce, (int, float), fn) assert reduced == [None, None, 3.4, 5.6, None] reduced = apply_to_collection(to_reduce, (int, float), fn, include_none=False) assert reduced == [3.4, 5.6]
def to(self, *args, **kwargs) -> 'ResultCollection': """Move all data to the given device.""" def to_(item: Union[torch.Tensor, Metric], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Metric]: return item.to(*args, **kwargs) apply_to_collection(self, (torch.Tensor, Metric), to_, *args, **kwargs) if self.minimize is not None: self.minimize = self.minimize.to(*args, **kwargs) self._batch_size = self._batch_size.to(*args, **kwargs) if 'device' in kwargs: self.device = kwargs['device'] return self
def test_sklearn_metric(metric_class, sklearn_func, inputs): numpy_inputs = apply_to_collection(inputs, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy) sklearn_result = sklearn_func(**numpy_inputs) lightning_result = metric_class(**inputs) assert np.allclose(sklearn_result, lightning_result, atol=1e-5) sklearn_result = apply_to_collection( sklearn_result, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy) lightning_result = apply_to_collection( lightning_result, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy) assert np.allclose(sklearn_result, lightning_result, atol=1e-5) assert isinstance(lightning_result, type(sklearn_result))
def log_histograms(self, batch: Any, group: str = "") -> None: """ Logs the histograms at the interval defined by `row_log_interval`, given a logger is available. Args: batch: torch or numpy arrays, or a collection of it (tuple, list, dict, ...), can be nested. If the data appears in a dictionary, the keys are used as labels for the corresponding histogram. Otherwise the histograms get labelled with an integer index. Each label also has the tensors's shape as suffix. group: Name under which the histograms will be grouped. """ if not self._log or ( self._train_batch_idx + 1) % self._log_every_n_steps != 0: # type: ignore[operator] return batch = apply_to_collection(batch, dtype=np.ndarray, function=torch.from_numpy) named_tensors: Dict[str, Tensor] = {} collect_and_name_tensors(batch, output=named_tensors, parent_name=group) for name, tensor in named_tensors.items(): self.log_histogram(tensor, name)
def all_gather( self, data: Union[torch.Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False ) -> Union[torch.Tensor, Dict, List, Tuple]: r""" Gather tensors or collections of tensors from multiple processes. Args: data: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof. group: the process group to gather results from. Defaults to all processes (world) sync_grads: flag that allows users to synchronize gradients for the all_gather operation Return: A tensor of shape (world_size, batch, ...), or if the input was a collection the output will also be a collection with tensors of this shape. """ group = group if group is not None else torch.distributed.group.WORLD data = convert_to_tensors(data, device=self.device) return apply_to_collection(data, torch.Tensor, self._strategy.all_gather, group=group, sync_grads=sync_grads)
def _get_keys(data: dict) -> Iterable[Tuple[str, ...]]: for k, v in data.items(): if isinstance(v, dict): for new_key in apply_to_collection(v, dict, EvaluationLoop._get_keys): yield (k, *new_key) # this need to be in parenthesis for older python versions else: yield k,
def _prepare_outputs_training_batch_end( batch_output: _BATCH_OUTPUTS_TYPE, automatic: bool, num_optimizers: int, ) -> Union[List[List[Dict[str, Any]]], List[Dict[str, Any]]]: """Processes the outputs from the batch loop into the format passed to the ``training_batch_end`` hook. ``(tbptt_steps, n_opt) -> (n_opt, tbptt_steps)``. The optimizer dimension might have been squeezed. """ if not batch_output: return [] # convert optimizer dicts to list if automatic: batch_output = apply_to_collection( batch_output, dtype=dict, function=_convert_optim_dict, num_optimizers=num_optimizers ) array = np.array(batch_output, dtype=object) if array.ndim == 1: array = np.expand_dims(array, 1) array = array.transpose((1, 0)) array = array.squeeze() array = array.tolist() array = _recursive_unpad(array) return array
def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: """Moves the state of the optimizers to the TPU if needed.""" # TODO: `self.root_device` would raise error if called outside the spawn process # while training on 8 and more cores. for opt in self.optimizers: for p, v in opt.state.items(): opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device)
def __init__(self, loaders: Any, mode: str = 'min_size'): """ Args: loaders: the loaders to sample from. Can be all kind of collection mode: the mode. Supported are 'min_size' which stops if the shortest loader is exhausted and 'max_size_cycle' which stops if the longest loader is exhausted and cycles through the smaller ones. """ self.loaders = loaders datasets = apply_to_collection(self.loaders, Iterable, getattr, 'dataset', None, wrong_dtype=(Sequence, Mapping)) # could be multiple datasets, but use self.dataset to follow the name convention in DataLoader self.dataset = CombinedDataset(datasets, mode) if mode not in self.SUPPORTED_MODES: raise MisconfigurationException(f"Invalid Mode: {mode}") self.mode = mode if self.mode == 'max_size_cycle': self._wrap_loaders_max_size_cycle()
def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int, float]: """ Compute the length of `CombinedDataset` according to the `mode`. Args: datasets: a sequence/mapping datasets. Can be a collections of torch.utils.data.Dataset, Iterable or even None. mode: Determine `CombinedDataset`'s length is the maximum or minimum of the datasets. Returns: length: the length of `CombinedDataset` """ if mode not in CombinedDataset.COMPUTE_FUNCS.keys(): raise MisconfigurationException(f"Invalid Mode: {mode}") # extract the lengths all_lengths = apply_to_collection(datasets, (Dataset, Iterable, type(None)), get_len, wrong_dtype=(Sequence, Mapping)) compute_func = CombinedDataset.COMPUTE_FUNCS[mode] if isinstance(all_lengths, (int, float)): length = all_lengths else: length = _nested_calc_num_data(all_lengths, compute_func) return length
def __getstate__(self) -> dict: def getstate(item: ResultMetric) -> dict: return item.__getstate__() items = apply_to_collection(dict(self), (ResultMetric, ResultMetricCollection), getstate) return {"items": items, "meta": self.meta.__getstate__(), "_class": self.__class__.__name__}