示例#1
0
 def merge_fn(strategy,
              value,
              use_locking=False,
              name=None,
              read_value=True):
   v = values_util.apply_aggregation(strategy, value, self._aggregation,
                                     self)
   if name and isinstance(name, values.PerReplica):
     name = name.values[0]
   return strategy.extended.update(
       self,
       f,
       args=(v,),
       kwargs={
           "use_locking": use_locking,
           "name": name,
           "read_value": read_value
       })
示例#2
0
    def merge_fn(strategy, value, **kwargs):
      """Aggregate values and update all variables in cross replica context."""
      # Don't allow MEAN with non float dtype, since it may cause unexpected
      # precision loss. Python3 and NumPy automatically upcast integers to
      # float in division, but we should always preserve the type.
      #
      # Note that to be backward compatible we allow the case when the value
      # is *always* the same on each replica. I.E. value is not a
      # PerReplica. Refer to regroup() to see how values are grouped.
      if self._aggregation == vs.VariableAggregation.MEAN and (
          not self.dtype.is_floating) and isinstance(value, PerReplica):
        raise ValueError(
            "Cannot update non-float variables with "
            "tf.VariableAggregation.MEAN aggregation in replica context. "
            "Either change the variable dtype to float or update it in "
            "cross-replica context.")

      assert strategy == self.distribute_strategy
      v = values_util.apply_aggregation(strategy, value, self.aggregation, self)
      return self._update_cross_replica(update_fn, v, **kwargs)