Ejemplo n.º 1
0
    def update_metrics(self, key: str, value: _METRIC_COLLECTION) -> None:
        def fn(result_metric, v):
            # performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
            result_metric.forward(v.to(self.device), self.batch_size)
            result_metric.has_reset = False

        apply_to_collections(self[key], value, ResultMetric, fn)
Ejemplo n.º 2
0
    def _apply_patch(self):
        def _apply_patch_fn(loader: DataLoader, iterator: Iterator):
            if isinstance(loader, CycleIterator):
                loader = loader.loader
                # cycle_iterator = iterator
                iterator = iterator._loader_iter

            if isinstance(loader, DataLoader) and _fault_tolerant_training():
                loader._lightning_fetcher = self
                patch_dataloader_iterator(loader, iterator, self)

        apply_to_collections(self.loaders, self.loader_iters, (Iterator, DataLoader), _apply_patch_fn)
Ejemplo n.º 3
0
    def on_restart(self, iterator: Iterator):
        if not self._loaders_iter_state_dict:
            return

        # this happen inside the workers if any were specificied.

        def create_loader_iters(dataloader: DataLoader, state_dict: DataLoaderDict):
            if isinstance(dataloader.dataset, CaptureIterableDataset):
                # provide the `state_dict` to the `CaptureIterableDataset`
                # as it is responsible for passing down the state to associated `FastForwardSampler`
                dataloader.dataset.load_state_dict(state_dict)
            else:
                # for `Mapping-based` dataset, the `fast_forward_sampler` was attached
                # on the dataloader for simplicity
                dataloader.fast_forward_sampler.load_state_dict(state_dict)

            # cycle back the iterator to the failed worker if multiple workers were provided
            iterator = _cycle_to_next_worker_and_reset(dataloader, state_dict)

            if isinstance(dataloader.dataset, CaptureIterableDataset):
                # remove keys related to iterator
                state_dict = {k: v for k, v in state_dict.items() if k not in ("num_worker", "previous_worker")}
                # need to re-attach the state dict into the iterator for future collection.
                iterator._sampler_state_dict = [state_dict]
            return iterator

        # apply the `create_loader_iters` on the collection of `DataLoader / Iterator`.
        # each `Iterator` was created from the `DataLoader`.
        iterator._loader_iters = apply_to_collections(
            self.loaders, self._loaders_iter_state_dict, (DataLoader, DataLoaderDict), create_loader_iters
        )
Ejemplo n.º 4
0
    def state_dict(self, num_batches_processed: int) -> Dict:
        """
        The state dict includes all states from wrapped dataloaders and their samplers through the
        ``CaptureIterableDataset`` and fast-forward samplers.

        Args:
            num_batches_processed: The number of batches processed so far, needed because the individual dataloaders
                may have already prefetched more batches by the time a state dict is requested.
        """
        if not _fault_tolerant_enabled():
            return DataLoaderDict()

        state_dict_fn = partial(self._state_dict_fn, num_batches_processed=num_batches_processed)

        return apply_to_collections(self.loaders, self._iterator.loader_iters, (Iterator, DataLoader), state_dict_fn)
Ejemplo n.º 5
0
    def on_restart(self, iterator: Iterator) -> None:
        if not self._loaders_iter_state_dict:
            return

        def create_loader_iters(dataloader: DataLoader,
                                state_dict: Dict) -> Iterator:
            """Function used to reload the iterator state before once the workers are created."""

            dataloader_to_iter_on = dataloader
            if isinstance(dataloader, CycleIterator):
                dataloader = dataloader_to_iter_on.loader

            # dataset states are collected across all ranks
            rank = torch.distributed.get_rank() if distributed_available(
            ) else 0
            state_dict = state_dict[rank]

            _reload_dataloader_state_dict(dataloader, state_dict)

            # We finally spawned the workers if any.
            it = iter(dataloader_to_iter_on)

            # restore caching state
            state = MergedIteratorState.from_state_dict(state_dict)

            if isinstance(dataloader_to_iter_on, CycleIterator):
                it._loader_iter.state = state
            else:
                it.state = state
            return it

        # create an un-existing token, so it doesn't activate for something else than an iterator.
        class DataLoaderDict(dict):
            pass

        # apply the `create_loader_iters` on the collection of `DataLoader / Iterator`.
        # each `Iterator` was created from the `DataLoader`.
        iterator._loader_iters = apply_to_collections(
            self.loaders,
            self._loaders_iter_state_dict,
            (Iterable, DataLoaderDict),
            create_loader_iters,
            wrong_dtype=(Sequence, Mapping),
        )

        self._loaders_iter_state_dict = None
