def skip(self, delta):
        """Advance the counter of a counter-based RNG.

    Args:
      delta: the amount of advancement. The state of the RNG after
        `skip(n)` will be the same as that after `normal([n])`
        (or any other distribution). The actual increment added to the
        counter is an unspecified implementation detail.

    Returns:
      A `Tensor` of type `int64`.
    """
        def update_fn(v):
            return self._skip_single_var(v, delta)

        # TODO(b/170515001): Always call strategy.extended.update after calling it
        #   from both replica context and cross-replica context is supported.
        if values_util.is_saving_non_distributed():
            # Assumes replica context with replica_id=0, since we only save the first
            # replica.
            return update_fn(self.state)
        if self._distribution_strategy is not None:
            with ds_context.enter_or_assert_strategy(
                    self._distribution_strategy):
                if ds_context.in_cross_replica_context():
                    # Code that operates on all replicas of a variable cannot be saved
                    # without retracing.
                    values_util.mark_as_unsaveable()
                    # In cross-replica context we need to use strategy.extended.update.
                    return ds_context.get_strategy().extended.update(
                        self.state, update_fn)
        return update_fn(self.state)
Esempio n. 2
0
    def _skip(self, delta):
        def update_fn(v):
            return self._skip_single_var(v, delta)

        # TODO(b/170515001): Always call strategy.extended.update after calling it
        #   from both replica context and cross-replica context is supported.
        if values_util.is_saving_non_distributed():
            # Assumes replica context with replica_id=0, since we only save the first
            # replica.
            return update_fn(self.state)
        if self._distribution_strategy is not None:
            with ds_context.enter_or_assert_strategy(
                    self._distribution_strategy):
                if ds_context.in_cross_replica_context():
                    # Code that operates on all replicas of a variable cannot be saved
                    # without retracing.
                    values_util.mark_as_unsaveable()
                    # In cross-replica context we need to use strategy.extended.update.
                    return ds_context.get_strategy().extended.update(
                        self.state, update_fn)
        return update_fn(self.state)