Ejemplo n.º 1
0
    def _read_variable_op(self, do_prefetch=True):
        resource_variable_ops.variable_accessed(self)
        if self.model_mode == "train":
            if do_prefetch:
                with ops.control_dependencies([
                        gen_resource_variable_ops.assign_variable_op(
                            self._handle,
                            self.prefetch_values(),
                            name="AssignBeforeReadVariable")
                ]):
                    _result = gen_resource_variable_ops.read_variable_op(
                        self._handle, self._dtype)
            else:
                _result = gen_resource_variable_ops.read_variable_op(
                    self._handle, self._dtype)
        else:
            _result = self.prefetch_values()

        if not context.executing_eagerly():
            # Note that if a control flow context is active the input of the read op
            # might not actually be the handle. This line bypasses it.
            tape.record_operation("ReadVariableOp", [_result], [self._handle],
                                  lambda x: [x])
        result = self.transform(_result)
        return result
Ejemplo n.º 2
0
    def test_shared_variable(self):
        x = gen_resource_variable_ops.var_handle_op(dtype=tf.float32,
                                                    shape=(1, 2),
                                                    shared_name="variable_1")
        gen_resource_variable_ops.assign_variable_op(x,
                                                     tf.constant([[1.0, 2.0]]))
        y = gen_resource_variable_ops.var_handle_op(dtype=tf.float32,
                                                    shape=(1, 2),
                                                    shared_name="variable_1")
        gen_resource_variable_ops.assign_variable_op(y,
                                                     tf.constant([[2.0, 3.0]]))
        read_x = gen_resource_variable_ops.read_variable_op(x,
                                                            dtype=tf.float32)
        read_y = gen_resource_variable_ops.read_variable_op(y,
                                                            dtype=tf.float32)
        self.assertTrue(tensor_equal(read_x, read_y))

        x = gen_resource_variable_ops.var_handle_op(
            dtype=tf.float32, shape=(1, 2), shared_name=context.shared_name())
        gen_resource_variable_ops.assign_variable_op(x,
                                                     tf.constant([[1.0, 2.0]]))
        y = gen_resource_variable_ops.var_handle_op(
            dtype=tf.float32, shape=(1, 2), shared_name=context.shared_name())
        gen_resource_variable_ops.assign_variable_op(y,
                                                     tf.constant([[2.0, 3.0]]))
        read_x = gen_resource_variable_ops.read_variable_op(x,
                                                            dtype=tf.float32)
        read_y = gen_resource_variable_ops.read_variable_op(y,
                                                            dtype=tf.float32)
        self.assertFalse(tensor_equal(read_x, read_y))
Ejemplo n.º 3
0
    def write_var_in_while():
      gen_resource_variable_ops.read_variable_op(
          v.handle, v.dtype, name="read1")

      result = build_functional_op(v)
      gen_resource_variable_ops.read_variable_op(
          v.handle, v.dtype, name="read2")
      gen_resource_variable_ops.assign_variable_op(v.handle, v + 1)
      return result
Ejemplo n.º 4
0
 def add_op_to_graph(num_ops):
     with func_graph.FuncGraph("resource").as_default():
         handle = resource_variable_ops.var_handle_op(
             dtype=dtypes.int32, shape=[])
         resource_variable_ops.assign_variable_op(
             handle, constant_op.constant(1, dtype=dtypes.int32))
         for _ in range(num_ops):
             gen_resource_variable_ops.read_variable_op(
                 handle, dtype=dtypes.int32)
Ejemplo n.º 5
0
 def testNoControlDepsBetweenVariableReads(self):
     with context.graph_mode(), self.cached_session():
         v = resource_variable_ops.ResourceVariable(1.0)
         self.evaluate(variables.global_variables_initializer())
         with acd.AutomaticControlDependencies():
             read_op1 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             read_op2 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             gen_resource_variable_ops.assign_variable_op(v.handle, v + 1)
         self.assertNotIn(read_op1, read_op2.control_inputs)
         self.assertNotIn(read_op2, read_op1.control_inputs)
Ejemplo n.º 6
0
  def _read_variable_op(self):
    """Reads the value of this variable."""
    if self.trainable:
      tape.variable_accessed(self)

    handle = self.handle
    if getattr(handle, "is_packed", False):
      # Add a device scope for a packed variable handle.
      with ops.device(self._get_on_device_or_primary().device):
        return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
    else:
      return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
