Esempio n. 1
0
    def _update(self, update_fn, value, **kwargs):
        """Applies updates depending on the context.

    The method calls `_update_replica` in replica context,
    `_update_cross_replica` in cross replica context, and `update_fn` in update
    context.

    If `read_value` is True, the method returns the updated Variable. If
    `read_value` is False, the method returns the update `tf.Operation`.

    Args:
      update_fn: A callable to pass to `strategy.extended.update` to update the
        variable. It should have the same signature as `Variable.assign()`.
      value: value to be passed to `update_fn`.
      **kwargs: keyword arguments to `update_fn`.

    Returns:
      Updated variable or `tf.Operation`.

    """
        with ds_context.enter_or_assert_strategy(self.distribute_strategy):
            if ds_context.in_cross_replica_context():
                update_replica_id = distribute_lib.get_update_replica_id()
                if update_replica_id is not None:
                    return update_fn(self._values[update_replica_id], value,
                                     **kwargs)
                return self._update_cross_replica(update_fn, value, **kwargs)
            else:
                values_util.assert_replica_context(self.distribute_strategy)
                return self._update_replica(update_fn, value, **kwargs)
Esempio n. 2
0
 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
     """Converts a variable to a tensor."""
     with ds_context.enter_or_assert_strategy(self._distribute_strategy):
         return ops.convert_to_tensor(self._get(),
                                      dtype=dtype,
                                      name=name,
                                      as_ref=as_ref)
Esempio n. 3
0
 def value(self):
     with ds_context.enter_or_assert_strategy(self._distribute_strategy):
         if ds_context.in_cross_replica_context():
             return self._get_cross_replica()
         else:
             # _get_on_device_or_primary() returns a Variable.
             return self._get_on_device_or_primary().value()
Esempio n. 4
0
  def __deepcopy__(self, memo):
    """Perform a deepcopy of the `AggregatingVariable`.

    Unlike the deepcopy of a regular tf.Variable, this keeps the original
    strategy and devices of the `AggregatingVariable`.  To avoid confusion
    with the behavior of deepcopy on a regular `Variable` (which does
    copy into new devices), we only allow a deepcopy of a `AggregatingVariable`
    within its originating strategy scope.

    Args:
      memo: The memoization object for `deepcopy`.

    Returns:
      A deep copy of the current `AggregatingVariable`.

    Raises:
      RuntimeError: If trying to deepcopy into a different strategy.
    """
    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
      v = copy.deepcopy(self._v, memo)

    copied_variable = type(self)(
        strategy=self._distribute_strategy,
        v=v,
        aggregation=self._aggregation)

    memo[id(self)] = copied_variable

    return copied_variable
    def skip(self, delta):
        """Advance the counter of a counter-based RNG.

    Args:
      delta: the amount of advancement. The state of the RNG after
        `skip(n)` will be the same as that after `normal([n])`
        (or any other distribution). The actual increment added to the
        counter is an unspecified implementation detail.

    Returns:
      A `Tensor` of type `int64`.
    """
        def update_fn(v):
            return self._skip_single_var(v, delta)

        # TODO(b/170515001): Always call strategy.extended.update after calling it
        #   from both replica context and cross-replica context is supported.
        if values_util.is_saving_non_distributed():
            # Assumes replica context with replica_id=0, since we only save the first
            # replica.
            return update_fn(self.state)
        if self._distribution_strategy is not None:
            with ds_context.enter_or_assert_strategy(
                    self._distribution_strategy):
                if ds_context.in_cross_replica_context():
                    # Code that operates on all replicas of a variable cannot be saved
                    # without retracing.
                    values_util.mark_as_unsaveable()
                    # In cross-replica context we need to use strategy.extended.update.
                    return ds_context.get_strategy().extended.update(
                        self.state, update_fn)
        return update_fn(self.state)
Esempio n. 6
0
  def _get_cross_replica(self):
    if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
      return self._primary

    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
      return self._distribute_strategy.reduce(
          reduce_util.ReduceOp.from_variable_aggregation(self.aggregation),
          self,
          axis=None)
