Ejemplo n.º 1
0
def _simple_graph_snapshot_restoration(datapipe: IterDataPipe,
                                       n_iterations: int,
                                       rng=None) -> None:
    r"""
    This function will restore a snapshot by fast-forwarding the given DataPipe by ``n_iterations``,
    and in the process, fast-forward its parent DataPipes as well at the cost of re-doing every computation.
    For instance, applying this function to the final DataPipe of a graph will restore the snapshot
    (via fast-forward) every DataPipe within the graph.

    After you deserialize a DataPipe, you can use its `_number_of_samples_yielded` attribute as the input
    to this function to forward the DataPipe.

    A DataPipe cannot be restored twice in a row unless there is an iteration started between the restoration
    attempts.

    Note:
        This is the simplest but least efficient way to fast-forward a DataPipe. Usage of other fast-forwarding
        methods (custom ones if necessary) are recommended.

    Args:
        datapipe: IterDataPipe to be fast-forwarded
        n_iterations: number of iterations to fast-forward
        rng: ``Optional[torch.Generator]``. If not ``None``, this RNG will be used for shuffling. The generator
            should be in its `initial` state as it was first passed into ``DataLoader`` or ``ReadingService``.
    """
    if datapipe._snapshot_state == _SnapshotState.Restored:
        raise RuntimeError(
            "Snapshot restoration cannot be applied. You can only restore simple snapshot to the graph "
            "if your graph has not been restored.")

    # For this snapshot restoration function, we want the DataPipe to be at its initial state prior to
    # simple fast-forwarding. Therefore, we need to call `reset` twice, because if `SnapshotState` is `Restored`,
    # the first reset will not actually reset.
    datapipe.reset(
    )  # This ensures `SnapshotState` is `Iterating` by this point, even if it was `Restored`.
    apply_shuffle_seed(datapipe, rng)

    remainder = n_iterations
    it = iter(datapipe)  # This always reset the DataPipe if it hasn't already.
    while remainder > 0:
        try:
            next(it)
            remainder -= 1
        except StopIteration:
            raise RuntimeError(
                f"Fast-forward {datapipe} by {n_iterations} iterations "
                "exceeds the number of samples available.")
    datapipe._fast_forward_iterator = it
    # While the DataPipe has `_fast_forward_iterator`, `next()` will get result from there instead of elsewhere.

    # This will prevent the DataPipe from resetting in the `iter()` call
    # If another DataPipe is consuming it, it won't have to start over again
    datapipe._snapshot_state = _SnapshotState.Restored
Ejemplo n.º 2
0
    def __getstate__(self):
        if IterDataPipe.getstate_hook is not None:
            return IterDataPipe.getstate_hook(self)

        state = (
            self.datapipes,
            self.length,
        )
        return state
Ejemplo n.º 3
0
    def __getstate__(self):
        if IterDataPipe.getstate_hook is not None:
            return IterDataPipe.getstate_hook(self)

        state = (
            self.main_datapipe,
            self.num_instances,
            self.buffer_size,
        )
        return state
Ejemplo n.º 4
0
    def __getstate__(self):
        if IterDataPipe.getstate_hook is not None:
            return IterDataPipe.getstate_hook(self)

        state = (
            self.datapipes,
            self.length,
            self._valid_iterator_id,
            self._number_of_samples_yielded,
        )
        return state
Ejemplo n.º 5
0
 def __getstate__(self):
     if IterDataPipe.getstate_hook is not None:
         return IterDataPipe.getstate_hook(self)
     state = (
         self.datapipe,
         self.buffer_size,
         self._enabled,
         self._seed,
         self._rng.getstate(),
     )
     return state
Ejemplo n.º 6
0
    def __getstate__(self):
        if IterDataPipe.getstate_hook is not None:
            return IterDataPipe.getstate_hook(self)

        state = (
            self.main_datapipe,
            self.num_instances,
            self.buffer_size,
            self._valid_iterator_id,
            self._number_of_samples_yielded,
        )
        return state
Ejemplo n.º 7
0
    def __call__(self, cls):
        if issubclass(cls, IterDataPipe):
            if isinstance(cls, Type):  # type: ignore[arg-type]
                if not isinstance(cls, _DataPipeMeta):
                    raise TypeError(
                        '`functional_datapipe` can only decorate IterDataPipe')
            # with non_deterministic decorator
            else:
                if not isinstance(cls, non_deterministic) and \
                    not (hasattr(cls, '__self__') and
                         isinstance(cls.__self__, non_deterministic)):
                    raise TypeError(
                        '`functional_datapipe` can only decorate IterDataPipe')
            IterDataPipe.register_datapipe_as_function(
                self.name,
                cls,
                enable_df_api_tracing=self.enable_df_api_tracing)
        elif issubclass(cls, MapDataPipe):
            MapDataPipe.register_datapipe_as_function(self.name, cls)

        return cls
Ejemplo n.º 8
0
 def __getstate__(self):
     if IterDataPipe.getstate_hook is not None:
         return IterDataPipe.getstate_hook(self)
     state = (
         self.datapipe,
         self.group_key_fn,
         self.max_buffer_size,
         self.group_size,
         self.guaranteed_group_size,
         self.drop_remaining,
         self.wrapper_class,
     )
     return state
Ejemplo n.º 9
0
    def __getstate__(self):
        if IterDataPipe.getstate_hook is not None:
            return IterDataPipe.getstate_hook(self)

        serialized_fn_with_method = serialize_fn(self.classifier_fn)
        state = (
            self.main_datapipe,
            self.num_instances,
            self.buffer_size,
            serialized_fn_with_method,
            self.drop_none,
        )
        return state
Ejemplo n.º 10
0
 def __getstate__(self):
     if IterDataPipe.getstate_hook is not None:
         return IterDataPipe.getstate_hook(self)
     state = (
         self.datapipe,
         self.buffer_size,
         self._enabled,
         self._seed,
         self._valid_iterator_id,
         self._number_of_samples_yielded,
         self._rng.getstate(),
     )
     return state
Ejemplo n.º 11
0
    def __getstate__(self):
        if IterDataPipe.getstate_hook is not None:
            return IterDataPipe.getstate_hook(self)

        state = (
            self.main_datapipe,
            self.num_instances,
            self.buffer_size,
            serialize_fn(self.classifier_fn)
            if DILL_AVAILABLE else self.classifier_fn,
            self.drop_none,
        )
        return state
Ejemplo n.º 12
0
 def __getstate__(self):
     if IterDataPipe.getstate_hook is not None:
         return IterDataPipe.getstate_hook(self)
     state = (
         self.datapipe,
         self.group_key_fn,
         self.max_buffer_size,
         self.group_size,
         self.guaranteed_group_size,
         self.drop_remaining,
         self.wrapper_class,
         self._valid_iterator_id,
         self._number_of_samples_yielded,
     )
     return state