Ejemplo n.º 7
0
  def testVariableMultipleReadsAndWrites(self):
    with context.graph_mode(), self.cached_session():
      v = resource_variable_ops.ResourceVariable(1.0)
      self.evaluate(variables.global_variables_initializer())
      with acd.AutomaticControlDependencies() as c:
        # 2 reads -> 2 writes -> 2 reads -> 2 writes.
        read_op1 = gen_resource_variable_ops.read_variable_op(
            v.handle, v.dtype).op
        read_op2 = gen_resource_variable_ops.read_variable_op(
            v.handle, v.dtype).op
        assign_op1 = gen_resource_variable_ops.assign_variable_op(
            v.handle, v + 1)
        assign_op2 = gen_resource_variable_ops.assign_variable_op(
            v.handle, v + 1)
        read_op3 = gen_resource_variable_ops.read_variable_op(
            v.handle, v.dtype).op
        read_op4 = gen_resource_variable_ops.read_variable_op(
            v.handle, v.dtype).op
        assign_op3 = gen_resource_variable_ops.assign_variable_op(
            v.handle, v + 1)
        assign_op4 = gen_resource_variable_ops.assign_variable_op(
            v.handle, v + 1)
        # Read ops get added to control outputs only if they have consumers.
        c.mark_as_return(read_op1.outputs[0])
        c.mark_as_return(read_op2.outputs[0])
        c.mark_as_return(read_op3.outputs[0])
        c.mark_as_return(read_op4.outputs[0])

      # Verify the control edges.
      self.assertIn(read_op1, assign_op1.control_inputs)
      self.assertIn(read_op2, assign_op1.control_inputs)
      self.assertIn(assign_op1, assign_op2.control_inputs)
      self.assertIn(assign_op2, read_op3.control_inputs)
      self.assertIn(assign_op2, read_op4.control_inputs)
      self.assertIn(read_op3, assign_op3.control_inputs)
      self.assertIn(read_op4, assign_op3.control_inputs)
      self.assertIn(assign_op3, assign_op4.control_inputs)

      # There should be no control deps between reads.
      read_ops = [read_op1, read_op2, read_op3, read_op4]
      for src_op, tgt_op in itertools.product(read_ops, read_ops):
        self.assertNotIn(src_op, tgt_op.control_inputs)

      # Reads must be in `ops_which_must_run`.
      self.assertIn(read_op1, c.ops_which_must_run)
      self.assertIn(read_op2, c.ops_which_must_run)
      self.assertIn(read_op3, c.ops_which_must_run)
      self.assertIn(read_op4, c.ops_which_must_run)
      # Last write must be in `ops_which_must_run`.
      self.assertIn(assign_op4, c.ops_which_must_run)
Ejemplo n.º 8
0
def gen_read_var_op(var_op, dtype):
    """
    Given a var op, generate the op for reading its value.

    Args:
        var_op (Operation): The var op
        dtype (dtype): The dtype of the data to read

    Returns:
        Operation: The value-reading operation
    """
    var_op_tensor = var_op.outputs[0]
    if var_op.type == 'VarHandleOp':
        result = gen_resource_variable_ops.read_variable_op(
            var_op_tensor, dtype)
        _maybe_set_handle_data(dtype, var_op_tensor, result)
        if not context.executing_eagerly():
            # Note that if a control flow context is active the input of the read op
            # might not actually be the handle. This line bypasses it.
            tape.record_operation("ReadVariableOp", [result], [var_op_tensor],
                                  lambda x: [x])
        return result
    elif var_op.type == 'VariableV2' or is_read_var_op(var_op):
        return array_ops.identity(var_op_tensor)
    raise ValueError("Can't generate the variable reading tensor from '{}'. "
                     "It may not be a proper variable op.".format(var_op.name))
Ejemplo n.º 9
0
 def _read_variable_op(self):
   if hasattr(self, "_trainable") and self._trainable:
     tape.watch_variable(self)
     return read_variable_op(self._handle, dtype=self._dtype)
   else:
     return gen_resource_variable_ops.read_variable_op(self._handle,
                                                       self._dtype)
 def value(self):
   """A cached operation which reads the value of this variable."""
   if self._cached_value is not None:
     return self._cached_value
   with ops.device(self._handle.device):
     return gen_resource_variable_ops.read_variable_op(
         self._handle, dtype=self._dtype)
Ejemplo n.º 11
0
 def _read_variable_op(self):
     if hasattr(self, "_trainable") and self._trainable:
         tape.watch(self._handle)
         return read_variable_op(self._handle, dtype=self._dtype)
     else:
         return gen_resource_variable_ops.read_variable_op(
             self._handle, self._dtype)
Ejemplo n.º 12
0
 def testVariableWriteThenRead(self):
     with context.graph_mode(), self.cached_session():
         v = resource_variable_ops.ResourceVariable(1.0)
         self.evaluate(variables.global_variables_initializer())
         with acd.AutomaticControlDependencies():
             assign_op = gen_resource_variable_ops.assign_variable_op(
                 v.handle, v + 1)
             read_op1 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             read_op2 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
         # Reads should have a control dep from the last write.
         self.assertIn(assign_op, read_op1.control_inputs)
         self.assertIn(assign_op, read_op2.control_inputs)
         # There should be no control deps between reads.
         self.assertNotIn(read_op1, read_op2.control_inputs)
         self.assertNotIn(read_op2, read_op1.control_inputs)
