def merge_fn(strategy, value, use_locking=False, name=None, read_value=True): v = values_util.apply_aggregation(strategy, value, self._aggregation, self) if name and isinstance(name, values.PerReplica): name = name.values[0] return strategy.extended.update( self, f, args=(v,), kwargs={ "use_locking": use_locking, "name": name, "read_value": read_value })
def merge_fn(strategy, value, **kwargs): """Aggregate values and update all variables in cross replica context.""" # Don't allow MEAN with non float dtype, since it may cause unexpected # precision loss. Python3 and NumPy automatically upcast integers to # float in division, but we should always preserve the type. # # Note that to be backward compatible we allow the case when the value # is *always* the same on each replica. I.E. value is not a # PerReplica. Refer to regroup() to see how values are grouped. if self._aggregation == vs.VariableAggregation.MEAN and ( not self.dtype.is_floating) and isinstance(value, PerReplica): raise ValueError( "Cannot update non-float variables with " "tf.VariableAggregation.MEAN aggregation in replica context. " "Either change the variable dtype to float or update it in " "cross-replica context.") assert strategy == self.distribute_strategy v = values_util.apply_aggregation(strategy, value, self.aggregation, self) return self._update_cross_replica(update_fn, v, **kwargs)