示例#1
0
    def _call_flat(self, args, captured_inputs, cancellation_manager=None):
        def get_handle(x):
            return x.handle if distribute_utils.is_distributed_variable(
                x) else x

        def get_unused_handle(x):
            return _unused_handle() if distribute_utils.is_distributed_variable(x)   \
                else x

        if (ds_context.get_replica_context() is not None
                or values_util.is_saving_non_distributed()):
            # If we're in the replica context or are saving a non-distributed version
            # of the model, we resolve the captured variables to the corresponding
            # resource handle. In both situation we call var.handle, but it has
            # different behavior. In the replica context, var.handle resolves the
            # replica local variable handle if the variable is replicated. When saving
            # a non-distributed version of the model, var.handle resolves to the
            # primary variable handle, since we only save one copy of a replicated
            # variable.
            captured_inputs = list(map(get_handle, captured_inputs))
        else:  # cross-replica context
            captured_inputs = list(map(get_unused_handle, captured_inputs))
        return super(_WrapperFunction,
                     self)._call_flat(args, captured_inputs,
                                      cancellation_manager)
示例#2
0
 def scatter_max(self, sparse_delta, use_locking=False, name=None):
     if values_util.is_saving_non_distributed():
         return self._primary.scatter_max(sparse_delta, use_locking, name)
     return self._policy.scatter_max(self,
                                     sparse_delta,
                                     use_locking=use_locking,
                                     name=name)
示例#3
0
 def op(self):
     if values_util.is_saving_non_distributed():
         return self._primary.op
     return values.DistributedVarOp(self._primary.op.name,
                                    self._primary.op.graph,
                                    self._primary.op.traceback,
                                    self._primary.op.type)
    def skip(self, delta):
        """Advance the counter of a counter-based RNG.

    Args:
      delta: the amount of advancement. The state of the RNG after
        `skip(n)` will be the same as that after `normal([n])`
        (or any other distribution). The actual increment added to the
        counter is an unspecified implementation detail.

    Returns:
      A `Tensor` of type `int64`.
    """
        def update_fn(v):
            return self._skip_single_var(v, delta)

        # TODO(b/170515001): Always call strategy.extended.update after calling it
        #   from both replica context and cross-replica context is supported.
        if values_util.is_saving_non_distributed():
            # Assumes replica context with replica_id=0, since we only save the first
            # replica.
            return update_fn(self.state)
        if self._distribution_strategy is not None:
            with ds_context.enter_or_assert_strategy(
                    self._distribution_strategy):
                if ds_context.in_cross_replica_context():
                    # Code that operates on all replicas of a variable cannot be saved
                    # without retracing.
                    values_util.mark_as_unsaveable()
                    # In cross-replica context we need to use strategy.extended.update.
                    return ds_context.get_strategy().extended.update(
                        self.state, update_fn)
        return update_fn(self.state)
示例#5
0
 def assign(self, value, use_locking=False, name=None, read_value=True):
     if values_util.is_saving_non_distributed():
         return self._primary.assign(value, use_locking, name, read_value)
     return self._policy.assign(self,
                                value,
                                use_locking=use_locking,
                                name=name,
                                read_value=read_value)
示例#6
0
 def _device_scope(self):
     if (self._packed_handle is None
             or values_util.is_saving_non_distributed()
             or tpu_util.enclosing_tpu_context() is not None):
         return ops.NullContextmanager()
     device = device_util.canonicalize(device_util.current())
     if device in self._device_to_handle:
         return ops.NullContextmanager()
     return ops.device(self._primary_handle.device)
示例#7
0
 def handle(self):
     if values_util.is_saving_non_distributed():
         return self._primary_handle
     tpu_context = tpu_util.enclosing_tpu_context()
     if tpu_context and not context.executing_eagerly():
         is_mirrored = (self._variables[0].synchronization !=
                        variables_lib.VariableSynchronization.ON_READ)
         if self._packed_handle is None:
             handles = [v.handle for v in self._variables]
             is_packed = False
         else:
             handles = [self._packed_handle]
             is_packed = True
         return tpu_context.get_replicated_var_handle(
             self._unique_id, handles, is_mirrored, is_packed)
     if self._packed_handle is not None and not context.executing_eagerly():
         return self._packed_handle
     device = device_util.canonicalize(device_util.current())
     return self._device_to_handle.get(device, self._primary_handle)
示例#8
0
    def _skip(self, delta):
        def update_fn(v):
            return self._skip_single_var(v, delta)

        # TODO(b/170515001): Always call strategy.extended.update after calling it
        #   from both replica context and cross-replica context is supported.
        if values_util.is_saving_non_distributed():
            # Assumes replica context with replica_id=0, since we only save the first
            # replica.
            return update_fn(self.state)
        if self._distribution_strategy is not None:
            with ds_context.enter_or_assert_strategy(
                    self._distribution_strategy):
                if ds_context.in_cross_replica_context():
                    # Code that operates on all replicas of a variable cannot be saved
                    # without retracing.
                    values_util.mark_as_unsaveable()
                    # In cross-replica context we need to use strategy.extended.update.
                    return ds_context.get_strategy().extended.update(
                        self.state, update_fn)
        return update_fn(self.state)
示例#9
0
 def handle(self):
     if values_util.is_saving_non_distributed():
         return self._primary_handle
     tpu_context = tpu_util.enclosing_tpu_context()
     if tpu_context and not context.executing_eagerly():
         is_mirrored = (self._variables[0].synchronization !=
                        variables_lib.VariableSynchronization.ON_READ)
         if self._packed_handle is None:
             handles = [v.handle for v in self._variables]
             is_packed = False
         else:
             handles = [self._packed_handle]
             is_packed = True
         common_name = self._handle_name
         # BaseResourceVariable appends ":0" to the handle name, which makes it not
         # a valid root scope name.
         if ":" in common_name:
             common_name = common_name.split(":")[0]
         return tpu_context.get_replicated_var_handle(
             common_name, self._unique_id, handles, is_mirrored, is_packed)
     if self._packed_handle is not None and not context.executing_eagerly():
         return self._packed_handle
     device = device_util.canonicalize(device_util.current())
     return self._device_to_handle.get(device, self._primary_handle)
示例#10
0
 def scatter_update(self, *args, **kwargs):
     if values_util.is_saving_non_distributed():
         return self._primary.scatter_update(*args, **kwargs)
     raise NotImplementedError
示例#11
0
 def initializer(self):
     if values_util.is_saving_non_distributed():
         return self._variables[0].initializer
     return super().initializer
示例#12
0
 def name(self):
     if values_util.is_saving_non_distributed():
         return self._variables[0].name
     return super().name