Ejemplo n.º 13
0
 def testVariableReadsInOpsWithMustRun(self):
     with context.graph_mode(), self.cached_session():
         v = resource_variable_ops.ResourceVariable(1.0)
         self.evaluate(variables.global_variables_initializer())
         with acd.AutomaticControlDependencies() as c:
             read_op = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
         self.assertIn(read_op, c.ops_which_must_run)
Ejemplo n.º 14
0
 def testResourceTensorPlacement(self):
   with context.device('gpu:0'):
     v = resource_variable_ops.ResourceVariable(1.0)
   with context.device('cpu:0'):
     # Check that even though we specified the cpu device we'll run the read op
     # in the device where the handle is.
     self.assertAllEqual(
         gen_resource_variable_ops.read_variable_op(v.handle, v.dtype), 1.0)
Ejemplo n.º 15
0
 def testVariableReadsNotInOpsWithMustRun(self):
     with context.graph_mode(), self.cached_session():
         v = resource_variable_ops.ResourceVariable(1.0)
         self.evaluate(variables.global_variables_initializer())
         with acd.AutomaticControlDependencies() as c:
             read_op1 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             read_op2 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             assign_op = gen_resource_variable_ops.assign_variable_op(
                 v.handle, v + 1)
         # Reads must not be in `ops_which_must_run` since those get added to the
         # `control_outputs`.
         self.assertNotIn(read_op1, c.ops_which_must_run)
         self.assertNotIn(read_op2, c.ops_which_must_run)
         # Last write must be in `ops_which_must_run`.
         self.assertIn(assign_op, c.ops_which_must_run)
Ejemplo n.º 16
0
 def testResourceTensorPlacement(self):
   with context.device('gpu:0'):
     v = resource_variable_ops.ResourceVariable(1.0)
   with context.device('cpu:0'):
     # Check that even though we specified the cpu device we'll run the read op
     # in the device where the handle is.
     self.assertAllEqual(
         gen_resource_variable_ops.read_variable_op(v.handle, v.dtype), 1.0)
Ejemplo n.º 17
0
 def testVariableReadThenWrite(self):
     with context.graph_mode(), self.cached_session():
         v = resource_variable_ops.ResourceVariable(1.0)
         self.evaluate(variables.global_variables_initializer())
         with acd.AutomaticControlDependencies():
             read_op1 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             read_op2 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             assign_op = gen_resource_variable_ops.assign_variable_op(
                 v.handle, v + 1)
         # Writes should have control deps from "all" reads since last write
         # or start of the code block.
         self.assertIn(read_op1, assign_op.control_inputs)
         self.assertIn(read_op2, assign_op.control_inputs)
         # There should be no control deps between reads.
         self.assertNotIn(read_op1, read_op2.control_inputs)
         self.assertNotIn(read_op2, read_op1.control_inputs)
Ejemplo n.º 18
0
 def testVariableReadsInOpsWithMustRun(self):
   with context.graph_mode(), self.cached_session():
     v = resource_variable_ops.ResourceVariable(1.0)
     self.evaluate(variables.global_variables_initializer())
     with acd.AutomaticControlDependencies() as c:
       read_op = gen_resource_variable_ops.read_variable_op(v.handle,
                                                            v.dtype).op
       # Read ops get added to control outputs only if they have consumers.
       c.mark_as_return(read_op.outputs[0])
     self.assertIn(read_op, c.ops_which_must_run)
Ejemplo n.º 19
0
def _read_component(handle, dtype, replica_id, parallel_device):
    """Read one component of a parallel variable and discard the rest."""
    with ops.device(handle.device):
        read = gen_resource_variable_ops.read_variable_op(resource=handle,
                                                          dtype=dtype)
    all_components = parallel_device.unpack(read)
    # We're pretending that parallel variables have a first axis with length
    # num_components, so we need to add a dummy first axis to the shape that gets
    # saved.
    return all_components[replica_id][None, ...]
Ejemplo n.º 20
0
 def testManualControlDepMonitoringAttrNotAdded(self):
     with context.graph_mode(), self.cached_session():
         v = resource_variable_ops.ResourceVariable(1.0)
         self.evaluate(variables.global_variables_initializer())
         with acd.AutomaticControlDependencies():
             read_op1 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             read_op2 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             assign_op = gen_resource_variable_ops.assign_variable_op(
                 v.handle, v + 1)
         # Writes should have control deps automatically added from "all" reads
         # since last write or start of the code block.
         self.assertIn(read_op1, assign_op.control_inputs)
         self.assertIn(read_op2, assign_op.control_inputs)
         # But, we shouldn't add the monitoring attribute in this case.
         with self.assertRaises(ValueError):
             assign_op.get_attr("_has_manual_control_dependencies")
         with self.assertRaises(ValueError):
             read_op1.get_attr("_has_manual_control_dependencies")
         with self.assertRaises(ValueError):
             read_op2.get_attr("_has_manual_control_dependencies")