Ejemplo n.º 6
0
    def state_dict(self, has_completed: bool = False) -> Dict:
        """The state dict includes all states from wrapped dataloaders and their samplers through the
        ``CaptureIterableDataset`` and fast-forward samplers.

        Args:
            has_completed: whether the current state of data fetching is considered completed or not. If it is, the
                current state gets returned, otherwise the previously cached state.
        """
        if not _fault_tolerant_training() or self._iterator is None:
            return {}

        return apply_to_collections(
            self.loaders,
            self._iterator.loader_iters,
            (Iterator, DataLoader),
            partial(self._state_dict_fn, has_completed=has_completed),
        )
def test_apply_to_collections_dataclass():
    to_reduce_1 = Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]),
                          segment_ids=np.array([4.0, 5.0, 6.0]))
    to_reduce_2 = Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]),
                          segment_ids=np.array([4.0, 5.0, 6.0]))

    def fn(a, b):
        return a + b

    reduced = apply_to_collections(to_reduce_1, to_reduce_2,
                                   (torch.Tensor, numbers.Number, np.ndarray),
                                   fn)

    assert reduced == Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]),
                              segment_ids=np.array([8.0, 10.0, 12.0]))

    model_example = ModelExample(
        example_ids=["i-1", "i-2", "i-3"],
        feature=to_reduce_1,
        label=torch.tensor([7.0, 8.0, 9.0]),
    )

    # different types
    with pytest.raises(
            TypeError,
            match="Expected inputs to be dataclasses of the same type"):
        apply_to_collections(to_reduce_1, [1, 2],
                             (torch.Tensor, numbers.Number, np.ndarray), fn)

    # unmatched fields
    with pytest.raises(TypeError, match="Dataclasses fields do not match"):
        apply_to_collections(to_reduce_1, model_example,
                             (torch.Tensor, numbers.Number, np.ndarray), fn)

    classvar = WithClassVar(torch.arange(
        3))  # dataclass with same number but different type of fields
    with pytest.raises(TypeError, match="Dataclasses fields do not match"):
        apply_to_collections(to_reduce_1, classvar,
                             (torch.Tensor, numbers.Number, np.ndarray), fn)
