Пример #1
0
    def assign(self, value, use_locking=False, name=None, read_value=True):
        tpu_context = tpu_util.enclosing_tpu_context()
        if (self._is_replicated_or_sharded_to_logical_cores()
                and tpu_context is None):
            assign_fn = lambda v, *a, **ka: v.assign(*a, **ka)
            return self._update(update_fn=assign_fn,
                                value=value,
                                use_locking=use_locking,
                                name=name,
                                read_value=read_value)

        if (tpu_util.enclosing_tpu_context() and self.aggregation
                == variable_scope.VariableAggregation.NONE):
            return tpu_util.make_raw_assign_fn(
                gen_resource_variable_ops.assign_variable_op)(
                    self,
                    value=value,
                    use_locking=use_locking,
                    name=name,
                    read_value=read_value)
        return assign(self,
                      value,
                      use_locking=use_locking,
                      name=name,
                      read_value=read_value)
Пример #2
0
 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
     """Converts a variable to a tensor."""
     # pylint: disable=protected-access
     if tpu_util.enclosing_tpu_context() is None:
         return self.read_value()
     else:
         return self._read_variable_op()
Пример #3
0
 def __getattr__(self, name):
     if tpu_util.enclosing_tpu_context() is None:
         return super(TPUVariableMixin, self).__getattr__(name)
     else:
         raise AttributeError(
             f"`TPUVariableMixin.{name}` not accessible within a TPU context."
         )
Пример #4
0
 def get(self):
     if tpu_util.enclosing_tpu_context() is None:
         return super(TPUVariableMixin, self).get()
     else:
         raise NotImplementedError(
             "`TPUVariableMixin.get()` is not supported within a TPU context."
         )
Пример #5
0
 def assign(self, *args, **kwargs):
     if tpu_util.enclosing_tpu_context() is None:
         return values.SyncOnReadVariable.assign(self, *args, **kwargs)
     else:
         return _make_raw_assign_fn(
             gen_resource_variable_ops.assign_variable_op)(self, *args,
                                                           **kwargs)
Пример #6
0
 def assign(self, var, *args, **kwargs):
     if tpu_util.enclosing_tpu_context() is None:
         return super(TPUOnReadPolicy, self).assign(var, *args, **kwargs)
     else:
         return _make_raw_assign_fn(
             gen_resource_variable_ops.assign_variable_op)(var, *args,
                                                           **kwargs)
Пример #7
0
 def _device_scope(self):
     if (self._packed_handle is None
             or values_util.is_saving_non_distributed()
             or tpu_util.enclosing_tpu_context() is not None):
         return ops.NullContextmanager()
     device = device_util.canonicalize(device_util.current())
     if device in self._device_to_handle:
         return ops.NullContextmanager()
     return ops.device(self._primary_handle.device)
Пример #8
0
 def assign(self, var, value, use_locking=False, name=None, read_value=True):
   if tpu_util.enclosing_tpu_context():
     return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
         var,
         value=value,
         use_locking=use_locking,
         name=name,
         read_value=read_value)
   return assign(
       var, value, use_locking=use_locking, name=name, read_value=read_value)
Пример #9
0
 def assign(self, var, value, use_locking=False, name=None, read_value=True):
   if (tpu_util.enclosing_tpu_context() and
       var.aggregation == variable_scope.VariableAggregation.NONE):
     return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
         var,
         value=value,
         use_locking=use_locking,
         name=name,
         read_value=read_value)
   return assign(
       var, value, use_locking=use_locking, name=name, read_value=read_value)
Пример #10
0
 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
   """Converts a variable to a tensor."""
   # pylint: disable=protected-access
   if tpu_util.enclosing_tpu_context() is None:
     return super(TPUVariableMixin, self)._dense_var_to_tensor(
         dtype=dtype, name=name, as_ref=as_ref)
   # pylint: enable=protected-access
   elif dtype is not None and dtype != self.dtype:
     return math_ops.cast(self.read_value(), dtype)
   else:
     return self.handle if as_ref else self.read_value()
  def handle(self):
    if save_context.in_save_context() or context.executing_eagerly():
      return self._vars[0].handle

    if tpu_util.enclosing_tpu_context() is None:
      raise NotImplementedError('TPUReplicatedVariable.handle is not available '
                                'outside tpu context or save context')
    else:
      with tpu_util.outside_or_skip_tpu_context():
        return xla_sharding.replicate(
            tpu_partition_ops.tpu_partitioned_input(
                [v.handle for v in self._vars], partition_dim=-1))
