def update_replica_device_attributes(self, inputs: Any) -> None:
        """Updates the device information of LightningModule by reading the device from the inputs. In
        :class:`~torch.nn.data_parallel.DataParallel` changes to the state during the `forward` pass are lost when
        the replicas get discarded. The only way to know the current device is from the inputs passed into the
        model.

        Args:
            inputs: A collection of inputs (typically a tuple). If the inputs don't contain tensors,
                a warning is shown that accessing ``self.device`` will not return the correct device.
        """
        replica_device = None

        def find_tensor_with_device(tensor: torch.Tensor) -> torch.Tensor:
            nonlocal replica_device
            if replica_device is None and tensor.device != torch.device("cpu"):
                replica_device = tensor.device
            return tensor

        apply_to_collection(inputs, dtype=torch.Tensor, function=find_tensor_with_device)

        if replica_device is not None:
            # by calling .to() we force the update to the self.device property
            self.module.to(device=replica_device)
        else:
            rank_zero_warn(
                "Could not determine on which device the inputs are."
                " When using DataParallel (strategy='dp'), be aware that in case you are using self.device"
                " in your code, it will reference only the root device."
            )
    def _resolve_overfit_batches(dataloader: Collection[DataLoader]) -> Collection[DataLoader]:
        all_have_sequential_sampler = True

        def resolve_has_no_sequential_sampler(dataloader: DataLoader):
            nonlocal all_have_sequential_sampler
            all_have_sequential_sampler = all_have_sequential_sampler & isinstance(
                dataloader.sampler, SequentialSampler
            )

        apply_to_collection(dataloader, DataLoader, resolve_has_no_sequential_sampler)

        if not all_have_sequential_sampler:
            rank_zero_warn(
                "You requested to overfit but enabled training dataloader shuffling."
                " We are turning off the training dataloader shuffling for you."
            )

            def replace_sampler(dataloader: DataLoader) -> DataLoader:
                return TrainerDataLoadingMixin._update_dataloader(
                    dataloader, SequentialSampler(dataloader.dataset), mode=RunningStage.TRAINING
                )

            dataloader = apply_to_collection(dataloader, DataLoader, replace_sampler)

        return dataloader
def test_combined_data_loader_validation_test(cuda_available_mock, device_count_mock, tmpdir):
    """This test makes sure distributed sampler has been properly injected in dataloaders when using
    CombinedLoader."""

    class CustomDataset(Dataset):
        def __init__(self, data):
            self.data = data

        def __len__(self):
            return len(self.data)

        def __getitem__(self, index):
            return self.data[index]

    dataloader = CombinedLoader(
        {
            "a": DataLoader(CustomDataset(range(10))),
            "b": {"c": DataLoader(CustomDataset(range(10))), "d": DataLoader(CustomDataset(range(10)))},
            "e": [DataLoader(CustomDataset(range(10))), DataLoader(CustomDataset(range(10)))],
        }
    )

    trainer = Trainer(replace_sampler_ddp=True, accelerator="ddp", gpus=2)
    dataloader = trainer.auto_add_sampler(dataloader, shuffle=True)
    _count = 0

    def _assert_distributed_sampler(v):
        nonlocal _count
        _count += 1
        assert isinstance(v, DistributedSampler)

    apply_to_collection(dataloader.sampler, Sampler, _assert_distributed_sampler)
    assert _count == 5
Example #4
0
    def forward(self, *args: Any, **kwargs: Any) -> Any:
        """Casts all inputs to the right precision and handles autocast for operations in the module forward
        method."""
        precision = self._precision_plugin.precision
        precision_to_type = {
            "bf16": torch.bfloat16,
            16: torch.float16,
            32: torch.float32,
            64: torch.float64,
        }
        # TODO (@awaelchli): let the precision plugin handle the conversion
        to_type = precision_to_type[precision]

        def _convert_float_tensor(t: Tensor) -> Tensor:
            return t.to(to_type) if torch.is_floating_point(t) else t

        args, kwargs = apply_to_collection([args, kwargs],
                                           function=_convert_float_tensor,
                                           dtype=Tensor)

        with self._precision_plugin.forward_context():
            output = self.module(*args, **kwargs)

        to_type = torch.get_default_dtype()
        output = apply_to_collection(output,
                                     function=_convert_float_tensor,
                                     dtype=Tensor)
        return output
