def assign(self, value, use_locking=False, name=None, read_value=True): tpu_context = tpu_util.enclosing_tpu_context() if (self._is_replicated_or_sharded_to_logical_cores() and tpu_context is None): assign_fn = lambda v, *a, **ka: v.assign(*a, **ka) return self._update(update_fn=assign_fn, value=value, use_locking=use_locking, name=name, read_value=read_value) if (tpu_util.enclosing_tpu_context() and self.aggregation == variable_scope.VariableAggregation.NONE): return tpu_util.make_raw_assign_fn( gen_resource_variable_ops.assign_variable_op)( self, value=value, use_locking=use_locking, name=name, read_value=read_value) return assign(self, value, use_locking=use_locking, name=name, read_value=read_value)
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): """Converts a variable to a tensor.""" # pylint: disable=protected-access if tpu_util.enclosing_tpu_context() is None: return self.read_value() else: return self._read_variable_op()
def __getattr__(self, name): if tpu_util.enclosing_tpu_context() is None: return super(TPUVariableMixin, self).__getattr__(name) else: raise AttributeError( f"`TPUVariableMixin.{name}` not accessible within a TPU context." )
def get(self): if tpu_util.enclosing_tpu_context() is None: return super(TPUVariableMixin, self).get() else: raise NotImplementedError( "`TPUVariableMixin.get()` is not supported within a TPU context." )
def assign(self, *args, **kwargs): if tpu_util.enclosing_tpu_context() is None: return values.SyncOnReadVariable.assign(self, *args, **kwargs) else: return _make_raw_assign_fn( gen_resource_variable_ops.assign_variable_op)(self, *args, **kwargs)
def assign(self, var, *args, **kwargs): if tpu_util.enclosing_tpu_context() is None: return super(TPUOnReadPolicy, self).assign(var, *args, **kwargs) else: return _make_raw_assign_fn( gen_resource_variable_ops.assign_variable_op)(var, *args, **kwargs)
def _device_scope(self): if (self._packed_handle is None or values_util.is_saving_non_distributed() or tpu_util.enclosing_tpu_context() is not None): return ops.NullContextmanager() device = device_util.canonicalize(device_util.current()) if device in self._device_to_handle: return ops.NullContextmanager() return ops.device(self._primary_handle.device)
def assign(self, var, value, use_locking=False, name=None, read_value=True): if tpu_util.enclosing_tpu_context(): return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)( var, value=value, use_locking=use_locking, name=name, read_value=read_value) return assign( var, value, use_locking=use_locking, name=name, read_value=read_value)
def assign(self, var, value, use_locking=False, name=None, read_value=True): if (tpu_util.enclosing_tpu_context() and var.aggregation == variable_scope.VariableAggregation.NONE): return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)( var, value=value, use_locking=use_locking, name=name, read_value=read_value) return assign( var, value, use_locking=use_locking, name=name, read_value=read_value)
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): """Converts a variable to a tensor.""" # pylint: disable=protected-access if tpu_util.enclosing_tpu_context() is None: return super(TPUVariableMixin, self)._dense_var_to_tensor( dtype=dtype, name=name, as_ref=as_ref) # pylint: enable=protected-access elif dtype is not None and dtype != self.dtype: return math_ops.cast(self.read_value(), dtype) else: return self.handle if as_ref else self.read_value()
def handle(self): if save_context.in_save_context() or context.executing_eagerly(): return self._vars[0].handle if tpu_util.enclosing_tpu_context() is None: raise NotImplementedError('TPUReplicatedVariable.handle is not available ' 'outside tpu context or save context') else: with tpu_util.outside_or_skip_tpu_context(): return xla_sharding.replicate( tpu_partition_ops.tpu_partitioned_input( [v.handle for v in self._vars], partition_dim=-1))
def assign_add(self, value, use_locking=False, name=None, read_value=True): if tpu_util.enclosing_tpu_context( ) is None or context.executing_eagerly(): assign_add_fn = lambda var, *a, **ka: var.assign_add(*a, **ka) return self._update(assign_add_fn, value=value, use_locking=use_locking, name=name, read_value=read_value) else: return tpu_util.make_raw_assign_fn( gen_resource_variable_ops.assign_add_variable_op)( self, value=value, use_locking=use_locking, name=name, read_value=read_value)
def handle(self): """The handle by which this variable can be accessed.""" # If we're in a tpu.rewrite(), return the replicated handle. tpu_context = tpu_util.enclosing_tpu_context() if tpu_context is None or context.executing_eagerly(): var = self._get_on_device_or_primary() if isinstance(var, packed.PackedVarAndDevice): return var.on_device_handle() else: return var.handle else: is_packed = self._packed_var is not None val = self._values if is_packed: val = [self._packed_var] return tpu_context.get_replicated_var_handle( self._handle_id, val, self._is_mirrored(), is_packed)
def handle(self): if values_util.is_saving_non_distributed(): return self._primary_handle tpu_context = tpu_util.enclosing_tpu_context() if tpu_context and not context.executing_eagerly(): is_mirrored = (self._variables[0].synchronization != variables_lib.VariableSynchronization.ON_READ) if self._packed_handle is None: handles = [v.handle for v in self._variables] is_packed = False else: handles = [self._packed_handle] is_packed = True return tpu_context.get_replicated_var_handle( self._unique_id, handles, is_mirrored, is_packed) if self._packed_handle is not None and not context.executing_eagerly(): return self._packed_handle device = device_util.canonicalize(device_util.current()) return self._device_to_handle.get(device, self._primary_handle)
def _scatter_xxx(self, raw_scater_xxx_fn, op_name, var, sparse_delta, use_locking=False, name=None): scater_xxx_fn = _make_raw_scatter_xxx_fn(raw_scater_xxx_fn) if tpu_util.enclosing_tpu_context(): if self._aggregation != variable_scope.VariableAggregation.NONE: raise NotImplementedError( _scatter_error_msg.format( op_name=op_name, aggregation=self._aggregation)) return scater_xxx_fn( var, sparse_delta=sparse_delta, use_locking=use_locking, name=name) else: return var._update( # pylint: disable=protected-access update_fn=scater_xxx_fn, value=sparse_delta, use_locking=use_locking, name=name)
def handle(self): if values_util.is_saving_non_distributed(): return self._primary_handle tpu_context = tpu_util.enclosing_tpu_context() if tpu_context and not context.executing_eagerly(): is_mirrored = (self._variables[0].synchronization != variables_lib.VariableSynchronization.ON_READ) if self._packed_handle is None: handles = [v.handle for v in self._variables] is_packed = False else: handles = [self._packed_handle] is_packed = True common_name = self._handle_name # BaseResourceVariable appends ":0" to the handle name, which makes it not # a valid root scope name. if ":" in common_name: common_name = common_name.split(":")[0] return tpu_context.get_replicated_var_handle( common_name, self._unique_id, handles, is_mirrored, is_packed) if self._packed_handle is not None and not context.executing_eagerly(): return self._packed_handle device = device_util.canonicalize(device_util.current()) return self._device_to_handle.get(device, self._primary_handle)
def _as_graph_element(self): if tpu_util.enclosing_tpu_context() is None: return super(TPUVariableMixin, self)._as_graph_element() # pylint: disable=protected-access else: return None
def value(self): if tpu_util.enclosing_tpu_context() is None: return super(TPUVariableMixin, self).value() else: return self._read_variable_op()
def device(self): if (self._is_replicated_or_sharded_to_logical_cores() and tpu_util.enclosing_tpu_context() is None): return self._primary.device return super(TPUMirroredVariable, self).device