Ejemplo n.º 21
0
 def testIdentityPassThrough(self):
   with context.graph_mode(), self.cached_session():
     v = resource_variable_ops.ResourceVariable(1.0)
     self.evaluate(variables.global_variables_initializer())
     with acd.AutomaticControlDependencies():
       gen_resource_variable_ops.assign_variable_op(v.handle, v + 1)
       identity_handle = gen_array_ops.identity(v.handle)
       assign_op2 = gen_resource_variable_ops.assign_variable_op(
           v.handle, v + 1)
       read_op = gen_resource_variable_ops.read_variable_op(
           identity_handle, v.dtype).op
     # Read should have a control dep from second last write even
     # with Identity applied to resource.
     self.assertIn(assign_op2, read_op.control_inputs)
Ejemplo n.º 22
0
  def read_value(self):
    """Constructs an op which reads the value of this variable.

    Should be used when there are multiple reads, or when it is desirable to
    read the value only after some condition is true.

    Returns:
     the read operation.
    """
    with ops.name_scope("Read"):
      value = gen_resource_variable_ops.read_variable_op(
          self._handle, dtype=self._dtype)
    # Return an identity so it can get placed on whatever device the context
    # specifies instead of the device where the variable is.
    return array_ops.identity(value)
    def read_value(self):
        """Constructs an op which reads the value of this variable.

    Should be used when there are multiple reads, or when it is desirable to
    read the value only after some condition is true.

    Returns:
     the read operation.
    """
        with ops.name_scope("Read"):
            value = gen_resource_variable_ops.read_variable_op(
                self._handle, dtype=self._dtype)
        # Return an identity so it can get placed on whatever device the context
        # specifies instead of the device where the variable is.
        return array_ops.identity(value)
Ejemplo n.º 24
0
    def read_value(self, collections=None, trainable=True):
        """Constructs an op which reads the value of this variable.

    Should be used when there are multiple reads, or when it is desirable to
    read the value only after some condition is true.

    Args:
     collections: any collections in which this operation should be inserted.
     trainable: whether this read is to be used for training.

    Returns:
     the read operation.
    """
        with ops.name_scope("Read"):
            value = gen_resource_variable_ops.read_variable_op(self._handle, dtype=self._dtype)
        _register_variable_read(value, collections=collections, trainable=trainable)
        return array_ops.identity(value)
Ejemplo n.º 25
0
  def read_value(self, collections=None, trainable=True):
    """Constructs an op which reads the value of this variable.

    Should be used when there are multiple reads, or when it is desirable to
    read the value only after some condition is true.

    Args:
     collections: any collections in which this operation should be inserted.
     trainable: whether this read is to be used for training.

    Returns:
     the read operation.
    """
    with ops.name_scope("Read"):
      value = gen_resource_variable_ops.read_variable_op(
          self._handle, dtype=self._dtype)
    _register_variable_read(value, collections=collections, trainable=trainable)
    return array_ops.identity(value)
Ejemplo n.º 26
0
def read_variable_op(handle, dtype):
  """Reads the value of a variable.

  The tensor returned by this operation is immutable.

  The value returned by this operation is guaranteed to be influenced by all the
  writes on which this operation depends directly or indirectly, and to not be
  influenced by any of the writes which depend directly or indirectly on this
  operation.

  Args:
    handle: A `Tensor` of type `resource`.
      handle to the resource in which to store the variable.
    dtype: A `tf.DType`. the dtype of the value.

  Returns:
    A `Tensor` of type `dtype`.
  """
  result = gen_resource_variable_ops.read_variable_op(handle, dtype)
  def grad(dresult):
    return dresult
  return result, grad
  def test_read_value(self):
    if not context.executing_eagerly():
      self.skipTest('Only test in eager mode.')

    params = de.get_variable('pn012', dim=2, initializer=0.1)
    params.upsert(
        constant_op.constant([1, 2, 3], dtype=dtypes.int64),
        constant_op.constant([[1., 1.], [2., 2.], [3., 3.]],
                             dtype=dtypes.float32))
    shadow = de.shadow_ops.ShadowVariable(params)

    val = shadow.read_value()
    self.assertAllEqual(val.numpy().tolist(), [])

    ids = constant_op.constant([2, 3, 4], dtype=dtypes.int64)
    shadow._reset_ids(ids)
    val = gen_resource_variable_ops.read_variable_op(shadow._handle,
                                                     dtypes.float32)
    self.assertAllEqual(val.numpy().tolist(), [])
    val = shadow.read_value(do_prefetch=False)
    self.assertAllEqual(val.numpy().tolist(), [])
    val = shadow.read_value(do_prefetch=True)
    self.assertAllEqual(
        val,
        constant_op.constant([[2., 2.], [3., 3.], [0.1, 0.1]],
                             dtype=dtypes.float32))

    ids = constant_op.constant([3, 4, 5], dtype=dtypes.int64)
    shadow._reset_ids(ids)
    val = shadow.read_value(do_prefetch=False)
    self.assertAllEqual(
        val,
        constant_op.constant([[2., 2.], [3., 3.], [0.1, 0.1]],
                             dtype=dtypes.float32))
    val = shadow.read_value(do_prefetch=True)
    self.assertAllEqual(
        val,
        constant_op.constant([[3., 3.], [0.1, 0.1], [0.1, 0.1]],
                             dtype=dtypes.float32))