Example #5
0
 def reset(self):
     if self._iterator:
         self._iterator._loader_iters = None
     if self.loaders is not None:
         apply_to_collection(self.loaders, DataLoader,
                             self._shutdown_workers_and_reset_iterator)
     self._iterator = None
Example #6
0
    def _resolve_overfit_batches(dataloaders: Collection[DataLoader],
                                 mode: RunningStage) -> Collection[DataLoader]:
        all_have_sequential_sampler = True

        def resolve_has_no_sequential_sampler(dataloader: DataLoader):
            nonlocal all_have_sequential_sampler
            all_have_sequential_sampler = all_have_sequential_sampler & isinstance(
                dataloader.sampler, SequentialSampler)

        apply_to_collection(dataloaders, DataLoader,
                            resolve_has_no_sequential_sampler)

        if not all_have_sequential_sampler:
            rank_zero_warn(
                "You requested to overfit but enabled training dataloader shuffling."
                f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you."
            )

            def replace_sampler(dataloader: DataLoader) -> DataLoader:
                return _update_dataloader(dataloader,
                                          sampler=SequentialSampler(
                                              dataloader.dataset),
                                          mode=mode)

            dataloaders = apply_to_collection(dataloaders, DataLoader,
                                              replace_sampler)

        return dataloaders
Example #7
0
    def _wrap_loaders_max_size_cycle(self) -> Any:
        """
        Wraps all loaders to make sure they are cycled until the longest loader is exhausted

        Returns:
            the wrapped loaders
        """
        all_lengths = apply_to_collection(self.loaders,
                                          Iterable,
                                          get_len,
                                          wrong_dtype=(Sequence, Mapping))

        length = _nested_calc_num_data(all_lengths, max)

        # multiple loaders
        if isinstance(self.loaders, (Sequence, Mapping)):
            state = SharedCycleIteratorState()

            self.loaders = apply_to_collection(self.loaders,
                                               Iterable,
                                               CycleIterator,
                                               length=length,
                                               state=state,
                                               wrong_dtype=(Sequence, Mapping))

            state.reset()
Example #8
0
def _validate_fault_tolerant_automatic(
        dataloader: Iterable, stage: "pl.trainer.states.RunningStage") -> None:
    """This function is used to validate that Fault-tolerance is possible with the user data."""
    if not _FaultTolerantMode.detect_current_mode().is_automatic:
        return

    from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator

    if isinstance(dataloader, CombinedLoader):
        dataloaders = dataloader.loaders
    else:
        dataloaders = dataloader

    dl_loaders = []

    def flatten_dataloader(
            dataloader: Union[DataLoader, CycleIterator, Iterable]) -> None:
        nonlocal dl_loaders
        if isinstance(dataloader, CycleIterator):
            dataloader = dataloader.loader
        dl_loaders.append(dataloader)

    apply_to_collection(dataloaders, (DataLoader, CycleIterator),
                        flatten_dataloader)

    if len(dl_loaders
           ) > 1 and stage == pl.trainer.states.RunningStage.TRAINING:
        raise ValueError("Fault-tolerance supports only a single dataloader.")

    for dataloader in dl_loaders:
        validator_fn = (_validate_iterable_dataset if isinstance(
            dataloader.dataset, IterableDataset) else _validate_map_dataset)
        validator_fn(dataloader)
 def teardown(self):
     args = (torch.Tensor, move_data_to_device, "cpu")
     self._logged_metrics = apply_to_collection(self._logged_metrics, *args)
     self._progress_bar_metrics = apply_to_collection(
         self._progress_bar_metrics, *args)
     self._callback_metrics = apply_to_collection(self._callback_metrics,
                                                  *args)
