Пример #1
0
    def _read_variable_op(self):
        with ops.control_dependencies([self._parent_op]):
            return gen_resource_variable_ops.read_variable_op(
                self._handle, self._dtype)

    def set_shape(self, shape):
        self._shape = shape

    @property
    def op(self):
        """The op for this variable."""
        return self._parent_op


ops.register_tensor_conversion_function(_UnreadVariable, _dense_var_to_tensor)
ops.register_dense_tensor_like_type(_UnreadVariable)

# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.

# Note: registering for Variable after ResourceVariable because inheritance will
# otherwise lead to the wrong behavior.
ops.register_tensor_conversion_function(ResourceVariable, _dense_var_to_tensor)
ops.register_tensor_conversion_function(
    variables.Variable, variables.Variable._TensorConversionFunction)  # pylint: disable=protected-access

# pylint: disable=protected-access
ResourceVariable._OverloadAllOperators()
ops.register_dense_tensor_like_type(ResourceVariable)

Пример #2
0
  def _read_variable_op(self):
    with ops.control_dependencies([self._parent_op]):
      return gen_resource_variable_ops.read_variable_op(self._handle,
                                                        self._dtype)

  def set_shape(self, shape):
    self._shape = shape

  @property
  def op(self):
    """The op for this variable."""
    return self._parent_op

ops.register_tensor_conversion_function(_UnreadVariable, _dense_var_to_tensor)
ops.register_dense_tensor_like_type(_UnreadVariable)

# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.

# Note: registering for Variable after ResourceVariable because inheritance will
# otherwise lead to the wrong behavior.
ops.register_tensor_conversion_function(ResourceVariable, _dense_var_to_tensor)
ops.register_tensor_conversion_function(
    variables.Variable, variables.Variable._TensorConversionFunction)  # pylint: disable=protected-access

# pylint: disable=protected-access
ResourceVariable._OverloadAllOperators()
ops.register_dense_tensor_like_type(ResourceVariable)

Пример #3
0
            # See https://docs.python.org/3/library/constants.html#NotImplemented
            return NotImplemented

    def __rmatmul__(self, o):
        try:
            return self.read_value().__rmatmul__(o)
        except AttributeError:
            # See https://docs.python.org/3/library/constants.html#NotImplemented
            return NotImplemented

    # pylint: enable=multiple-statements


ops.register_tensor_conversion_function(AutoCastVariable,
                                        AutoCastVariable._dense_var_to_tensor)  # pylint:disable=protected-access
ops.register_dense_tensor_like_type(AutoCastVariable)


