def _update(self, update_fn, value, **kwargs): """Applies updates depending on the context. The method calls `_update_replica` in replica context, `_update_cross_replica` in cross replica context, and `update_fn` in update context. If `read_value` is True, the method returns the updated Variable. If `read_value` is False, the method returns the update `tf.Operation`. Args: update_fn: A callable to pass to `strategy.extended.update` to update the variable. It should have the same signature as `Variable.assign()`. value: value to be passed to `update_fn`. **kwargs: keyword arguments to `update_fn`. Returns: Updated variable or `tf.Operation`. """ with ds_context.enter_or_assert_strategy(self.distribute_strategy): if ds_context.in_cross_replica_context(): update_replica_id = distribute_lib.get_update_replica_id() if update_replica_id is not None: return update_fn(self._values[update_replica_id], value, **kwargs) return self._update_cross_replica(update_fn, value, **kwargs) else: values_util.assert_replica_context(self.distribute_strategy) return self._update_replica(update_fn, value, **kwargs)
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): """Converts a variable to a tensor.""" with ds_context.enter_or_assert_strategy(self._distribute_strategy): return ops.convert_to_tensor(self._get(), dtype=dtype, name=name, as_ref=as_ref)
def value(self): with ds_context.enter_or_assert_strategy(self._distribute_strategy): if ds_context.in_cross_replica_context(): return self._get_cross_replica() else: # _get_on_device_or_primary() returns a Variable. return self._get_on_device_or_primary().value()
def __deepcopy__(self, memo): """Perform a deepcopy of the `AggregatingVariable`. Unlike the deepcopy of a regular tf.Variable, this keeps the original strategy and devices of the `AggregatingVariable`. To avoid confusion with the behavior of deepcopy on a regular `Variable` (which does copy into new devices), we only allow a deepcopy of a `AggregatingVariable` within its originating strategy scope. Args: memo: The memoization object for `deepcopy`. Returns: A deep copy of the current `AggregatingVariable`. Raises: RuntimeError: If trying to deepcopy into a different strategy. """ with ds_context.enter_or_assert_strategy(self._distribute_strategy): v = copy.deepcopy(self._v, memo) copied_variable = type(self)( strategy=self._distribute_strategy, v=v, aggregation=self._aggregation) memo[id(self)] = copied_variable return copied_variable
def skip(self, delta): """Advance the counter of a counter-based RNG. Args: delta: the amount of advancement. The state of the RNG after `skip(n)` will be the same as that after `normal([n])` (or any other distribution). The actual increment added to the counter is an unspecified implementation detail. Returns: A `Tensor` of type `int64`. """ def update_fn(v): return self._skip_single_var(v, delta) # TODO(b/170515001): Always call strategy.extended.update after calling it # from both replica context and cross-replica context is supported. if values_util.is_saving_non_distributed(): # Assumes replica context with replica_id=0, since we only save the first # replica. return update_fn(self.state) if self._distribution_strategy is not None: with ds_context.enter_or_assert_strategy( self._distribution_strategy): if ds_context.in_cross_replica_context(): # Code that operates on all replicas of a variable cannot be saved # without retracing. values_util.mark_as_unsaveable() # In cross-replica context we need to use strategy.extended.update. return ds_context.get_strategy().extended.update( self.state, update_fn) return update_fn(self.state)
def _get_cross_replica(self): if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: return self._primary with ds_context.enter_or_assert_strategy(self._distribute_strategy): return self._distribute_strategy.reduce( reduce_util.ReduceOp.from_variable_aggregation(self.aggregation), self, axis=None)
def _mirrored_update(self, update_fn, **kwargs): with ds_context.enter_or_assert_strategy(self._distribute_strategy): if (ds_context.in_cross_replica_context() and (enclosing_tpu_context() is not None)): return self._distribute_strategy.extended.update(self, update_fn, kwargs=kwargs) else: return values.MirroredVariable._mirrored_update( self, update_fn, **kwargs)
def on_read_assign_add_cross_replica(var, value, read_value=True): with ds_context.enter_or_assert_strategy(var.distribute_strategy): if ds_context.in_cross_replica_context(): if var.aggregation == vs.VariableAggregation.SUM: raise ValueError( "SyncOnReadVariable does not support `assign_add` in " "cross-replica context when aggregation is set to " "`tf.VariableAggregation.SUM`.") return assign_on_each_device(var, assign_add_on_device, value, read_value)
def _assign_func(self, *args, **kwargs): with ds_context.enter_or_assert_strategy(self._distribute_strategy): if (ds_context.in_cross_replica_context() and (enclosing_tpu_context() is not None)): f = kwargs.pop("f") return self._distribute_strategy.extended.update(self, f, args=args, kwargs=kwargs) else: return values.MirroredVariable._assign_func( self, *args, **kwargs)
def assign(self, value, use_locking=False, name=None, read_value=True): with ds_context.enter_or_assert_strategy(self._distribute_strategy): if ds_context.in_cross_replica_context(): # To preserve the sum across save and restore, we have to divide the # total across all devices when restoring a variable that was summed # when saving. if self._aggregation == vs.VariableAggregation.SUM: value = math_ops.cast(value / len(self._values), self.dtype) return self._assign_on_each_device(values_util.assign_on_device, value, read_value) else: return super(SyncOnReadVariable, self).assign(value, use_locking, name, read_value)
def _preprocess_key(self, key): if self._distribution_strategy is None: return key with ds_context.enter_or_assert_strategy(self._distribution_strategy): replica_id = get_replica_id() if replica_id is not None: replica_id = array_ops.stack([replica_id, 0], axis=0) replica_id = math_ops.cast(replica_id, dtypes.uint64) # Conceptually: key = hash(key, replica_id) key = gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2( shape=[1], key=key, counter=replica_id, dtype=dtypes.uint64, alg=self.algorithm) return key
def assign_add(self, value, use_locking=False, name=None, read_value=True): with ds_context.enter_or_assert_strategy(self._distribute_strategy): if ds_context.in_cross_replica_context(): if self._aggregation == vs.VariableAggregation.SUM: raise ValueError( "SyncOnReadVariable does not support `assign_add` in " "cross-replica context when aggregation is set to " "`tf.VariableAggregation.SUM`.") return self._assign_on_each_device(values_util.assign_add_on_device, value, read_value) else: return super(SyncOnReadVariable, self).assign_add(value, use_locking, name, read_value)
def on_read_assign_cross_replica(var, value, read_value=True): """Return the value of the variable in cross replica context.""" with ds_context.enter_or_assert_strategy(var.distribute_strategy): if ds_context.in_cross_replica_context(): # To preserve the sum across save and restore, we have to divide the # total across all devices when restoring a variable that was summed # when saving. tensor = value if var.aggregation == vs.VariableAggregation.SUM: strategy = var._distribute_strategy # pylint: disable=protected-access tensor = math_ops.cast(tensor / strategy.num_replicas_in_sync, var.dtype) return assign_on_each_device(var, assign_on_device, tensor, read_value)
def on_read_assign_cross_replica(var, value, read_value=True): """Return the value of the variable in cross replica context.""" with ds_context.enter_or_assert_strategy(var.distribute_strategy): if ds_context.in_cross_replica_context(): # To preserve the sum across save and restore, we have to divide the # total across all devices when restoring a variable that was summed # when saving. tensor = value # TODO(anjs): Should this be over all the replicas in sync since we # call `reduce` on the variable during read? if var.aggregation == vs.VariableAggregation.SUM: tensor = math_ops.cast(tensor / len(var._values), var.dtype) # pylint: disable=protected-access return assign_on_each_device(var, assign_on_device, tensor, read_value)
def _assign_func(self, *args, **kwargs): with ds_context.enter_or_assert_strategy(self._distribute_strategy): f = kwargs.pop("f") if ds_context.in_cross_replica_context(): if distribute_lib.get_update_replica_id() is not None: # We are calling an assign function in an update context. return f(self._v, *args, **kwargs) # We are calling an assign function in cross replica context, wrap it in # an update call. return self._distribute_strategy.extended.update(self, f, args=args, kwargs=kwargs) else: replica_context = ds_context.get_replica_context() assert replica_context # We are calling an assign function in replica context. # We reduce the value we want to assign/add/sub. More details about how # we handle the different use cases can be found in the _reduce method. # We call the function with the reduced value. if self._aggregation == vs.VariableAggregation.NONE: raise ValueError( values_util.aggregation_error_msg.format( variable_type="AggregatingVariable")) def merge_fn(strategy, value, use_locking=False, name=None, read_value=True): v = values_util.apply_aggregation(strategy, value, self._aggregation, self) if name and isinstance(name, values.PerReplica): name = name.values[0] return strategy.extended.update(self, f, args=(v, ), kwargs={ "use_locking": use_locking, "name": name, "read_value": read_value }) return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs)
def _skip(self, delta): def update_fn(v): return self._skip_single_var(v, delta) # TODO(b/170515001): Always call strategy.extended.update after calling it # from both replica context and cross-replica context is supported. if values_util.is_saving_non_distributed(): # Assumes replica context with replica_id=0, since we only save the first # replica. return update_fn(self.state) if self._distribution_strategy is not None: with ds_context.enter_or_assert_strategy( self._distribution_strategy): if ds_context.in_cross_replica_context(): # Code that operates on all replicas of a variable cannot be saved # without retracing. values_util.mark_as_unsaveable() # In cross-replica context we need to use strategy.extended.update. return ds_context.get_strategy().extended.update( self.state, update_fn) return update_fn(self.state)
def _as_graph_element(self): # pylint: disable=protected-access with ds_context.enter_or_assert_strategy(self._distribute_strategy): if ds_context.in_cross_replica_context(): return ops.convert_to_tensor(self._get_cross_replica()) return self._get()._as_graph_element()
def read_value(self): with ds_context.enter_or_assert_strategy(self._distribute_strategy): return array_ops.identity(self._get())