Пример #12
0
 def assign_add(self, value, use_locking=False, name=None, read_value=True):
     if tpu_util.enclosing_tpu_context(
     ) is None or context.executing_eagerly():
         assign_add_fn = lambda var, *a, **ka: var.assign_add(*a, **ka)
         return self._update(assign_add_fn,
                             value=value,
                             use_locking=use_locking,
                             name=name,
                             read_value=read_value)
     else:
         return tpu_util.make_raw_assign_fn(
             gen_resource_variable_ops.assign_add_variable_op)(
                 self,
                 value=value,
                 use_locking=use_locking,
                 name=name,
                 read_value=read_value)
Пример #13
0
    def handle(self):
        """The handle by which this variable can be accessed."""
        # If we're in a tpu.rewrite(), return the replicated handle.
        tpu_context = tpu_util.enclosing_tpu_context()
        if tpu_context is None or context.executing_eagerly():
            var = self._get_on_device_or_primary()
            if isinstance(var, packed.PackedVarAndDevice):
                return var.on_device_handle()
            else:
                return var.handle
        else:
            is_packed = self._packed_var is not None
            val = self._values
            if is_packed:
                val = [self._packed_var]

            return tpu_context.get_replicated_var_handle(
                self._handle_id, val, self._is_mirrored(), is_packed)
Пример #14
0
 def handle(self):
     if values_util.is_saving_non_distributed():
         return self._primary_handle
     tpu_context = tpu_util.enclosing_tpu_context()
     if tpu_context and not context.executing_eagerly():
         is_mirrored = (self._variables[0].synchronization !=
                        variables_lib.VariableSynchronization.ON_READ)
         if self._packed_handle is None:
             handles = [v.handle for v in self._variables]
             is_packed = False
         else:
             handles = [self._packed_handle]
             is_packed = True
         return tpu_context.get_replicated_var_handle(
             self._unique_id, handles, is_mirrored, is_packed)
     if self._packed_handle is not None and not context.executing_eagerly():
         return self._packed_handle
     device = device_util.canonicalize(device_util.current())
     return self._device_to_handle.get(device, self._primary_handle)
Пример #15
0
 def _scatter_xxx(self,
                  raw_scater_xxx_fn,
                  op_name,
                  var,
                  sparse_delta,
                  use_locking=False,
                  name=None):
   scater_xxx_fn = _make_raw_scatter_xxx_fn(raw_scater_xxx_fn)
   if tpu_util.enclosing_tpu_context():
     if self._aggregation != variable_scope.VariableAggregation.NONE:
       raise NotImplementedError(
           _scatter_error_msg.format(
               op_name=op_name, aggregation=self._aggregation))
     return scater_xxx_fn(
         var, sparse_delta=sparse_delta, use_locking=use_locking, name=name)
   else:
     return var._update(  # pylint: disable=protected-access
         update_fn=scater_xxx_fn,
         value=sparse_delta,
         use_locking=use_locking,
         name=name)
Пример #16
0
 def handle(self):
     if values_util.is_saving_non_distributed():
         return self._primary_handle
     tpu_context = tpu_util.enclosing_tpu_context()
     if tpu_context and not context.executing_eagerly():
         is_mirrored = (self._variables[0].synchronization !=
                        variables_lib.VariableSynchronization.ON_READ)
         if self._packed_handle is None:
             handles = [v.handle for v in self._variables]
             is_packed = False
         else:
             handles = [self._packed_handle]
             is_packed = True
         common_name = self._handle_name
         # BaseResourceVariable appends ":0" to the handle name, which makes it not
         # a valid root scope name.
         if ":" in common_name:
             common_name = common_name.split(":")[0]
         return tpu_context.get_replicated_var_handle(
             common_name, self._unique_id, handles, is_mirrored, is_packed)
     if self._packed_handle is not None and not context.executing_eagerly():
         return self._packed_handle
     device = device_util.canonicalize(device_util.current())
     return self._device_to_handle.get(device, self._primary_handle)
Пример #17
0
 def _as_graph_element(self):
     if tpu_util.enclosing_tpu_context() is None:
         return super(TPUVariableMixin, self)._as_graph_element()  # pylint: disable=protected-access
     else:
         return None
Пример #18
0
 def value(self):
     if tpu_util.enclosing_tpu_context() is None:
         return super(TPUVariableMixin, self).value()
     else:
         return self._read_variable_op()
Пример #19
0
 def device(self):
     if (self._is_replicated_or_sharded_to_logical_cores()
             and tpu_util.enclosing_tpu_context() is None):
         return self._primary.device
     return super(TPUMirroredVariable, self).device