Example #10
0
def test_sklearn_metric(metric_class, sklearn_func, inputs: dict):
    numpy_inputs = apply_to_collection(
        inputs, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)

    sklearn_result = sklearn_func(**numpy_inputs)
    lightning_result = metric_class(**inputs)

    sklearn_result = apply_to_collection(
        sklearn_result, (torch.Tensor, np.ndarray, numbers.Number),
        _convert_to_numpy)

    lightning_result = apply_to_collection(
        lightning_result, (torch.Tensor, np.ndarray, numbers.Number),
        _convert_to_numpy)

    assert isinstance(lightning_result, type(sklearn_result))

    if isinstance(lightning_result, np.ndarray):
        assert np.allclose(lightning_result, sklearn_result)
    elif isinstance(lightning_result, Mapping):
        for key in lightning_result.keys():
            assert np.allclose(lightning_result[key], sklearn_result[key])

    elif isinstance(lightning_result, Sequence):
        for val_lightning, val_sklearn in zip(lightning_result,
                                              sklearn_result):
            assert np.allclose(val_lightning, val_sklearn)

    else:
        raise TypeError
Example #11
0
    def result_metrics(self) -> List[ResultMetric]:
        o = []

        def append_fn(v: ResultMetric) -> None:
            nonlocal o
            o.append(v)

        apply_to_collection(list(self.values()), ResultMetric, append_fn)
        return o
Example #12
0
    def _add_capture_metadata_collate(dataloader: Iterable) -> None:
        if not isinstance(dataloader, (DataLoader, CombinedLoader)):
            return

        if isinstance(dataloader, CombinedLoader):
            dataloader = dataloader.loaders

        apply_to_collection(dataloader, DataLoader,
                            _add_capture_metadata_collate)
Example #13
0
 def _load_from_state_dict(self, state_dict: Dict, prefix: str,
                           restart_progress: bool) -> None:
     for k, v in self.__dict__.items():
         if isinstance(v, BaseProgress):
             v.load_state_dict(state_dict[prefix + k])
             if restart_progress:
                 apply_to_collection(v, Progress,
                                     lambda p: p.current.reset_on_restart())
     self.on_load_checkpoint(state_dict[prefix + "state_dict"])
     self.restarting = True
Example #14
0
    def _check_extra_detach_deprecation(extra: Dict[str, Any]) -> None:
        def check_fn(v: Tensor) -> Tensor:
            if v.grad_fn is not None:
                rank_zero_deprecation(
                    f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"
                    " but this behaviour will change in v1.6. Please detach it manually:"
                    " `return {'loss': ..., 'something': something.detach()}`")
            return v

        apply_to_collection(extra, Tensor, check_fn)
Example #15
0
def test_apply_to_collection_frozen_dataclass():
    @dataclasses.dataclass(frozen=True)
    class Foo:
        input: torch.Tensor

    foo = Foo(torch.tensor(0))

    with pytest.raises(MisconfigurationException,
                       match="frozen dataclass was passed"):
        apply_to_collection(foo, torch.Tensor, lambda t: t.to(torch.int))
Example #16
0
    def _prepare_input(self, args: Any):
        def to_tuple(x):
            return tuple(x)

        def to_tensor(x):
            return torch.tensor(x).unsqueeze(0).repeat(self._n_replicate)

        args = apply_to_collection(args, dtype=list, function=to_tuple)
        args = apply_to_collection(args, dtype=(int, float), function=to_tensor)
        return args
Example #17
0
    def _attach_data_fetcher(self) -> None:
        def _attach_data_fetcher_fn(loader: DataLoader) -> None:
            if isinstance(loader, CycleIterator):
                loader = loader.loader

            if isinstance(loader, DataLoader) and _fault_tolerant_training():
                loader._lightning_fetcher = self

        apply_to_collection(self.loaders, (DataLoader, CycleIterator),
                            _attach_data_fetcher_fn)
Example #18
0
def collect_tensors(data: Any) -> List[torch.Tensor]:
    """ Filters all tensors in a collection and returns them in a list. """
    tensors = []

    def collect_batches(tensor):
        tensors.append(tensor)
        return tensor

    apply_to_collection(data, dtype=torch.Tensor, function=collect_batches)
    return tensors