def read_variable_op(handle, dtype):
    """Reads the value of a variable.

  The tensor returned by this operation is immutable.

  The value returned by this operation is guaranteed to be influenced by all the
  writes on which this operation depends directly or indirectly, and to not be
  influenced by any of the writes which depend directly or indirectly on this
  operation.

  Args:
    handle: A `Tensor` of type `resource`.
      handle to the resource in which to store the variable.
    dtype: A `tf.DType`. the dtype of the value.

  Returns:
    A `Tensor` of type `dtype`.
  """
    result = gen_resource_variable_ops.read_variable_op(handle, dtype)

    def grad(dresult):
        return dresult

    return result, grad
Ejemplo n.º 29
0
 def _read_variable_op(self):
   if self.trainable:
     tape.watch_variable(self)
   return gen_resource_variable_ops.read_variable_op(self._handle,
                                                     self._dtype)
Ejemplo n.º 30
0
 def _read_variable_op(self):
     with ops.control_dependencies([self._parent_op]):
         return gen_resource_variable_ops.read_variable_op(
             self._handle, self._dtype)
Ejemplo n.º 31
0
    def _init_from_args(self,
                        initial_value=None,
                        trainable=True,
                        collections=None,
                        validate_shape=True,
                        caching_device=None,
                        name=None,
                        dtype=None,
                        constraint=None):
        """Creates a variable.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called.
        (Note that initializer functions from init_ops.py must first be bound
         to a shape before being used here.)
      trainable: If `True`, the default, also adds the variable to the graph
        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
        the default list of variables to use by the `Optimizer` classes.
      collections: List of graph collections keys. The new variable is added to
        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
      validate_shape: Ignored. Provided for compatibility with tf.Variable.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).
      constraint: An optional projection function to be applied to the variable
        after being updated by an `Optimizer` (e.g. used to implement norm
        constraints or value constraints for layer weights). The function must
        take as input the unprojected Tensor representing the value of the
        variable and return the Tensor for the projected value
        (which must have the same shape). Constraints are not safe to
        use when doing asynchronous distributed training.

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.
    """
        if initial_value is None:
            raise ValueError("initial_value must be specified.")
        init_from_fn = callable(initial_value)

        if collections is None:
            collections = [ops.GraphKeys.GLOBAL_VARIABLES]
        if not isinstance(collections, (list, tuple, set)):
            raise ValueError(
                "collections argument to Variable constructor must be a list, tuple, "
                "or set. Got %s of type %s" % (collections, type(collections)))
        if constraint is not None and not callable(constraint):
            raise ValueError("The `constraint` argument must be a callable.")

        if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
            collections = list(collections) + [
                ops.GraphKeys.TRAINABLE_VARIABLES
            ]
        self._save_slice_info = None
        in_graph_mode = context.in_graph_mode()
        with ops.control_dependencies(None):
            with ops.name_scope(
                    name, "Variable",
                [] if init_from_fn else [initial_value]) as name:
                # pylint: disable=protected-access
                handle_name = ops._name_from_scope_name(name)
                if init_from_fn:
                    # Use attr_scope and device(None) to simulate the behavior of
                    # colocate_with when the variable we want to colocate with doesn't
                    # yet exist.
                    if in_graph_mode:
                        attr = attr_value_pb2.AttrValue(
                            list=attr_value_pb2.AttrValue.ListValue(
                                s=[compat.as_bytes("loc:@%s" % handle_name)]))
                        with ops.get_default_graph()._attr_scope(
                            {"_class": attr}):
                            with ops.name_scope("Initializer"), ops.device(
                                    None):
                                initial_value = ops.convert_to_tensor(
                                    initial_value(),
                                    name="initial_value",
                                    dtype=dtype)
                            self._handle = gen_resource_variable_ops.var_handle_op(
                                shape=initial_value.get_shape(),
                                dtype=initial_value.dtype.base_dtype,
                                shared_name=handle_name,
                                name=name)
                    else:
                        initial_value = initial_value()
                        self._handle = gen_resource_variable_ops.var_handle_op(
                            shape=initial_value.get_shape(),
                            dtype=initial_value.dtype.base_dtype,
                            shared_name=handle_name,
                            name=name,
                            container="")
                # pylint: enable=protected-access

                # Or get the initial value from a Tensor or Python object.
                else:
                    with ops.name_scope("Initializer"):
                        initial_value = ops.convert_to_tensor(
                            initial_value, name="initial_value", dtype=dtype)
                    # pylint: disable=protected-access
                    if (in_graph_mode and initial_value is not None
                            and initial_value.op._get_control_flow_context()
                            is not None):
                        raise ValueError(
                            "Initializer for variable %s is from inside a control-flow "
                            "construct, such as a loop or conditional. When creating a "
                            "variable inside a loop or conditional, use a lambda as the "
                            "initializer." % name)
                    # pylint: enable=protected-access
                    self._handle = gen_resource_variable_ops.var_handle_op(
                        shape=initial_value.get_shape(),
                        dtype=initial_value.dtype.base_dtype,
                        shared_name=handle_name,
                        name=name,
                        container="")

                self._initial_value = initial_value if in_graph_mode else None
                self._handle_name = handle_name + ":0"
                self._dtype = initial_value.dtype.base_dtype
                self._constraint = constraint

                if in_graph_mode:
                    with ops.name_scope("IsInitialized"):
                        self._is_initialized_op = (
                            gen_resource_variable_ops.var_is_initialized_op(
                                self._handle))
                    if initial_value is not None:
                        with ops.name_scope("Assign") as n, ops.colocate_with(
                                self._handle):
                            self._initializer_op = (
                                gen_resource_variable_ops.assign_variable_op(
                                    self._handle,
                                    self._build_initializer_expr(
                                        initial_value),
                                    name=n))
                    with ops.name_scope("Read"), ops.colocate_with(
                            self._handle):
                        # Manually assign reads to the handle's device to avoid log
                        # messages.
                        with ops.device(self._handle.device):
                            value = gen_resource_variable_ops.read_variable_op(
                                self._handle, dtype=self._dtype)
                        self._graph_element = value
                        if caching_device is not None:
                            # Variables may be created in a tf.device() or ops.colocate_with()
                            # context. At the same time, users would expect caching device to
                            # be independent of this context, and/or would not expect the
                            # current device context to be merged with the caching device
                            # spec.  Therefore we reset the colocation stack before creating
                            # the cached value. Note that resetting the colocation stack will
                            # also reset the device stack.
                            with ops.colocate_with(None, ignore_existing=True):
                                with ops.device(caching_device):
                                    self._cached_value = array_ops.identity(
                                        value)
                        else:
                            self._cached_value = None
                else:
                    gen_resource_variable_ops.assign_variable_op(
                        self._handle, initial_value)
                    self._is_initialized_op = None
                    self._initializer_op = None
                    self._graph_element = None
                    if caching_device:
                        with ops.device(caching_device):
                            self._cached_value = gen_resource_variable_ops.read_variable_op(
                                self._handle, dtype=self._dtype)
                    else:
                        self._cached_value = None
                ops.add_to_collections(collections, self)