def create_autocast_variable(variable):
    """Creates an AutoCastVariable that wraps another variable.

  This typically just returns `AutoCastVariable(variable)`. But, if the variable
  is a DistributedVariable or one of its subclasses, we instead dynamically
  create a class that subclasses from both AutoCastVariable and
  variable.__class__. This is so the returned variable will still pass
  `isinstance(variable, variable.__class__)`, which is required for
  DistributedVariables and its subclasses to work properly.

  Args:
    variable: A floating-point resource variable to wrap.
Пример #4
0
                array_ops.pack([state_ops.is_variable_initialized(v) for v in var_list])
            )
            # Get a 1-D string tensor containing all the variable names.
            variable_names_tensor = array_ops.constant([s.op.name for s in var_list])
            # Return a 1-D tensor containing all the names of uninitialized variables.
            return array_ops.boolean_mask(variable_names_tensor, variables_mask)


# pylint: disable=protected-access
ops.register_tensor_conversion_function(Variable, Variable._TensorConversionFunction)
Variable._OverloadAllOperators()

ops.register_tensor_conversion_function(PartitionedVariable, PartitionedVariable._TensorConversionFunction)
# pylint: enable=protected-access

ops.register_dense_tensor_like_type(Variable)
ops.register_proto_function(
    ops.GraphKeys.GLOBAL_VARIABLES,
    proto_type=variable_pb2.VariableDef,
    to_proto=Variable.to_proto,
    from_proto=Variable.from_proto,
)
ops.register_proto_function(
    ops.GraphKeys.TRAINABLE_VARIABLES,
    proto_type=variable_pb2.VariableDef,
    to_proto=Variable.to_proto,
    from_proto=Variable.from_proto,
)
ops.register_proto_function(
    ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
    proto_type=variable_pb2.VariableDef,
Пример #5
0
    @_handle_name.setter
    def _handle_name(self, handle_name):
        self.latent_variable._handle_name = handle_name

    @property
    def _initializer_op(self):
        return self.latent_variable._initializer_op

    @_initializer_op.setter
    def _initializer_op(self, initializer_op):
        self.latent_variable._initializer_op = initializer_op

    def _as_graph_element(self):
        if self.quantizer and context.should_quantize():
            return self.quantizer(self.latent_variable)
        graph_element = self.latent_variable._as_graph_element()
        if graph_element is None:
            return self._op
        return graph_element


QuantizedVariable._OverloadAllOperators()
tf.register_tensor_conversion_function(
    QuantizedVariable, QuantizedVariable._dense_var_to_tensor
)
try:
    ops.register_dense_tensor_like_type(QuantizedVariable)
except AttributeError:
    pass
Пример #6
0
        return self.get().op

    @property
    def _in_graph_mode(self):
        return self._primary_var._in_graph_mode  # pylint: disable=protected-access

    def read_value(self):
        return distribution_strategy_context.get_distribution_strategy(
        ).read_var(self)

    def _should_act_as_resource_variable(self):
        """Pass resource_variable_ops.is_resource_variable check."""
        pass


ops.register_dense_tensor_like_type(DistributedVariable)


class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
    """Class for defining how to restore a MirroredVariable."""
    def __init__(self, mirrored_variable, primary_variable, name):
        self._mirrored_variable = mirrored_variable
        super(_MirroredSaveable, self).__init__(primary_variable, "", name)

    def restore(self, restored_tensors, restored_shapes):
        """Restore the same value into all variables."""
        tensor, = restored_tensors
        return control_flow_ops.group([
            _assign_on_device(d, v, tensor)
            for d, v in six.iteritems(self._mirrored_variable._index)
        ])  # pylint: disable=protected-access
Пример #7
0
                                         self._variable.dtype, name,
                                         as_ref=False)
    with ops.colocate_with(None, ignore_existing=True):
      with ops.device(val.device):
        return math_ops.cast(val, self.dtype)

  def _should_act_as_resource_variable(self):
    """Pass resource_variable_ops.is_resource_variable check."""
    pass

  # TODO(reedwm): Define operator overloads.


ops.register_tensor_conversion_function(
    AutoCastVariable, AutoCastVariable._dense_var_to_tensor)  # pylint:disable=protected-access
ops.register_dense_tensor_like_type(AutoCastVariable)


# We have DistributedVariable subclass to pass
# isinstance(..., DistributedVariable) checks when wrapping a
# DistributedVariable.
# TODO(reedwm): We should not wrap DistributedVariable, but instead have
# DistributedVariable wrap AutoCastVariable. Subclassing DistributedVariable is
# messy, because we do not fully implement the interface of DistributedVariable.
class AutoCastDistributedVariable(AutoCastVariable,
                                  distribute_values.DistributedVariable):
  """Version of AutoCastVariable that subclasses DistributedVariable."""

  def __init__(self, variable):
    if not isinstance(variable, distribute_values.DistributedValues):
      raise ValueError('variable must be of type DistributedValues, '
Пример #8
0
    """Pass resource_variable_ops.is_resource_variable check."""
    pass

  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
    """Converts a variable to a tensor."""
    # pylint: disable=protected-access
    if _enclosing_tpu_context() is None:
      if hasattr(self._primary_var, '_dense_var_to_tensor'):
        return self._primary_var._dense_var_to_tensor(dtype, name, as_ref)
      else:
        return ops.convert_to_tensor(self._primary_var)
    # pylint: enable=protected-access
    if dtype is not None and dtype != self.dtype:
      return NotImplemented
    if as_ref:
      return self.handle
    else:
      return self.read_value()


# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.
def _tensor_conversion(var, dtype=None, name=None, as_ref=False):
  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access


ops.register_tensor_conversion_function(ReplicatedVariable, _tensor_conversion)

if not TF_23:
  ops.register_dense_tensor_like_type(ReplicatedVariable)
Пример #9
0
    """Pass resource_variable_ops.is_resource_variable check."""
    pass


# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.
def _tensor_conversion(var, dtype=None, name=None, as_ref=False):
  # Try to avoid assignments to and other mutations of MirroredVariable
  # state except through a DistributionStrategy.update() call.
  assert not as_ref
  return ops.internal_convert_to_tensor(
      var.get(), dtype=dtype, name=name, as_ref=as_ref)


ops.register_tensor_conversion_function(DistributedVariable, _tensor_conversion)
ops.register_dense_tensor_like_type(DistributedVariable)


class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
  """Class for defining how to restore a MirroredVariable."""

  def __init__(self, mirrored_variable, primary_variable, name):
    self._mirrored_variable = mirrored_variable
    super(_MirroredSaveable, self).__init__(primary_variable, "", name)

  def restore(self, restored_tensors, restored_shapes):
    """Restore the same value into all variables."""
    tensor, = restored_tensors
    return control_flow_ops.group([
        _assign_on_device(d, v, tensor)
        for d, v in six.iteritems(self._mirrored_variable._index)])  # pylint: disable=protected-access
Пример #10
0
        variable_names_tensor = array_ops.constant(
            [s.op.name for s in var_list])
        # Return a 1-D tensor containing all the names of uninitialized variables.
        return array_ops.boolean_mask(variable_names_tensor,
                                      variables_mask,
                                      name=name)


# pylint: disable=protected-access
ops.register_tensor_conversion_function(Variable,
                                        Variable._TensorConversionFunction)
Variable._OverloadAllOperators()

ops.register_tensor_conversion_function(
    _PartitionedVariable, _PartitionedVariable._TensorConversionFunction)
# pylint: enable=protected-access

ops.register_dense_tensor_like_type(Variable)
ops.register_proto_function(ops.GraphKeys.VARIABLES,
                            proto_type=variable_pb2.VariableDef,
                            to_proto=Variable.to_proto,
                            from_proto=Variable.from_proto)
ops.register_proto_function(ops.GraphKeys.TRAINABLE_VARIABLES,
                            proto_type=variable_pb2.VariableDef,
                            to_proto=Variable.to_proto,
                            from_proto=Variable.from_proto)
ops.register_proto_function(ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
                            proto_type=variable_pb2.VariableDef,
                            to_proto=Variable.to_proto,
                            from_proto=Variable.from_proto)
Пример #11
0

# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.
def _tensor_conversion(var, dtype=None, name=None, as_ref=False):
  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access


def replicated_fetch_function(var):
  # pylint: disable=protected-access
  return ([var._dense_var_to_tensor()], lambda v: v[0])
  # pylint: enable=protected-access


ops.register_tensor_conversion_function(ReplicatedVariable, _tensor_conversion)
ops.register_dense_tensor_like_type(ReplicatedVariable)
session_lib.register_session_run_conversion_functions(
    ReplicatedVariable, replicated_fetch_function)


def replicated_scope(num_replicas):
  """Variable scope for constructing replicated variables."""

  def _replicated_variable_getter(getter, name, *args, **kwargs):
    """Getter that constructs replicated variables."""
    collections = kwargs.pop("collections", None)
    if collections is None:
      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    kwargs["collections"] = []

    variables = []
Пример #12
0
def _tensor_conversion_function(tensor, dtype=None, name=None, as_ref=False):
    # assert name is None, "Not implemented, name='{}'".format(name)
    # assert not as_ref, "Not implemented, as_ref={}".format(as_ref)
    # assert dtype in [tf.int32, None], dtype
    return convert_from_rtttensor(tensor, dtype=dtype)


# this allows implicit convertion of rtt.RttTensor to tf.Tensor,
# but since the output dtype is determined by the outer context
# we essentially have to export with the implied risk of data loss
tf_ops.register_tensor_conversion_function(RttTensor,
                                           _tensor_conversion_function)

# this allows RttTensor to pass the tf.is_tensor test
tf_ops.register_dense_tensor_like_type(RttTensor)

# this allows rtt.RttTensor to be plumbed through Keras layers
# but seems only truly useful when used in conjunction with
# `register_tensor_conversion_function`
tf_utils.register_symbolic_tensor_type(RttTensor)


def _convert_numpy_tensor(tensor):
    """ convert numpy tensor to rtt tensor """

    if (np.issubdtype(tensor.dtype, np.int16)
            or np.issubdtype(tensor.dtype, np.int32)
            or np.issubdtype(tensor.dtype, np.int64)
            or np.issubdtype(tensor.dtype, np.float)
            or np.issubdtype(tensor.dtype, np.double)