Example #19
0
def convert_to_modules(transforms: Optional[Dict[str, Callable]]):

    if transforms is None or isinstance(transforms, torch.nn.Module):
        return transforms

    transforms = apply_to_collection(transforms, Callable, FuncModule, wrong_dtype=torch.nn.Module)
    transforms = apply_to_collection(transforms, Mapping, torch.nn.ModuleDict, wrong_dtype=torch.nn.ModuleDict)
    transforms = apply_to_collection(
        transforms, Iterable, torch.nn.ModuleList, wrong_dtype=(torch.nn.ModuleList, torch.nn.ModuleDict)
    )
    return transforms
Example #20
0
def test_apply_to_collection_include_none():
    to_reduce = [1, 2, 3.4, 5.6, 7]

    def fn(x):
        if isinstance(x, float):
            return x

    reduced = apply_to_collection(to_reduce, (int, float), fn)
    assert reduced == [None, None, 3.4, 5.6, None]

    reduced = apply_to_collection(to_reduce, (int, float), fn, include_none=False)
    assert reduced == [3.4, 5.6]
Example #21
0
    def to(self, *args, **kwargs) -> 'ResultCollection':
        """Move all data to the given device."""

        def to_(item: Union[torch.Tensor, Metric], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Metric]:
            return item.to(*args, **kwargs)

        apply_to_collection(self, (torch.Tensor, Metric), to_, *args, **kwargs)

        if self.minimize is not None:
            self.minimize = self.minimize.to(*args, **kwargs)
        self._batch_size = self._batch_size.to(*args, **kwargs)
        if 'device' in kwargs:
            self.device = kwargs['device']
        return self
def test_sklearn_metric(metric_class, sklearn_func, inputs):
    numpy_inputs = apply_to_collection(inputs, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)

    sklearn_result = sklearn_func(**numpy_inputs)
    lightning_result = metric_class(**inputs)
    assert np.allclose(sklearn_result, lightning_result, atol=1e-5)

    sklearn_result = apply_to_collection(
        sklearn_result, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)

    lightning_result = apply_to_collection(
        lightning_result, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)

    assert np.allclose(sklearn_result, lightning_result, atol=1e-5)
    assert isinstance(lightning_result, type(sklearn_result))
Example #23
0
    def log_histograms(self, batch: Any, group: str = "") -> None:
        """
        Logs the histograms at the interval defined by `row_log_interval`, given a logger is available.

        Args:
            batch: torch or numpy arrays, or a collection of it (tuple, list, dict, ...), can be nested.
                If the data appears in a dictionary, the keys are used as labels for the corresponding histogram.
                Otherwise the histograms get labelled with an integer index.
                Each label also has the tensors's shape as suffix.
            group: Name under which the histograms will be grouped.
        """
        if not self._log or (
                self._train_batch_idx +
                1) % self._log_every_n_steps != 0:  # type: ignore[operator]
            return

        batch = apply_to_collection(batch,
                                    dtype=np.ndarray,
                                    function=torch.from_numpy)
        named_tensors: Dict[str, Tensor] = {}
        collect_and_name_tensors(batch,
                                 output=named_tensors,
                                 parent_name=group)

        for name, tensor in named_tensors.items():
            self.log_histogram(tensor, name)
Example #24
0
    def all_gather(
            self,
            data: Union[torch.Tensor, Dict, List, Tuple],
            group: Optional[Any] = None,
            sync_grads: bool = False
    ) -> Union[torch.Tensor, Dict, List, Tuple]:
        r"""
        Gather tensors or collections of tensors from multiple processes.

        Args:
            data: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof.
            group: the process group to gather results from. Defaults to all processes (world)
            sync_grads: flag that allows users to synchronize gradients for the all_gather operation

        Return:
            A tensor of shape (world_size, batch, ...), or if the input was a collection
            the output will also be a collection with tensors of this shape.
        """
        group = group if group is not None else torch.distributed.group.WORLD
        data = convert_to_tensors(data, device=self.device)
        return apply_to_collection(data,
                                   torch.Tensor,
                                   self._strategy.all_gather,
                                   group=group,
                                   sync_grads=sync_grads)
 def _get_keys(data: dict) -> Iterable[Tuple[str, ...]]:
     for k, v in data.items():
         if isinstance(v, dict):
             for new_key in apply_to_collection(v, dict, EvaluationLoop._get_keys):
                 yield (k, *new_key)  # this need to be in parenthesis for older python versions
         else:
             yield k,