Ejemplo n.º 32
0
 def _read_variable_op(self):
   with ops.control_dependencies([self._parent_op]):
     return gen_resource_variable_ops.read_variable_op(self._handle,
                                                       self._dtype)
Ejemplo n.º 33
0
  def __init__(self,
               initial_value=None,
               name=None,
               trainable=True,
               collections=None,
               dtype=None,
               shape=None):
    """Creates a variable.

    Args:
      initial_value: A `Tensor` or Python object convertible to a `Tensor`
        representing the initial value of this variable.
      name: The name of this variable. Automatically uniquified.
      trainable: Whether the global read of this variable will be used for
        training.
      collections: Additional collections to which the `read` operation for
        this variable is to be added. Defaults to [].
      dtype: The type of this variable. Can be omitted if it can be deduced
        from the initial_value. If different from the type of the initial
        value it will be cast to this type.
      shape: The shape of this variable. Only specify if there is no initial
        value but shape inference is desired.
    """
    if initial_value is not None:
      initial_value = ops.convert_to_tensor(initial_value)
    if dtype is None:
      assert initial_value is not None, ("Trying to create a resource variable "
                                         "with no dtype or initial value. At"
                                         " least one of these must be set.")
      dtype = initial_value.dtype
    elif initial_value is not None:
      initial_value = math_ops.cast(initial_value, dtype)
    if shape is None:
      if initial_value is not None:
        shape = initial_value.get_shape().as_proto()
      else:
        shape = tensor_shape.unknown_shape()
    else:
      shape = tensor_shape.as_shape(shape)

    self._dtype = dtype
    with ops.name_scope(name, "Variable", [initial_value]) as name:
      self._handle = gen_resource_variable_ops.var_handle_op(shared_name=name,
                                                             name=name,
                                                             dtype=dtype,
                                                             shape=shape)

      with ops.name_scope("IsInitialized"):
        self._is_initialized_op = (
            gen_resource_variable_ops.var_is_initialized_op(self._handle))
      if initial_value is not None:
        with ops.name_scope("Create"):
          self._initialize_op = gen_resource_variable_ops.create_variable_op(
              self._handle, initial_value)
        resources.register_resource(self._handle,
                                    self._initialize_op,
                                    self._is_initialized_op)

      with ops.name_scope("Read"):
        self._value = gen_resource_variable_ops.read_variable_op(
            self._handle, dtype=self._dtype)
      _register_variable_read(
          self._value, trainable=trainable, collections=collections)