Ejemplo n.º 8
0
def test_apply_to_collections():
    to_reduce_1 = {"a": {"b": [1, 2]}, "c": 5}
    to_reduce_2 = {"a": {"b": [3, 4]}, "c": 6}

    def fn(a, b):
        return a + b

    # basic test
    reduced = apply_to_collections(to_reduce_1, to_reduce_2, int, fn)
    assert reduced == {"a": {"b": [4, 6]}, "c": 11}

    with pytest.raises(KeyError):
        # strict mode - if a key does not exist in both we fail
        apply_to_collections({
            **to_reduce_2, "d": "foo"
        }, to_reduce_1, float, fn)

    # multiple dtypes
    reduced = apply_to_collections(to_reduce_1, to_reduce_2, (list, int), fn)
    assert reduced == {"a": {"b": [1, 2, 3, 4]}, "c": 11}

    # wrong dtype
    reduced = apply_to_collections(to_reduce_1,
                                   to_reduce_2, (list, int),
                                   fn,
                                   wrong_dtype=int)
    assert reduced == {"a": {"b": [1, 2, 3, 4]}, "c": 5}

    # list takes precedence because it is the type of data1
    reduced = apply_to_collections([1, 2, 3], [4], (int, list), fn)
    assert reduced == [1, 2, 3, 4]

    # different sizes
    with pytest.raises(AssertionError,
                       match="Sequence collections have different sizes"):
        apply_to_collections([[1, 2], [3]], [4], int, fn)

    def fn(a, b):
        return a.keys() | b.keys()

    # base case
    reduced = apply_to_collections(to_reduce_1, to_reduce_2, dict, fn)
    assert reduced == {"a", "c"}

    # type conversion
    to_reduce = [(1, 2), (3, 4)]
    reduced = apply_to_collections(to_reduce, to_reduce, int,
                                   lambda *x: sum(x))
    assert reduced == [(2, 4), (6, 8)]

    # named tuple
    foo = namedtuple("Foo", ["bar"])
    to_reduce = [foo(1), foo(2), foo(3)]
    reduced = apply_to_collections(to_reduce, to_reduce, int,
                                   lambda *x: sum(x))
    assert reduced == [foo(2), foo(4), foo(6)]

    # passing none
    reduced1 = apply_to_collections([1, 2, 3], None, int, lambda x: x * x)
    reduced2 = apply_to_collections(None, [1, 2, 3], int, lambda x: x * x)
    assert reduced1 == reduced2 == [1, 4, 9]
    reduced = apply_to_collections(None, None, int, lambda x: x * x)
    assert reduced is None
Ejemplo n.º 9
0
    def on_restart(self, iterator: Iterator) -> None:
        if not self._loaders_iter_state_dict:
            return

        def create_loader_iters(dataloader: DataLoader,
                                state_dict: Dict) -> Iterator:
            """Function used to reload the iterator state before once the workers are created."""

            dataloader_to_iter_on = dataloader
            if isinstance(dataloader, CycleIterator):
                dataloader = dataloader_to_iter_on.loader

            dataset = dataloader.dataset

            # We reload the states before creating the workers
            # The specific type of dataset will then decide if the state should be applied before or after
            # spawning the workers
            if isinstance(dataset, CaptureMapDataset):
                iterator_state = state_dict["state"][0]

                if not isinstance(iterator_state, IteratorState):
                    iterator_state = IteratorState.from_state_dict(
                        iterator_state)

                # reload sampler state
                ff_sampler = _find_fast_forward_samplers(dataloader)
                ff_sampler.load_state_dict(iterator_state.sampler_state)
                # reload dataset state
                dataset.load_state_dict(
                    iterator_state.dataset_state,
                    latest_worker_id=state_dict["latest_worker_id"],
                    num_workers=iterator_state.num_workers,
                )

            elif isinstance(dataset, CaptureIterableDataset):
                dataset_dict = {
                    sampler_name: state[0]["sampler_state"]
                    for sampler_name, state in state_dict["state"].items()
                }
                dataset.load_state_dict(dataset_dict)

            else:
                raise MisconfigurationException(
                    "This shouldn't happen. Please, open an issue on PyTorch Lightning Github."
                )

            # We finally spawned the workers if any.
            it = iter(dataloader_to_iter_on)

            # restore caching state
            state = MergedIteratorState.from_state_dict(state_dict)

            if isinstance(dataloader_to_iter_on, CycleIterator):
                it._loader_iter.state = state
            else:
                it.state = state
            return it

        # create an un-existing token, so it doesn't activate for something else than an iterator.
        class DataLoaderDict(dict):
            pass

        # apply the `create_loader_iters` on the collection of `DataLoader / Iterator`.
        # each `Iterator` was created from the `DataLoader`.
        iterator._loader_iters = apply_to_collections(
            self.loaders,
            self._loaders_iter_state_dict,
            (Iterable, DataLoaderDict),
            create_loader_iters,
            wrong_dtype=(Sequence, Mapping),
        )

        self._loaders_iter_state_dict = None