Example #26
0
    def _prepare_outputs_training_batch_end(
        batch_output: _BATCH_OUTPUTS_TYPE,
        automatic: bool,
        num_optimizers: int,
    ) -> Union[List[List[Dict[str, Any]]], List[Dict[str, Any]]]:
        """Processes the outputs from the batch loop into the format passed to the ``training_batch_end`` hook.

        ``(tbptt_steps, n_opt) -> (n_opt, tbptt_steps)``. The optimizer dimension might have been squeezed.
        """
        if not batch_output:
            return []

        # convert optimizer dicts to list
        if automatic:
            batch_output = apply_to_collection(
                batch_output, dtype=dict, function=_convert_optim_dict, num_optimizers=num_optimizers
            )
        array = np.array(batch_output, dtype=object)
        if array.ndim == 1:
            array = np.expand_dims(array, 1)

        array = array.transpose((1, 0))
        array = array.squeeze()
        array = array.tolist()
        array = _recursive_unpad(array)
        return array
Example #27
0
 def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
     """Moves the state of the optimizers to the TPU if needed."""
     # TODO: `self.root_device` would raise error if called outside the spawn process
     # while training on 8 and more cores.
     for opt in self.optimizers:
         for p, v in opt.state.items():
             opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device)
Example #28
0
    def __init__(self, loaders: Any, mode: str = 'min_size'):
        """

        Args:
            loaders: the loaders to sample from. Can be all kind of collection
            mode: the mode. Supported are 'min_size' which stops if the shortest loader is exhausted and
                'max_size_cycle' which stops if the longest loader is exhausted and cycles through the smaller ones.

        """
        self.loaders = loaders

        datasets = apply_to_collection(self.loaders,
                                       Iterable,
                                       getattr,
                                       'dataset',
                                       None,
                                       wrong_dtype=(Sequence, Mapping))
        # could be multiple datasets, but use self.dataset to follow the name convention in DataLoader
        self.dataset = CombinedDataset(datasets, mode)

        if mode not in self.SUPPORTED_MODES:
            raise MisconfigurationException(f"Invalid Mode: {mode}")

        self.mode = mode

        if self.mode == 'max_size_cycle':
            self._wrap_loaders_max_size_cycle()
Example #29
0
    def _calc_num_data(datasets: Union[Sequence, Mapping],
                       mode: str) -> Union[int, float]:
        """
        Compute the length of `CombinedDataset` according to the `mode`.

        Args:
            datasets: a sequence/mapping datasets. Can be a collections of torch.utils.data.Dataset,
                Iterable or even None.
            mode: Determine `CombinedDataset`'s length is the maximum or minimum of
                the datasets.

        Returns:
            length: the length of `CombinedDataset`

        """
        if mode not in CombinedDataset.COMPUTE_FUNCS.keys():
            raise MisconfigurationException(f"Invalid Mode: {mode}")

        # extract the lengths
        all_lengths = apply_to_collection(datasets,
                                          (Dataset, Iterable, type(None)),
                                          get_len,
                                          wrong_dtype=(Sequence, Mapping))

        compute_func = CombinedDataset.COMPUTE_FUNCS[mode]

        if isinstance(all_lengths, (int, float)):
            length = all_lengths
        else:
            length = _nested_calc_num_data(all_lengths, compute_func)

        return length
Example #30
0
    def __getstate__(self) -> dict:

        def getstate(item: ResultMetric) -> dict:
            return item.__getstate__()

        items = apply_to_collection(dict(self), (ResultMetric, ResultMetricCollection), getstate)
        return {"items": items, "meta": self.meta.__getstate__(), "_class": self.__class__.__name__}