Ejemplo n.º 34
0
 def _read_variable_op(self):
   if _enclosing_tpu_context() is None:
     return self._primary_var.read_value()
   v = gen_resource_variable_ops.read_variable_op(self.handle, self._dtype)
   return v
    def _init_from_args(self,
                        initial_value=None,
                        trainable=True,
                        collections=None,
                        validate_shape=True,
                        caching_device=None,
                        name=None,
                        dtype=None):
        """Creates a variable.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called.
        (Note that initializer functions from init_ops.py must first be bound
         to a shape before being used here.)
      trainable: If `True`, the default, also adds the variable to the graph
        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
        the default list of variables to use by the `Optimizer` classes.
      collections: List of graph collections keys. The new variable is added to
        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
      validate_shape: Ignored. Provided for compatibility with tf.Variable.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.
    """
        if initial_value is None:
            raise ValueError("initial_value must be specified.")
        init_from_fn = callable(initial_value)

        if collections is None:
            collections = [ops.GraphKeys.GLOBAL_VARIABLES]
        if not isinstance(collections, (list, tuple, set)):
            raise ValueError(
                "collections argument to Variable constructor must be a list, tuple, "
                "or set. Got %s of type %s" % (collections, type(collections)))
        if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
            collections = list(collections) + [
                ops.GraphKeys.TRAINABLE_VARIABLES
            ]
        self._save_slice_info = None
        with ops.control_dependencies(None):
            with ops.name_scope(
                    name, "Variable",
                [] if init_from_fn else [initial_value]) as name:
                if init_from_fn:
                    # Use attr_scope and device(None) to simulate the behavior of
                    # colocate_with when the variable we want to colocate with doesn't
                    # yet exist.
                    # pylint: disable=protected-access
                    true_name = ops._name_from_scope_name(name)
                    attr = attr_value_pb2.AttrValue(
                        list=attr_value_pb2.AttrValue.ListValue(
                            s=[compat.as_bytes("loc:@%s" % true_name)]))
                    # pylint: disable=protected-access
                    with ops.get_default_graph()._attr_scope({"_class": attr}):
                        with ops.name_scope("Initializer"), ops.device(None):
                            self._initial_value = ops.convert_to_tensor(
                                initial_value(),
                                name="initial_value",
                                dtype=dtype)
                        self._handle = gen_resource_variable_ops.var_handle_op(
                            shape=self._initial_value.get_shape(),
                            dtype=self._initial_value.dtype.base_dtype,
                            shared_name=name,
                            name=name)

                # Or get the initial value from a Tensor or Python object.
                else:
                    self._initial_value = ops.convert_to_tensor(
                        initial_value, name="initial_value", dtype=dtype)
                    self._handle = gen_resource_variable_ops.var_handle_op(
                        shape=self._initial_value.get_shape(),
                        dtype=self._initial_value.dtype.base_dtype,
                        shared_name=name,
                        name=name)

                self._dtype = self._initial_value.dtype.base_dtype

                with ops.name_scope("IsInitialized"):
                    self._is_initialized_op = (
                        gen_resource_variable_ops.var_is_initialized_op(
                            self._handle))
                if initial_value is not None:
                    with ops.name_scope("Assign") as n, ops.colocate_with(
                            self._handle):
                        self._initialize_op = gen_resource_variable_ops.assign_variable_op(
                            self._handle, self._initial_value, name=n)
                with ops.name_scope("Read"), ops.colocate_with(self._handle):
                    value = gen_resource_variable_ops.read_variable_op(
                        self._handle, dtype=self._dtype)
                    self._graph_element = value
                    if caching_device is not None:
                        with ops.device(caching_device):
                            self._cached_value = array_ops.identity(value)
                    else:
                        self._cached_value = None
                    ops.add_to_collections(collections, self)
