def _call_flat(self, args, captured_inputs, cancellation_manager=None): def get_handle(x): return x.handle if distribute_utils.is_distributed_variable( x) else x def get_unused_handle(x): return _unused_handle() if distribute_utils.is_distributed_variable(x) \ else x if (ds_context.get_replica_context() is not None or values_util.is_saving_non_distributed()): # If we're in the replica context or are saving a non-distributed version # of the model, we resolve the captured variables to the corresponding # resource handle. In both situation we call var.handle, but it has # different behavior. In the replica context, var.handle resolves the # replica local variable handle if the variable is replicated. When saving # a non-distributed version of the model, var.handle resolves to the # primary variable handle, since we only save one copy of a replicated # variable. captured_inputs = list(map(get_handle, captured_inputs)) else: # cross-replica context captured_inputs = list(map(get_unused_handle, captured_inputs)) return super(_WrapperFunction, self)._call_flat(args, captured_inputs, cancellation_manager)
def scatter_max(self, sparse_delta, use_locking=False, name=None): if values_util.is_saving_non_distributed(): return self._primary.scatter_max(sparse_delta, use_locking, name) return self._policy.scatter_max(self, sparse_delta, use_locking=use_locking, name=name)
def op(self): if values_util.is_saving_non_distributed(): return self._primary.op return values.DistributedVarOp(self._primary.op.name, self._primary.op.graph, self._primary.op.traceback, self._primary.op.type)
def skip(self, delta): """Advance the counter of a counter-based RNG. Args: delta: the amount of advancement. The state of the RNG after `skip(n)` will be the same as that after `normal([n])` (or any other distribution). The actual increment added to the counter is an unspecified implementation detail. Returns: A `Tensor` of type `int64`. """ def update_fn(v): return self._skip_single_var(v, delta) # TODO(b/170515001): Always call strategy.extended.update after calling it # from both replica context and cross-replica context is supported. if values_util.is_saving_non_distributed(): # Assumes replica context with replica_id=0, since we only save the first # replica. return update_fn(self.state) if self._distribution_strategy is not None: with ds_context.enter_or_assert_strategy( self._distribution_strategy): if ds_context.in_cross_replica_context(): # Code that operates on all replicas of a variable cannot be saved # without retracing. values_util.mark_as_unsaveable() # In cross-replica context we need to use strategy.extended.update. return ds_context.get_strategy().extended.update( self.state, update_fn) return update_fn(self.state)
def assign(self, value, use_locking=False, name=None, read_value=True): if values_util.is_saving_non_distributed(): return self._primary.assign(value, use_locking, name, read_value) return self._policy.assign(self, value, use_locking=use_locking, name=name, read_value=read_value)
def _device_scope(self): if (self._packed_handle is None or values_util.is_saving_non_distributed() or tpu_util.enclosing_tpu_context() is not None): return ops.NullContextmanager() device = device_util.canonicalize(device_util.current()) if device in self._device_to_handle: return ops.NullContextmanager() return ops.device(self._primary_handle.device)
def handle(self): if values_util.is_saving_non_distributed(): return self._primary_handle tpu_context = tpu_util.enclosing_tpu_context() if tpu_context and not context.executing_eagerly(): is_mirrored = (self._variables[0].synchronization != variables_lib.VariableSynchronization.ON_READ) if self._packed_handle is None: handles = [v.handle for v in self._variables] is_packed = False else: handles = [self._packed_handle] is_packed = True return tpu_context.get_replicated_var_handle( self._unique_id, handles, is_mirrored, is_packed) if self._packed_handle is not None and not context.executing_eagerly(): return self._packed_handle device = device_util.canonicalize(device_util.current()) return self._device_to_handle.get(device, self._primary_handle)
def _skip(self, delta): def update_fn(v): return self._skip_single_var(v, delta) # TODO(b/170515001): Always call strategy.extended.update after calling it # from both replica context and cross-replica context is supported. if values_util.is_saving_non_distributed(): # Assumes replica context with replica_id=0, since we only save the first # replica. return update_fn(self.state) if self._distribution_strategy is not None: with ds_context.enter_or_assert_strategy( self._distribution_strategy): if ds_context.in_cross_replica_context(): # Code that operates on all replicas of a variable cannot be saved # without retracing. values_util.mark_as_unsaveable() # In cross-replica context we need to use strategy.extended.update. return ds_context.get_strategy().extended.update( self.state, update_fn) return update_fn(self.state)
def handle(self): if values_util.is_saving_non_distributed(): return self._primary_handle tpu_context = tpu_util.enclosing_tpu_context() if tpu_context and not context.executing_eagerly(): is_mirrored = (self._variables[0].synchronization != variables_lib.VariableSynchronization.ON_READ) if self._packed_handle is None: handles = [v.handle for v in self._variables] is_packed = False else: handles = [self._packed_handle] is_packed = True common_name = self._handle_name # BaseResourceVariable appends ":0" to the handle name, which makes it not # a valid root scope name. if ":" in common_name: common_name = common_name.split(":")[0] return tpu_context.get_replicated_var_handle( common_name, self._unique_id, handles, is_mirrored, is_packed) if self._packed_handle is not None and not context.executing_eagerly(): return self._packed_handle device = device_util.canonicalize(device_util.current()) return self._device_to_handle.get(device, self._primary_handle)
def scatter_update(self, *args, **kwargs): if values_util.is_saving_non_distributed(): return self._primary.scatter_update(*args, **kwargs) raise NotImplementedError
def initializer(self): if values_util.is_saving_non_distributed(): return self._variables[0].initializer return super().initializer
def name(self): if values_util.is_saving_non_distributed(): return self._variables[0].name return super().name