def _simple_graph_snapshot_restoration(datapipe: IterDataPipe, n_iterations: int, rng=None) -> None: r""" This function will restore a snapshot by fast-forwarding the given DataPipe by ``n_iterations``, and in the process, fast-forward its parent DataPipes as well at the cost of re-doing every computation. For instance, applying this function to the final DataPipe of a graph will restore the snapshot (via fast-forward) every DataPipe within the graph. After you deserialize a DataPipe, you can use its `_number_of_samples_yielded` attribute as the input to this function to forward the DataPipe. A DataPipe cannot be restored twice in a row unless there is an iteration started between the restoration attempts. Note: This is the simplest but least efficient way to fast-forward a DataPipe. Usage of other fast-forwarding methods (custom ones if necessary) are recommended. Args: datapipe: IterDataPipe to be fast-forwarded n_iterations: number of iterations to fast-forward rng: ``Optional[torch.Generator]``. If not ``None``, this RNG will be used for shuffling. The generator should be in its `initial` state as it was first passed into ``DataLoader`` or ``ReadingService``. """ if datapipe._snapshot_state == _SnapshotState.Restored: raise RuntimeError( "Snapshot restoration cannot be applied. You can only restore simple snapshot to the graph " "if your graph has not been restored.") # For this snapshot restoration function, we want the DataPipe to be at its initial state prior to # simple fast-forwarding. Therefore, we need to call `reset` twice, because if `SnapshotState` is `Restored`, # the first reset will not actually reset. datapipe.reset( ) # This ensures `SnapshotState` is `Iterating` by this point, even if it was `Restored`. apply_shuffle_seed(datapipe, rng) remainder = n_iterations it = iter(datapipe) # This always reset the DataPipe if it hasn't already. while remainder > 0: try: next(it) remainder -= 1 except StopIteration: raise RuntimeError( f"Fast-forward {datapipe} by {n_iterations} iterations " "exceeds the number of samples available.") datapipe._fast_forward_iterator = it # While the DataPipe has `_fast_forward_iterator`, `next()` will get result from there instead of elsewhere. # This will prevent the DataPipe from resetting in the `iter()` call # If another DataPipe is consuming it, it won't have to start over again datapipe._snapshot_state = _SnapshotState.Restored
def __getstate__(self): if IterDataPipe.getstate_hook is not None: return IterDataPipe.getstate_hook(self) state = ( self.datapipes, self.length, ) return state
def __getstate__(self): if IterDataPipe.getstate_hook is not None: return IterDataPipe.getstate_hook(self) state = ( self.main_datapipe, self.num_instances, self.buffer_size, ) return state
def __getstate__(self): if IterDataPipe.getstate_hook is not None: return IterDataPipe.getstate_hook(self) state = ( self.datapipes, self.length, self._valid_iterator_id, self._number_of_samples_yielded, ) return state
def __getstate__(self): if IterDataPipe.getstate_hook is not None: return IterDataPipe.getstate_hook(self) state = ( self.datapipe, self.buffer_size, self._enabled, self._seed, self._rng.getstate(), ) return state
def __getstate__(self): if IterDataPipe.getstate_hook is not None: return IterDataPipe.getstate_hook(self) state = ( self.main_datapipe, self.num_instances, self.buffer_size, self._valid_iterator_id, self._number_of_samples_yielded, ) return state
def __call__(self, cls): if issubclass(cls, IterDataPipe): if isinstance(cls, Type): # type: ignore[arg-type] if not isinstance(cls, _DataPipeMeta): raise TypeError( '`functional_datapipe` can only decorate IterDataPipe') # with non_deterministic decorator else: if not isinstance(cls, non_deterministic) and \ not (hasattr(cls, '__self__') and isinstance(cls.__self__, non_deterministic)): raise TypeError( '`functional_datapipe` can only decorate IterDataPipe') IterDataPipe.register_datapipe_as_function( self.name, cls, enable_df_api_tracing=self.enable_df_api_tracing) elif issubclass(cls, MapDataPipe): MapDataPipe.register_datapipe_as_function(self.name, cls) return cls
def __getstate__(self): if IterDataPipe.getstate_hook is not None: return IterDataPipe.getstate_hook(self) state = ( self.datapipe, self.group_key_fn, self.max_buffer_size, self.group_size, self.guaranteed_group_size, self.drop_remaining, self.wrapper_class, ) return state
def __getstate__(self): if IterDataPipe.getstate_hook is not None: return IterDataPipe.getstate_hook(self) serialized_fn_with_method = serialize_fn(self.classifier_fn) state = ( self.main_datapipe, self.num_instances, self.buffer_size, serialized_fn_with_method, self.drop_none, ) return state
def __getstate__(self): if IterDataPipe.getstate_hook is not None: return IterDataPipe.getstate_hook(self) state = ( self.datapipe, self.buffer_size, self._enabled, self._seed, self._valid_iterator_id, self._number_of_samples_yielded, self._rng.getstate(), ) return state
def __getstate__(self): if IterDataPipe.getstate_hook is not None: return IterDataPipe.getstate_hook(self) state = ( self.main_datapipe, self.num_instances, self.buffer_size, serialize_fn(self.classifier_fn) if DILL_AVAILABLE else self.classifier_fn, self.drop_none, ) return state
def __getstate__(self): if IterDataPipe.getstate_hook is not None: return IterDataPipe.getstate_hook(self) state = ( self.datapipe, self.group_key_fn, self.max_buffer_size, self.group_size, self.guaranteed_group_size, self.drop_remaining, self.wrapper_class, self._valid_iterator_id, self._number_of_samples_yielded, ) return state