def _update(self, update_fn, value, **kwargs): """Applies updates depending on the context. The method calls `_update_replica` in replica context, `_update_cross_replica` in cross replica context, and `update_fn` in update context. If `read_value` is True, the method returns the updated Variable. If `read_value` is False, the method returns the update `tf.Operation`. Args: update_fn: A callable to pass to `strategy.extended.update` to update the variable. It should have the same signature as `Variable.assign()`. value: value to be passed to `update_fn`. **kwargs: keyword arguments to `update_fn`. Returns: Updated variable or `tf.Operation`. """ with ds_context.enter_or_assert_strategy(self.distribute_strategy): if ds_context.in_cross_replica_context(): update_replica_id = distribute_lib.get_update_replica_id() if update_replica_id is not None: return update_fn(self._values[update_replica_id], value, **kwargs) return self._update_cross_replica(update_fn, value, **kwargs) else: values_util.assert_replica_context(self.distribute_strategy) return self._update_replica(update_fn, value, **kwargs)
def get_current_replica_id_as_int(): """Returns the current replica ID as an integer, or `None`.""" replica_context = ds_context.get_replica_context() if replica_context: replica_id = replica_context.replica_id_in_sync_group if not isinstance(replica_id, int): replica_id = tensor_util.constant_value(replica_id) else: replica_id = distribute_lib.get_update_replica_id() return replica_id
def get_current_replica_id_as_int(): """Returns the current replica ID as an integer, or `None`.""" replica_context = ds_context.get_replica_context() if replica_context: replica_id = replica_context._replica_id # pylint: disable=protected-access if not isinstance(replica_id, int): replica_id = tensor_util.constant_value(replica_id) else: replica_id = distribute_lib.get_update_replica_id() return replica_id
def _update(strategy, var, update_fn, args): """Applies updates depending on the context.""" assert distribution_strategy_context.in_cross_replica_context(), ( "_update can only be called in cross-replica context") if distribute_lib.get_update_replica_id() is not None: # Call update_fn on var to delegate the implementation. We expect `var` will # do the right thing in update context, e.g, if `var` is a MirroredVariable, # it should pick its component variable based on `update_replica_id` and # only update that. return update_fn(var, *args) else: return strategy.extended.update(var, update_fn, args)
def _assign_func(self, *args, **kwargs): with ds_context.enter_or_assert_strategy(self._distribute_strategy): f = kwargs.pop("f") if ds_context.in_cross_replica_context(): if distribute_lib.get_update_replica_id() is not None: # We are calling an assign function in an update context. return f(self._v, *args, **kwargs) # We are calling an assign function in cross replica context, wrap it in # an update call. return self._distribute_strategy.extended.update(self, f, args=args, kwargs=kwargs) else: replica_context = ds_context.get_replica_context() assert replica_context # We are calling an assign function in replica context. # We reduce the value we want to assign/add/sub. More details about how # we handle the different use cases can be found in the _reduce method. # We call the function with the reduced value. if self._aggregation == vs.VariableAggregation.NONE: raise ValueError( values_util.aggregation_error_msg.format( variable_type="AggregatingVariable")) def merge_fn(strategy, value, use_locking=False, name=None, read_value=True): v = values_util.apply_aggregation(strategy, value, self._aggregation, self) if name and isinstance(name, values.PerReplica): name = name.values[0] return strategy.extended.update(self, f, args=(v, ), kwargs={ "use_locking": use_locking, "name": name, "read_value": read_value }) return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs)
def in_replica_update_context(): return distribute_lib.get_update_replica_id() is not None