Esempio n. 7
0
 def _mirrored_update(self, update_fn, **kwargs):
     with ds_context.enter_or_assert_strategy(self._distribute_strategy):
         if (ds_context.in_cross_replica_context()
                 and (enclosing_tpu_context() is not None)):
             return self._distribute_strategy.extended.update(self,
                                                              update_fn,
                                                              kwargs=kwargs)
         else:
             return values.MirroredVariable._mirrored_update(
                 self, update_fn, **kwargs)
Esempio n. 8
0
def on_read_assign_add_cross_replica(var, value, read_value=True):
    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
        if ds_context.in_cross_replica_context():
            if var.aggregation == vs.VariableAggregation.SUM:
                raise ValueError(
                    "SyncOnReadVariable does not support `assign_add` in "
                    "cross-replica context when aggregation is set to "
                    "`tf.VariableAggregation.SUM`.")
            return assign_on_each_device(var, assign_add_on_device, value,
                                         read_value)
Esempio n. 9
0
 def _assign_func(self, *args, **kwargs):
     with ds_context.enter_or_assert_strategy(self._distribute_strategy):
         if (ds_context.in_cross_replica_context()
                 and (enclosing_tpu_context() is not None)):
             f = kwargs.pop("f")
             return self._distribute_strategy.extended.update(self,
                                                              f,
                                                              args=args,
                                                              kwargs=kwargs)
         else:
             return values.MirroredVariable._assign_func(
                 self, *args, **kwargs)
Esempio n. 10
0
 def assign(self, value, use_locking=False, name=None, read_value=True):
   with ds_context.enter_or_assert_strategy(self._distribute_strategy):
     if ds_context.in_cross_replica_context():
       # To preserve the sum across save and restore, we have to divide the
       # total across all devices when restoring a variable that was summed
       # when saving.
       if self._aggregation == vs.VariableAggregation.SUM:
         value = math_ops.cast(value / len(self._values), self.dtype)
       return self._assign_on_each_device(values_util.assign_on_device, value,
                                          read_value)
     else:
       return super(SyncOnReadVariable,
                    self).assign(value, use_locking, name, read_value)
Esempio n. 11
0
 def _preprocess_key(self, key):
   if self._distribution_strategy is None:
     return key
   with ds_context.enter_or_assert_strategy(self._distribution_strategy):
     replica_id = get_replica_id()
     if replica_id is not None:
       replica_id = array_ops.stack([replica_id, 0], axis=0)
       replica_id = math_ops.cast(replica_id, dtypes.uint64)
       # Conceptually: key = hash(key, replica_id)
       key = gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
           shape=[1], key=key, counter=replica_id, dtype=dtypes.uint64,
           alg=self.algorithm)
     return key
Esempio n. 12
0
 def assign_add(self, value, use_locking=False, name=None, read_value=True):
   with ds_context.enter_or_assert_strategy(self._distribute_strategy):
     if ds_context.in_cross_replica_context():
       if self._aggregation == vs.VariableAggregation.SUM:
         raise ValueError(
             "SyncOnReadVariable does not support `assign_add` in "
             "cross-replica context when aggregation is set to "
             "`tf.VariableAggregation.SUM`.")
       return self._assign_on_each_device(values_util.assign_add_on_device,
                                          value, read_value)
     else:
       return super(SyncOnReadVariable,
                    self).assign_add(value, use_locking, name, read_value)
Esempio n. 13
0
def on_read_assign_cross_replica(var, value, read_value=True):
    """Return the value of the variable in cross replica context."""
    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
        if ds_context.in_cross_replica_context():
            # To preserve the sum across save and restore, we have to divide the
            # total across all devices when restoring a variable that was summed
            # when saving.
            tensor = value
            if var.aggregation == vs.VariableAggregation.SUM:
                strategy = var._distribute_strategy  # pylint: disable=protected-access
                tensor = math_ops.cast(tensor / strategy.num_replicas_in_sync,
                                       var.dtype)
            return assign_on_each_device(var, assign_on_device, tensor,
                                         read_value)