Ejemplo n.º 36
0
 def _read_variable_op(self):
   if _enclosing_tpu_context() is None:
     return self._primary_var.read_value()
   v = gen_resource_variable_ops.read_variable_op(self.handle, self._dtype)
   return v
Ejemplo n.º 37
0
  def _init_from_args(self,
                      initial_value=None,
                      trainable=True,
                      collections=None,
                      validate_shape=True,
                      caching_device=None,
                      name=None,
                      dtype=None):

    """Creates a variable.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called.
        (Note that initializer functions from init_ops.py must first be bound
         to a shape before being used here.)
      trainable: If `True`, the default, also adds the variable to the graph
        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
        the default list of variables to use by the `Optimizer` classes.
      collections: List of graph collections keys. The new variable is added to
        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
      validate_shape: Ignored. Provided for compatibility with tf.Variable.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.
    """
    if initial_value is None:
      raise ValueError("initial_value must be specified.")
    init_from_fn = callable(initial_value)

    if collections is None:
      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    if not isinstance(collections, (list, tuple, set)):
      raise ValueError(
          "collections argument to Variable constructor must be a list, tuple, "
          "or set. Got %s of type %s" % (collections, type(collections)))
    if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
    self._save_slice_info = None
    with ops.control_dependencies(None):
      with ops.name_scope(name, "Variable", [] if init_from_fn else
                          [initial_value]) as name:
        # pylint: disable=protected-access
        true_name = ops._name_from_scope_name(name)
        if init_from_fn:
          # Use attr_scope and device(None) to simulate the behavior of
          # colocate_with when the variable we want to colocate with doesn't
          # yet exist.
          attr = attr_value_pb2.AttrValue(
              list=attr_value_pb2.AttrValue.ListValue(
                  s=[compat.as_bytes("loc:@%s" % true_name)]))
          with ops.get_default_graph()._attr_scope({"_class": attr}):
            with ops.name_scope("Initializer"), ops.device(None):
              self._initial_value = ops.convert_to_tensor(
                  initial_value(), name="initial_value", dtype=dtype)
            self._handle = gen_resource_variable_ops.var_handle_op(
                shape=self._initial_value.get_shape(),
                dtype=self._initial_value.dtype.base_dtype,
                shared_name=true_name, name=name)
        # pylint: enable=protected-access

        # Or get the initial value from a Tensor or Python object.
        else:
          self._initial_value = ops.convert_to_tensor(
              initial_value, name="initial_value", dtype=dtype)
          self._handle = gen_resource_variable_ops.var_handle_op(
              shape=self._initial_value.get_shape(),
              dtype=self._initial_value.dtype.base_dtype,
              shared_name=true_name, name=name)

        self._dtype = self._initial_value.dtype.base_dtype

        with ops.name_scope("IsInitialized"):
          self._is_initialized_op = (
              gen_resource_variable_ops.var_is_initialized_op(self._handle))
        if initial_value is not None:
          with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
            self._initialize_op = gen_resource_variable_ops.assign_variable_op(
                self._handle, self._initial_value, name=n)
        with ops.name_scope("Read"), ops.colocate_with(self._handle):
          # Manually assign reads to the handle's device to avoid log messages.
          with ops.device(self._handle.device):
            value = gen_resource_variable_ops.read_variable_op(
                self._handle, dtype=self._dtype)
          self._graph_element = value
          if caching_device is not None:
            # Variables may be created in a tf.device() or ops.colocate_with()
            # context. At the same time, users would expect caching device to be
            # independent of this context, and/or would not expect the current
            # device context to be merged with the caching device spec.
            # Therefore we reset the colocation stack before creating the cached
            # value. Note that resetting the colocation stack will also reset
            # the device stack.
            with ops.colocate_with(None, ignore_existing=True):
              with ops.device(caching_device):
                self._cached_value = array_ops.identity(value)
          else:
            self._cached_value = None
          ops.add_to_collections(collections, self)
Ejemplo n.º 38
0
 def then_branch():
     gen_resource_variable_ops.assign_variable_op(
         v.handle, v + 1)
     return gen_resource_variable_ops.read_variable_op(
         v.handle, v.dtype)
Ejemplo n.º 39
0
 def body(_):
     gen_resource_variable_ops.assign_variable_op(
         v.handle, v + 1)
     return gen_resource_variable_ops.read_variable_op(
         v.handle, v.dtype)
Ejemplo n.º 40
0
 def inner_fn():
     gen_resource_variable_ops.assign_variable_op(
         v.handle, v + 1)
     return gen_resource_variable_ops.read_variable_op(
         v.handle, v.dtype)