def update_metrics(self, key: str, value: _METRIC_COLLECTION) -> None: def fn(result_metric, v): # performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl` result_metric.forward(v.to(self.device), self.batch_size) result_metric.has_reset = False apply_to_collections(self[key], value, ResultMetric, fn)
def _apply_patch(self): def _apply_patch_fn(loader: DataLoader, iterator: Iterator): if isinstance(loader, CycleIterator): loader = loader.loader # cycle_iterator = iterator iterator = iterator._loader_iter if isinstance(loader, DataLoader) and _fault_tolerant_training(): loader._lightning_fetcher = self patch_dataloader_iterator(loader, iterator, self) apply_to_collections(self.loaders, self.loader_iters, (Iterator, DataLoader), _apply_patch_fn)
def on_restart(self, iterator: Iterator): if not self._loaders_iter_state_dict: return # this happen inside the workers if any were specificied. def create_loader_iters(dataloader: DataLoader, state_dict: DataLoaderDict): if isinstance(dataloader.dataset, CaptureIterableDataset): # provide the `state_dict` to the `CaptureIterableDataset` # as it is responsible for passing down the state to associated `FastForwardSampler` dataloader.dataset.load_state_dict(state_dict) else: # for `Mapping-based` dataset, the `fast_forward_sampler` was attached # on the dataloader for simplicity dataloader.fast_forward_sampler.load_state_dict(state_dict) # cycle back the iterator to the failed worker if multiple workers were provided iterator = _cycle_to_next_worker_and_reset(dataloader, state_dict) if isinstance(dataloader.dataset, CaptureIterableDataset): # remove keys related to iterator state_dict = {k: v for k, v in state_dict.items() if k not in ("num_worker", "previous_worker")} # need to re-attach the state dict into the iterator for future collection. iterator._sampler_state_dict = [state_dict] return iterator # apply the `create_loader_iters` on the collection of `DataLoader / Iterator`. # each `Iterator` was created from the `DataLoader`. iterator._loader_iters = apply_to_collections( self.loaders, self._loaders_iter_state_dict, (DataLoader, DataLoaderDict), create_loader_iters )
def state_dict(self, num_batches_processed: int) -> Dict: """ The state dict includes all states from wrapped dataloaders and their samplers through the ``CaptureIterableDataset`` and fast-forward samplers. Args: num_batches_processed: The number of batches processed so far, needed because the individual dataloaders may have already prefetched more batches by the time a state dict is requested. """ if not _fault_tolerant_enabled(): return DataLoaderDict() state_dict_fn = partial(self._state_dict_fn, num_batches_processed=num_batches_processed) return apply_to_collections(self.loaders, self._iterator.loader_iters, (Iterator, DataLoader), state_dict_fn)
def on_restart(self, iterator: Iterator) -> None: if not self._loaders_iter_state_dict: return def create_loader_iters(dataloader: DataLoader, state_dict: Dict) -> Iterator: """Function used to reload the iterator state before once the workers are created.""" dataloader_to_iter_on = dataloader if isinstance(dataloader, CycleIterator): dataloader = dataloader_to_iter_on.loader # dataset states are collected across all ranks rank = torch.distributed.get_rank() if distributed_available( ) else 0 state_dict = state_dict[rank] _reload_dataloader_state_dict(dataloader, state_dict) # We finally spawned the workers if any. it = iter(dataloader_to_iter_on) # restore caching state state = MergedIteratorState.from_state_dict(state_dict) if isinstance(dataloader_to_iter_on, CycleIterator): it._loader_iter.state = state else: it.state = state return it # create an un-existing token, so it doesn't activate for something else than an iterator. class DataLoaderDict(dict): pass # apply the `create_loader_iters` on the collection of `DataLoader / Iterator`. # each `Iterator` was created from the `DataLoader`. iterator._loader_iters = apply_to_collections( self.loaders, self._loaders_iter_state_dict, (Iterable, DataLoaderDict), create_loader_iters, wrong_dtype=(Sequence, Mapping), ) self._loaders_iter_state_dict = None
def state_dict(self, has_completed: bool = False) -> Dict: """The state dict includes all states from wrapped dataloaders and their samplers through the ``CaptureIterableDataset`` and fast-forward samplers. Args: has_completed: whether the current state of data fetching is considered completed or not. If it is, the current state gets returned, otherwise the previously cached state. """ if not _fault_tolerant_training() or self._iterator is None: return {} return apply_to_collections( self.loaders, self._iterator.loader_iters, (Iterator, DataLoader), partial(self._state_dict_fn, has_completed=has_completed), )
def test_apply_to_collections_dataclass(): to_reduce_1 = Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])) to_reduce_2 = Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])) def fn(a, b): return a + b reduced = apply_to_collections(to_reduce_1, to_reduce_2, (torch.Tensor, numbers.Number, np.ndarray), fn) assert reduced == Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=np.array([8.0, 10.0, 12.0])) model_example = ModelExample( example_ids=["i-1", "i-2", "i-3"], feature=to_reduce_1, label=torch.tensor([7.0, 8.0, 9.0]), ) # different types with pytest.raises( TypeError, match="Expected inputs to be dataclasses of the same type"): apply_to_collections(to_reduce_1, [1, 2], (torch.Tensor, numbers.Number, np.ndarray), fn) # unmatched fields with pytest.raises(TypeError, match="Dataclasses fields do not match"): apply_to_collections(to_reduce_1, model_example, (torch.Tensor, numbers.Number, np.ndarray), fn) classvar = WithClassVar(torch.arange( 3)) # dataclass with same number but different type of fields with pytest.raises(TypeError, match="Dataclasses fields do not match"): apply_to_collections(to_reduce_1, classvar, (torch.Tensor, numbers.Number, np.ndarray), fn)
def test_apply_to_collections(): to_reduce_1 = {"a": {"b": [1, 2]}, "c": 5} to_reduce_2 = {"a": {"b": [3, 4]}, "c": 6} def fn(a, b): return a + b # basic test reduced = apply_to_collections(to_reduce_1, to_reduce_2, int, fn) assert reduced == {"a": {"b": [4, 6]}, "c": 11} with pytest.raises(KeyError): # strict mode - if a key does not exist in both we fail apply_to_collections({ **to_reduce_2, "d": "foo" }, to_reduce_1, float, fn) # multiple dtypes reduced = apply_to_collections(to_reduce_1, to_reduce_2, (list, int), fn) assert reduced == {"a": {"b": [1, 2, 3, 4]}, "c": 11} # wrong dtype reduced = apply_to_collections(to_reduce_1, to_reduce_2, (list, int), fn, wrong_dtype=int) assert reduced == {"a": {"b": [1, 2, 3, 4]}, "c": 5} # list takes precedence because it is the type of data1 reduced = apply_to_collections([1, 2, 3], [4], (int, list), fn) assert reduced == [1, 2, 3, 4] # different sizes with pytest.raises(AssertionError, match="Sequence collections have different sizes"): apply_to_collections([[1, 2], [3]], [4], int, fn) def fn(a, b): return a.keys() | b.keys() # base case reduced = apply_to_collections(to_reduce_1, to_reduce_2, dict, fn) assert reduced == {"a", "c"} # type conversion to_reduce = [(1, 2), (3, 4)] reduced = apply_to_collections(to_reduce, to_reduce, int, lambda *x: sum(x)) assert reduced == [(2, 4), (6, 8)] # named tuple foo = namedtuple("Foo", ["bar"]) to_reduce = [foo(1), foo(2), foo(3)] reduced = apply_to_collections(to_reduce, to_reduce, int, lambda *x: sum(x)) assert reduced == [foo(2), foo(4), foo(6)] # passing none reduced1 = apply_to_collections([1, 2, 3], None, int, lambda x: x * x) reduced2 = apply_to_collections(None, [1, 2, 3], int, lambda x: x * x) assert reduced1 == reduced2 == [1, 4, 9] reduced = apply_to_collections(None, None, int, lambda x: x * x) assert reduced is None
def on_restart(self, iterator: Iterator) -> None: if not self._loaders_iter_state_dict: return def create_loader_iters(dataloader: DataLoader, state_dict: Dict) -> Iterator: """Function used to reload the iterator state before once the workers are created.""" dataloader_to_iter_on = dataloader if isinstance(dataloader, CycleIterator): dataloader = dataloader_to_iter_on.loader dataset = dataloader.dataset # We reload the states before creating the workers # The specific type of dataset will then decide if the state should be applied before or after # spawning the workers if isinstance(dataset, CaptureMapDataset): iterator_state = state_dict["state"][0] if not isinstance(iterator_state, IteratorState): iterator_state = IteratorState.from_state_dict( iterator_state) # reload sampler state ff_sampler = _find_fast_forward_samplers(dataloader) ff_sampler.load_state_dict(iterator_state.sampler_state) # reload dataset state dataset.load_state_dict( iterator_state.dataset_state, latest_worker_id=state_dict["latest_worker_id"], num_workers=iterator_state.num_workers, ) elif isinstance(dataset, CaptureIterableDataset): dataset_dict = { sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items() } dataset.load_state_dict(dataset_dict) else: raise MisconfigurationException( "This shouldn't happen. Please, open an issue on PyTorch Lightning Github." ) # We finally spawned the workers if any. it = iter(dataloader_to_iter_on) # restore caching state state = MergedIteratorState.from_state_dict(state_dict) if isinstance(dataloader_to_iter_on, CycleIterator): it._loader_iter.state = state else: it.state = state return it # create an un-existing token, so it doesn't activate for something else than an iterator. class DataLoaderDict(dict): pass # apply the `create_loader_iters` on the collection of `DataLoader / Iterator`. # each `Iterator` was created from the `DataLoader`. iterator._loader_iters = apply_to_collections( self.loaders, self._loaders_iter_state_dict, (Iterable, DataLoaderDict), create_loader_iters, wrong_dtype=(Sequence, Mapping), ) self._loaders_iter_state_dict = None