Esempio n. 14
0
def on_read_assign_cross_replica(var, value, read_value=True):
    """Return the value of the variable in cross replica context."""
    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
        if ds_context.in_cross_replica_context():
            # To preserve the sum across save and restore, we have to divide the
            # total across all devices when restoring a variable that was summed
            # when saving.
            tensor = value
            # TODO(anjs): Should this be over all the replicas in sync since we
            # call `reduce` on the variable during read?
            if var.aggregation == vs.VariableAggregation.SUM:
                tensor = math_ops.cast(tensor / len(var._values), var.dtype)  # pylint: disable=protected-access
            return assign_on_each_device(var, assign_on_device, tensor,
                                         read_value)
Esempio n. 15
0
    def _assign_func(self, *args, **kwargs):
        with ds_context.enter_or_assert_strategy(self._distribute_strategy):
            f = kwargs.pop("f")
            if ds_context.in_cross_replica_context():
                if distribute_lib.get_update_replica_id() is not None:
                    # We are calling an assign function in an update context.
                    return f(self._v, *args, **kwargs)

                # We are calling an assign function in cross replica context, wrap it in
                # an update call.
                return self._distribute_strategy.extended.update(self,
                                                                 f,
                                                                 args=args,
                                                                 kwargs=kwargs)
            else:
                replica_context = ds_context.get_replica_context()
                assert replica_context
                # We are calling an assign function in replica context.
                # We reduce the value we want to assign/add/sub. More details about how
                # we handle the different use cases can be found in the _reduce method.
                # We call the function with the reduced value.
                if self._aggregation == vs.VariableAggregation.NONE:
                    raise ValueError(
                        values_util.aggregation_error_msg.format(
                            variable_type="AggregatingVariable"))

                def merge_fn(strategy,
                             value,
                             use_locking=False,
                             name=None,
                             read_value=True):
                    v = values_util.apply_aggregation(strategy, value,
                                                      self._aggregation, self)
                    if name and isinstance(name, values.PerReplica):
                        name = name.values[0]
                    return strategy.extended.update(self,
                                                    f,
                                                    args=(v, ),
                                                    kwargs={
                                                        "use_locking":
                                                        use_locking,
                                                        "name": name,
                                                        "read_value":
                                                        read_value
                                                    })

                return replica_context.merge_call(merge_fn,
                                                  args=args,
                                                  kwargs=kwargs)
Esempio n. 16
0
    def _skip(self, delta):
        def update_fn(v):
            return self._skip_single_var(v, delta)

        # TODO(b/170515001): Always call strategy.extended.update after calling it
        #   from both replica context and cross-replica context is supported.
        if values_util.is_saving_non_distributed():
            # Assumes replica context with replica_id=0, since we only save the first
            # replica.
            return update_fn(self.state)
        if self._distribution_strategy is not None:
            with ds_context.enter_or_assert_strategy(
                    self._distribution_strategy):
                if ds_context.in_cross_replica_context():
                    # Code that operates on all replicas of a variable cannot be saved
                    # without retracing.
                    values_util.mark_as_unsaveable()
                    # In cross-replica context we need to use strategy.extended.update.
                    return ds_context.get_strategy().extended.update(
                        self.state, update_fn)
        return update_fn(self.state)
Esempio n. 17
0
 def _as_graph_element(self):
     # pylint: disable=protected-access
     with ds_context.enter_or_assert_strategy(self._distribute_strategy):
         if ds_context.in_cross_replica_context():
             return ops.convert_to_tensor(self._get_cross_replica())
     return self._get()._as_graph_element()
Esempio n. 18
0
 def read_value(self):
     with ds_context.enter_or_assert_strategy(self._distribute_strategy):
         return array_ops.identity(self._get())