コード例 #1
0
    def test_cond_with_variable(self):
        pred = self.device.pack(
            [constant_op.constant(True),
             constant_op.constant(False)])
        capture = self.device.pack(
            [constant_op.constant([1.]),
             constant_op.constant([2.])])
        with self.device:
            v = None

            @def_function.function
            def true_branch():
                nonlocal v
                if v is None:
                    v = variables.Variable(constant_op.constant(2.))
                return v * capture

            result = control_flow_ops.cond(
                pred, true_branch, def_function.function(lambda: capture * 4.))
        self.assertAllClose([[2.], [8.]], self.device.unpack(result))
        self.assertAllClose([2., 2.], self.device.unpack(v))
        # There are two unique variable handles with separate storage.
        h1, _ = self.device.unpack(v.handle)
        gen_resource_variable_ops.assign_variable_op(h1,
                                                     constant_op.constant(3.))
        self.assertAllClose([3., 2.], self.device.unpack(v))
コード例 #2
0
 def restore(self, tensors, restored_shapes=None):
     with ops.device(self._handle.device):
         # Combine the restored tensors into one parallel tensor to assign.
         bundled = self._parallel_device.pack(tensors)
         gen_resource_variable_ops.assign_variable_op(
             resource=self._handle,
             # Squeeze out the dummy first axis we added when saving.
             value=array_ops.squeeze(bundled, axis=0))
コード例 #3
0
    def write_var_in_while():
      gen_resource_variable_ops.read_variable_op(
          v.handle, v.dtype, name="read1")

      result = build_functional_op(v)
      gen_resource_variable_ops.read_variable_op(
          v.handle, v.dtype, name="read2")
      gen_resource_variable_ops.assign_variable_op(v.handle, v + 1)
      return result
コード例 #4
0
 def testNoControlDepsBetweenVariableReads(self):
     with context.graph_mode(), self.cached_session():
         v = resource_variable_ops.ResourceVariable(1.0)
         self.evaluate(variables.global_variables_initializer())
         with acd.AutomaticControlDependencies():
             read_op1 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             read_op2 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             gen_resource_variable_ops.assign_variable_op(v.handle, v + 1)
         self.assertNotIn(read_op1, read_op2.control_inputs)
         self.assertNotIn(read_op2, read_op1.control_inputs)
コード例 #5
0
  def testVariableMultipleReadsAndWrites(self):
    with context.graph_mode(), self.cached_session():
      v = resource_variable_ops.ResourceVariable(1.0)
      self.evaluate(variables.global_variables_initializer())
      with acd.AutomaticControlDependencies() as c:
        # 2 reads -> 2 writes -> 2 reads -> 2 writes.
        read_op1 = gen_resource_variable_ops.read_variable_op(
            v.handle, v.dtype).op
        read_op2 = gen_resource_variable_ops.read_variable_op(
            v.handle, v.dtype).op
        assign_op1 = gen_resource_variable_ops.assign_variable_op(
            v.handle, v + 1)
        assign_op2 = gen_resource_variable_ops.assign_variable_op(
            v.handle, v + 1)
        read_op3 = gen_resource_variable_ops.read_variable_op(
            v.handle, v.dtype).op
        read_op4 = gen_resource_variable_ops.read_variable_op(
            v.handle, v.dtype).op
        assign_op3 = gen_resource_variable_ops.assign_variable_op(
            v.handle, v + 1)
        assign_op4 = gen_resource_variable_ops.assign_variable_op(
            v.handle, v + 1)
        # Read ops get added to control outputs only if they have consumers.
        c.mark_as_return(read_op1.outputs[0])
        c.mark_as_return(read_op2.outputs[0])
        c.mark_as_return(read_op3.outputs[0])
        c.mark_as_return(read_op4.outputs[0])

      # Verify the control edges.
      self.assertIn(read_op1, assign_op1.control_inputs)
      self.assertIn(read_op2, assign_op1.control_inputs)
      self.assertIn(assign_op1, assign_op2.control_inputs)
      self.assertIn(assign_op2, read_op3.control_inputs)
      self.assertIn(assign_op2, read_op4.control_inputs)
      self.assertIn(read_op3, assign_op3.control_inputs)
      self.assertIn(read_op4, assign_op3.control_inputs)
      self.assertIn(assign_op3, assign_op4.control_inputs)

      # There should be no control deps between reads.
      read_ops = [read_op1, read_op2, read_op3, read_op4]
      for src_op, tgt_op in itertools.product(read_ops, read_ops):
        self.assertNotIn(src_op, tgt_op.control_inputs)

      # Reads must be in `ops_which_must_run`.
      self.assertIn(read_op1, c.ops_which_must_run)
      self.assertIn(read_op2, c.ops_which_must_run)
      self.assertIn(read_op3, c.ops_which_must_run)
      self.assertIn(read_op4, c.ops_which_must_run)
      # Last write must be in `ops_which_must_run`.
      self.assertIn(assign_op4, c.ops_which_must_run)
コード例 #6
0
    def test_shared_variable(self):
        x = gen_resource_variable_ops.var_handle_op(dtype=tf.float32,
                                                    shape=(1, 2),
                                                    shared_name="variable_1")
        gen_resource_variable_ops.assign_variable_op(x,
                                                     tf.constant([[1.0, 2.0]]))
        y = gen_resource_variable_ops.var_handle_op(dtype=tf.float32,
                                                    shape=(1, 2),
                                                    shared_name="variable_1")
        gen_resource_variable_ops.assign_variable_op(y,
                                                     tf.constant([[2.0, 3.0]]))
        read_x = gen_resource_variable_ops.read_variable_op(x,
                                                            dtype=tf.float32)
        read_y = gen_resource_variable_ops.read_variable_op(y,
                                                            dtype=tf.float32)
        self.assertTrue(tensor_equal(read_x, read_y))

        x = gen_resource_variable_ops.var_handle_op(
            dtype=tf.float32, shape=(1, 2), shared_name=context.shared_name())
        gen_resource_variable_ops.assign_variable_op(x,
                                                     tf.constant([[1.0, 2.0]]))
        y = gen_resource_variable_ops.var_handle_op(
            dtype=tf.float32, shape=(1, 2), shared_name=context.shared_name())
        gen_resource_variable_ops.assign_variable_op(y,
                                                     tf.constant([[2.0, 3.0]]))
        read_x = gen_resource_variable_ops.read_variable_op(x,
                                                            dtype=tf.float32)
        read_y = gen_resource_variable_ops.read_variable_op(y,
                                                            dtype=tf.float32)
        self.assertFalse(tensor_equal(read_x, read_y))
コード例 #7
0
 def testIdentityPassThrough(self):
   with context.graph_mode(), self.cached_session():
     v = resource_variable_ops.ResourceVariable(1.0)
     self.evaluate(variables.global_variables_initializer())
     with acd.AutomaticControlDependencies():
       gen_resource_variable_ops.assign_variable_op(v.handle, v + 1)
       identity_handle = gen_array_ops.identity(v.handle)
       assign_op2 = gen_resource_variable_ops.assign_variable_op(
           v.handle, v + 1)
       read_op = gen_resource_variable_ops.read_variable_op(
           identity_handle, v.dtype).op
     # Read should have a control dep from second last write even
     # with Identity applied to resource.
     self.assertIn(assign_op2, read_op.control_inputs)
コード例 #8
0
  def initialize(self):
    with ops.name_scope(self._name, "Variable", skip_on_eager=False) as name:
      with ops.colocate_with(self._handle), ops.name_scope("Initializer"):
        if callable(self._initial_value):
          initial_value = self._initial_value()
        else:
          initial_value = self._initial_value

        if not initial_value.shape.is_compatible_with(self._shape):
          raise ValueError(
              f"In this `tf.Variable` creation, the initial value's shape "
              f"({initial_value.shape}) is not compatible with "
              f"the explicitly supplied `shape` argument ({self._shape}).")
        assert self._dtype is initial_value.dtype.base_dtype
      gen_resource_variable_ops.assign_variable_op(self._handle, initial_value)
コード例 #9
0
  def assign(self, value, use_locking=None, name=None, read_value=True):
    """Assigns a new value to this variable.

    Args:
      value: A `Tensor`. The new value for this variable.
      use_locking: If `True`, use locking during the assignment.
      name: The name to use for the assignment.
      read_value: A `bool`. Whether to read and return the new value of the
          variable or not.

    Returns:
      If `read_value` is `True`, this method will return the new value of the
      variable after the assignment has completed. Otherwise, when in graph mode
      it will return the `Operation` that does the assignment, and when in eager
      mode it will return `None`.
    """
    # Note: not depending on the cached value here since this can used to
    # initialize the variable.
    with _handle_graph(self.handle):
      value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
      self._shape.assert_is_compatible_with(value_tensor.shape)
      assign_op = gen_resource_variable_ops.assign_variable_op(
          self.handle, value_tensor, name=name)
      if read_value:
        return self._lazy_read(assign_op)
    return assign_op
コード例 #10
0
  def assign(self, value, use_locking=None, name=None, read_value=False):
    """Assign `value` to all replicas.

    Outside of the tpu.rewrite context, assign explicitly to all replicas.
    Inside of the tpu.rewrite context, assigns to the local replica.

    Arguments:
      value: Tensor to assign
      use_locking: ignored
      name: ignored
      read_value: return the value from the assignment
    Returns:
      Assignment operation, or new value of the variable if `read_value` is True
    """
    del use_locking
    if _enclosing_tpu_context() is None:
      assign_ops = []
      with self._assign_dependencies():
        for var in self._vars:
          assign_ops.append(var.assign(value, use_locking=None, name=name))

        if read_value:
          with ops.control_dependencies(assign_ops):
            return self.read_value()
        else:
          return control_flow_ops.group(assign_ops)

    with _handle_graph(self.handle), self._assign_dependencies():
      value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
      assign_op = gen_resource_variable_ops.assign_variable_op(
          self.handle, value_tensor, name=name)
    if read_value:
      return self._read_variable_op()
    return assign_op
コード例 #11
0
    def _read_variable_op(self, do_prefetch=True):
        resource_variable_ops.variable_accessed(self)
        if self.model_mode == "train":
            if do_prefetch:
                with ops.control_dependencies([
                        gen_resource_variable_ops.assign_variable_op(
                            self._handle,
                            self.prefetch_values(),
                            name="AssignBeforeReadVariable")
                ]):
                    _result = gen_resource_variable_ops.read_variable_op(
                        self._handle, self._dtype)
            else:
                _result = gen_resource_variable_ops.read_variable_op(
                    self._handle, self._dtype)
        else:
            _result = self.prefetch_values()

        if not context.executing_eagerly():
            # Note that if a control flow context is active the input of the read op
            # might not actually be the handle. This line bypasses it.
            tape.record_operation("ReadVariableOp", [_result], [self._handle],
                                  lambda x: [x])
        result = self.transform(_result)
        return result
コード例 #12
0
 def assign(self, value, use_locking=None, name=None):
     value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
     self._shape.assert_is_compatible_with(value_tensor.shape)
     return self._lazy_read(
         gen_resource_variable_ops.assign_variable_op(self.handle,
                                                      value_tensor,
                                                      name=name))
コード例 #13
0
  def assign(self, value, use_locking=None, name=None, read_value=True):
    """
    Assigns a new value to this variable.
    To discriminate with ResourceVariable, the shadow always uses a
    variant space to hold the temporary embedding lookup buffer.

    Args:
      value: A `Tensor`. The new value for this variable.
      use_locking: If `True`, use locking during the assignment.
      name: The name to use for the assignment.
      read_value: A `bool`. Whether to read and return the new value of the
        variable or not.

    Returns:
      If `read_value` is `True`, this method will return the new value of the
      variable after the assignment has completed. Otherwise, when in graph mode
      it will return the `Operation` that does the assignment, and when in eager
      mode it will return `None`.
    """
    # Note: not depending on the cached value here since this can be used to
    # initialize the variable.
    with resource_variable_ops._handle_graph(self.handle):
      value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
      assign_op = gen_resource_variable_ops.assign_variable_op(self.handle,
                                                               value_tensor,
                                                               name=name)
      if read_value:
        return self._lazy_read(assign_op)
    return assign_op
コード例 #14
0
def shape_safe_assign_variable_handle(handle, shape, value, name=None):
  """Helper that checks shape compatibility and assigns variable."""
  value_tensor = ops.convert_to_tensor(value)
  shape.assert_is_compatible_with(value_tensor.shape)
  return gen_resource_variable_ops.assign_variable_op(handle,
                                                      value_tensor,
                                                      name=name)
コード例 #15
0
  def assign(self, value, use_locking=None, name=None, read_value=True):
    """Assigns a new value to this variable.

    Args:
      value: A `Tensor`. The new value for this variable.
      use_locking: If `True`, use locking during the assignment.
      name: The name to use for the assignment.
      read_value: A `bool`. Whether to read and return the new value of the
          variable or not.

    Returns:
      If `read_value` is `True`, this method will return the new value of the
      variable after the assignment has completed. Otherwise, when in graph mode
      it will return the `Operation` that does the assignment, and when in eager
      mode it will return `None`.
    """
    # Note: not depending on the cached value here since this can used to
    # initialize the variable.
    with _handle_graph(self.handle):
      value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
      self._shape.assert_is_compatible_with(value_tensor.shape)
      assign_op = gen_resource_variable_ops.assign_variable_op(
          self.handle, value_tensor, name=name)
      if read_value:
        return self._lazy_read(assign_op)
    return assign_op
コード例 #16
0
  def assign(self, value, use_locking=None, name=None, read_value=False):
    """Assign `value` to all replicas.

    Outside of the tpu.rewrite context, assign explicitly to all replicas.
    Inside of the tpu.rewrite context, assigns to the local replica.

    Arguments:
      value: Tensor to assign
      use_locking: ignored
      name: ignored
      read_value: return the value from the assignment
    Returns:
      Assignment operation, or new value of the variable if `read_value` is True
    """
    del use_locking
    if _enclosing_tpu_context() is None:
      assign_ops = []
      with self._assign_dependencies():
        for var in self._vars:
          assign_ops.append(var.assign(value, use_locking=None, name=name))

        if read_value:
          with ops.control_dependencies(assign_ops):
            return self.read_value()
        else:
          return control_flow_ops.group(assign_ops)

    with _handle_graph(self.handle), self._assign_dependencies():
      value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
      assign_op = gen_resource_variable_ops.assign_variable_op(
          self.handle, value_tensor, name=name)
    if read_value:
      return self._read_variable_op()
    return assign_op
コード例 #17
0
def shape_safe_assign_variable_handle(handle, shape, value, name=None):
    """Helper that checks shape compatibility and assigns variable."""
    value_tensor = ops.convert_to_tensor(value)
    shape.assert_is_compatible_with(value_tensor.shape)
    return gen_resource_variable_ops.assign_variable_op(handle,
                                                        value_tensor,
                                                        name=name)
コード例 #18
0
 def assign(self, value, use_locking=None, name=None):
   value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
   self._shape.assert_is_compatible_with(value_tensor.shape)
   return self._lazy_read(
       gen_resource_variable_ops.assign_variable_op(
           self.handle,
           value_tensor,
           name=name))
コード例 #19
0
 def assign(self, value, use_locking=None, name=None):
   with ops.control_dependencies([
       gen_resource_variable_ops.assign_variable_op(
           self.handle,
           ops.convert_to_tensor(value, dtype=self.dtype),
           name=name)
   ]):
     return self.read_value()
コード例 #20
0
ファイル: resource_variable_ops.py プロジェクト: lengjia/RRL
 def assign(self, value, use_locking=None, name=None):
     with ops.control_dependencies([
             gen_resource_variable_ops.assign_variable_op(
                 self.handle,
                 ops.convert_to_tensor(value, dtype=self.dtype),
                 name=name)
     ]):
         return self.read_value()
コード例 #21
0
 def assign(self, value, use_locking=None, name=None, read_value=False):
   del use_locking
   with _handle_graph(self.handle), self._assign_dependencies():
     value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
     assign_op = gen_resource_variable_ops.assign_variable_op(
         self.handle, value_tensor, name=name)
   if read_value:
     return self._read_variable_op()
   return assign_op
コード例 #22
0
 def assign(self, value, use_locking=None, name=None):
     value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
     self._shape.assert_is_compatible_with(value_tensor.shape)
     with ops.control_dependencies([
             gen_resource_variable_ops.assign_variable_op(self.handle,
                                                          value_tensor,
                                                          name=name)
     ]):
         return self.read_value()
コード例 #23
0
 def assign(self, value, use_locking=None, name=None):
   value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
   self._shape.assert_is_compatible_with(value_tensor.shape)
   with ops.control_dependencies([
       gen_resource_variable_ops.assign_variable_op(
           self.handle,
           value_tensor,
           name=name)
   ]):
     return self.read_value()
コード例 #24
0
 def testVariableWriteThenRead(self):
     with context.graph_mode(), self.cached_session():
         v = resource_variable_ops.ResourceVariable(1.0)
         self.evaluate(variables.global_variables_initializer())
         with acd.AutomaticControlDependencies():
             assign_op = gen_resource_variable_ops.assign_variable_op(
                 v.handle, v + 1)
             read_op1 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             read_op2 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
         # Reads should have a control dep from the last write.
         self.assertIn(assign_op, read_op1.control_inputs)
         self.assertIn(assign_op, read_op2.control_inputs)
         # There should be no control deps between reads.
         self.assertNotIn(read_op1, read_op2.control_inputs)
         self.assertNotIn(read_op2, read_op1.control_inputs)
コード例 #25
0
 def testVariableReadsNotInOpsWithMustRun(self):
     with context.graph_mode(), self.cached_session():
         v = resource_variable_ops.ResourceVariable(1.0)
         self.evaluate(variables.global_variables_initializer())
         with acd.AutomaticControlDependencies() as c:
             read_op1 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             read_op2 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             assign_op = gen_resource_variable_ops.assign_variable_op(
                 v.handle, v + 1)
         # Reads must not be in `ops_which_must_run` since those get added to the
         # `control_outputs`.
         self.assertNotIn(read_op1, c.ops_which_must_run)
         self.assertNotIn(read_op2, c.ops_which_must_run)
         # Last write must be in `ops_which_must_run`.
         self.assertIn(assign_op, c.ops_which_must_run)
コード例 #26
0
 def testVariableReadThenWrite(self):
     with context.graph_mode(), self.cached_session():
         v = resource_variable_ops.ResourceVariable(1.0)
         self.evaluate(variables.global_variables_initializer())
         with acd.AutomaticControlDependencies():
             read_op1 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             read_op2 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             assign_op = gen_resource_variable_ops.assign_variable_op(
                 v.handle, v + 1)
         # Writes should have control deps from "all" reads since last write
         # or start of the code block.
         self.assertIn(read_op1, assign_op.control_inputs)
         self.assertIn(read_op2, assign_op.control_inputs)
         # There should be no control deps between reads.
         self.assertNotIn(read_op1, read_op2.control_inputs)
         self.assertNotIn(read_op2, read_op1.control_inputs)
コード例 #27
0
 def testManualControlDepMonitoringAttrNotAdded(self):
     with context.graph_mode(), self.cached_session():
         v = resource_variable_ops.ResourceVariable(1.0)
         self.evaluate(variables.global_variables_initializer())
         with acd.AutomaticControlDependencies():
             read_op1 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             read_op2 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             assign_op = gen_resource_variable_ops.assign_variable_op(
                 v.handle, v + 1)
         # Writes should have control deps automatically added from "all" reads
         # since last write or start of the code block.
         self.assertIn(read_op1, assign_op.control_inputs)
         self.assertIn(read_op2, assign_op.control_inputs)
         # But, we shouldn't add the monitoring attribute in this case.
         with self.assertRaises(ValueError):
             assign_op.get_attr("_has_manual_control_dependencies")
         with self.assertRaises(ValueError):
             read_op1.get_attr("_has_manual_control_dependencies")
         with self.assertRaises(ValueError):
             read_op2.get_attr("_has_manual_control_dependencies")
コード例 #28
0
 def then_branch():
     gen_resource_variable_ops.assign_variable_op(
         v.handle, v + 1)
     return gen_resource_variable_ops.read_variable_op(
         v.handle, v.dtype)
コード例 #29
0
 def fn_with_write():
     gen_resource_variable_ops.assign_variable_op(v.handle, v + 1)
     return gen_resource_variable_ops.read_variable_op(
         v.handle, v.dtype)
コード例 #30
0
 def f():
   gen_resource_variable_ops.assign_variable_op(v.handle, 1)
   ops.get_default_graph().experimental_acd_manager.run_independently(
       gen_resource_variable_ops.assign_variable_op(v.handle, 2))
コード例 #31
0
 def inner_fn():
     gen_resource_variable_ops.assign_variable_op(
         v.handle, v + 1)
     return gen_resource_variable_ops.read_variable_op(
         v.handle, v.dtype)
コード例 #32
0
  def _init_from_args(self,
                      initial_value=None,
                      trainable=True,
                      collections=None,
                      validate_shape=True,
                      caching_device=None,
                      name=None,
                      dtype=None,
                      constraint=None):
    """Creates a variable.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called.
        (Note that initializer functions from init_ops.py must first be bound
         to a shape before being used here.)
      trainable: If `True`, the default, also adds the variable to the graph
        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
        the default list of variables to use by the `Optimizer` classes.
      collections: List of graph collections keys. The new variable is added to
        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
      validate_shape: Ignored. Provided for compatibility with tf.Variable.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).
      constraint: An optional projection function to be applied to the variable
        after being updated by an `Optimizer` (e.g. used to implement norm
        constraints or value constraints for layer weights). The function must
        take as input the unprojected Tensor representing the value of the
        variable and return the Tensor for the projected value
        (which must have the same shape). Constraints are not safe to
        use when doing asynchronous distributed training.

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.

    @compatibility(eager)
    When Eager Execution is enabled, variables are never added to collections.
    It is not implicitly added to the `GLOBAL_VARIABLES` or
    `TRAINABLE_VARIABLES` collections, and the `collections` argument is
    ignored.
    @end_compatibility
    """
    if initial_value is None:
      raise ValueError("initial_value must be specified.")
    init_from_fn = callable(initial_value)

    if collections is None:
      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    if not isinstance(collections, (list, tuple, set)):
      raise ValueError(
          "collections argument to Variable constructor must be a list, tuple, "
          "or set. Got %s of type %s" % (collections, type(collections)))
    if constraint is not None and not callable(constraint):
      raise ValueError("The `constraint` argument must be a callable.")

    if isinstance(initial_value, checkpointable.CheckpointInitialValue):
      self._maybe_initialize_checkpointable()
      self._update_uid = initial_value.checkpoint_position.restore_uid
      initial_value = initial_value.wrapped_value

    self._trainable = trainable
    if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
    self._save_slice_info = None
    # Store the graph key so optimizers know how to only retrieve variables from
    # this graph.
    self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
    with ops.init_scope():
      self._in_graph_mode = context.in_graph_mode()
      with ops.name_scope(name, "Variable", []
                          if init_from_fn else [initial_value]) as name:
        # pylint: disable=protected-access
        handle_name = ops._name_from_scope_name(name)
        if init_from_fn:
          # Use attr_scope and device(None) to simulate the behavior of
          # colocate_with when the variable we want to colocate with doesn't
          # yet exist.
          if self._in_graph_mode:
            attr = attr_value_pb2.AttrValue(
                list=attr_value_pb2.AttrValue.ListValue(
                    s=[compat.as_bytes("loc:@%s" % handle_name)]))
            with ops.get_default_graph()._attr_scope({"_class": attr}):
              with ops.name_scope("Initializer"), ops.device(None):
                initial_value = ops.convert_to_tensor(
                    initial_value(), name="initial_value", dtype=dtype)
              self._handle = _eager_safe_variable_handle(
                  shape=initial_value.get_shape(),
                  dtype=initial_value.dtype.base_dtype,
                  shared_name=handle_name,
                  name=name,
                  graph_mode=self._in_graph_mode)
              self._handle_device = (
                  self._handle.device if self._in_graph_mode else
                  context.get_default_context().device_name)
              self._shape = initial_value.get_shape()
          else:
            initial_value = initial_value()
            with ops.name_scope("Initializer"):
              initial_value = ops.convert_to_tensor(
                  initial_value, name="initial_value", dtype=dtype)
            self._handle = _eager_safe_variable_handle(
                shape=initial_value.get_shape(),
                dtype=initial_value.dtype.base_dtype,
                shared_name=handle_name,
                name=name,
                graph_mode=False)
            self._handle_device = (
                self._handle.device if self._in_graph_mode else
                context.get_default_context().device_name)
            self._shape = initial_value.get_shape()
        # pylint: enable=protected-access

        # Or get the initial value from a Tensor or Python object.
        else:
          with ops.name_scope("Initializer"):
            initial_value = ops.convert_to_tensor(
                initial_value, name="initial_value", dtype=dtype)
          # pylint: disable=protected-access
          if (self._in_graph_mode and initial_value is not None and
              initial_value.op._get_control_flow_context() is not None):
            raise ValueError(
                "Initializer for variable %s is from inside a control-flow "
                "construct, such as a loop or conditional. When creating a "
                "variable inside a loop or conditional, use a lambda as the "
                "initializer." % name)
          # pylint: enable=protected-access
          self._handle = _eager_safe_variable_handle(
              shape=initial_value.get_shape(),
              dtype=initial_value.dtype.base_dtype,
              shared_name=handle_name,
              name=name,
              graph_mode=self._in_graph_mode)
          self._handle_device = (self._handle.device if self._in_graph_mode else
                                 context.get_default_context().device_name)
          self._shape = initial_value.get_shape()

        self._initial_value = initial_value if self._in_graph_mode else None
        self._handle_name = handle_name + ":0"
        self._dtype = initial_value.dtype.base_dtype
        self._constraint = constraint

        if self._in_graph_mode:
          with ops.name_scope("IsInitialized"):
            self._is_initialized_op = (
                gen_resource_variable_ops.var_is_initialized_op(self._handle))
          if initial_value is not None:
            with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
              self._initializer_op = (
                  gen_resource_variable_ops.assign_variable_op(
                      self._handle,
                      self._try_guard_against_uninitialized_dependencies(
                          initial_value),
                      name=n))
          with ops.name_scope("Read"), ops.colocate_with(self._handle):
            # Manually assign reads to the handle's device to avoid log
            # messages.
            with ops.device(self._handle_device):
              value = self._read_variable_op()
            self._graph_element = value
            if caching_device is not None:
              # Variables may be created in a tf.device() or ops.colocate_with()
              # context. At the same time, users would expect caching device to
              # be independent of this context, and/or would not expect the
              # current device context to be merged with the caching device
              # spec.  Therefore we reset the colocation stack before creating
              # the cached value. Note that resetting the colocation stack will
              # also reset the device stack.
              with ops.colocate_with(None, ignore_existing=True):
                with ops.device(caching_device):
                  self._cached_value = array_ops.identity(value)
            else:
              self._cached_value = None
        else:
          gen_resource_variable_ops.assign_variable_op(self._handle,
                                                       initial_value)
          self._is_initialized_op = None
          self._initializer_op = None
          self._graph_element = None
          if caching_device:
            with ops.device(caching_device):
              self._cached_value = self._read_variable_op()
          else:
            self._cached_value = None
        if context.in_graph_mode():
          ops.add_to_collections(collections, self)
        elif ops.GraphKeys.GLOBAL_STEP in collections:
          ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self)

    if not self._in_graph_mode:
      # After the handle has been created, set up a way to clean it up when
      # executing eagerly. We'll hold the only reference to the deleter, so that
      # when this object is garbage collected the deleter will be too. This
      # means ResourceVariables can be part of reference cycles without those
      # cycles being uncollectable, and means that no __del__ will be defined at
      # all in graph mode.
      self._handle_deleter = EagerResourceDeleter(
          handle=self._handle, handle_device=self._handle_device)
コード例 #33
0
 def body(_):
     gen_resource_variable_ops.assign_variable_op(
         v.handle, v + 1)
     return gen_resource_variable_ops.read_variable_op(
         v.handle, v.dtype)
コード例 #34
0
    def _init_from_args(self,
                        initial_value=None,
                        trainable=True,
                        collections=None,
                        validate_shape=True,
                        caching_device=None,
                        name=None,
                        dtype=None):
        """Creates a variable.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called.
        (Note that initializer functions from init_ops.py must first be bound
         to a shape before being used here.)
      trainable: If `True`, the default, also adds the variable to the graph
        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
        the default list of variables to use by the `Optimizer` classes.
      collections: List of graph collections keys. The new variable is added to
        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
      validate_shape: Ignored. Provided for compatibility with tf.Variable.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.
    """
        if initial_value is None:
            raise ValueError("initial_value must be specified.")
        init_from_fn = callable(initial_value)

        if collections is None:
            collections = [ops.GraphKeys.GLOBAL_VARIABLES]
        if not isinstance(collections, (list, tuple, set)):
            raise ValueError(
                "collections argument to Variable constructor must be a list, tuple, "
                "or set. Got %s of type %s" % (collections, type(collections)))
        if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
            collections = list(collections) + [
                ops.GraphKeys.TRAINABLE_VARIABLES
            ]
        self._save_slice_info = None
        with ops.control_dependencies(None):
            with ops.name_scope(
                    name, "Variable",
                [] if init_from_fn else [initial_value]) as name:
                if init_from_fn:
                    # Use attr_scope and device(None) to simulate the behavior of
                    # colocate_with when the variable we want to colocate with doesn't
                    # yet exist.
                    # pylint: disable=protected-access
                    true_name = ops._name_from_scope_name(name)
                    attr = attr_value_pb2.AttrValue(
                        list=attr_value_pb2.AttrValue.ListValue(
                            s=[compat.as_bytes("loc:@%s" % true_name)]))
                    # pylint: disable=protected-access
                    with ops.get_default_graph()._attr_scope({"_class": attr}):
                        with ops.name_scope("Initializer"), ops.device(None):
                            self._initial_value = ops.convert_to_tensor(
                                initial_value(),
                                name="initial_value",
                                dtype=dtype)
                        self._handle = gen_resource_variable_ops.var_handle_op(
                            shape=self._initial_value.get_shape(),
                            dtype=self._initial_value.dtype.base_dtype,
                            shared_name=name,
                            name=name)

                # Or get the initial value from a Tensor or Python object.
                else:
                    self._initial_value = ops.convert_to_tensor(
                        initial_value, name="initial_value", dtype=dtype)
                    self._handle = gen_resource_variable_ops.var_handle_op(
                        shape=self._initial_value.get_shape(),
                        dtype=self._initial_value.dtype.base_dtype,
                        shared_name=name,
                        name=name)

                self._dtype = self._initial_value.dtype.base_dtype

                with ops.name_scope("IsInitialized"):
                    self._is_initialized_op = (
                        gen_resource_variable_ops.var_is_initialized_op(
                            self._handle))
                if initial_value is not None:
                    with ops.name_scope("Assign") as n, ops.colocate_with(
                            self._handle):
                        self._initialize_op = gen_resource_variable_ops.assign_variable_op(
                            self._handle, self._initial_value, name=n)
                with ops.name_scope("Read"), ops.colocate_with(self._handle):
                    value = gen_resource_variable_ops.read_variable_op(
                        self._handle, dtype=self._dtype)
                    self._graph_element = value
                    if caching_device is not None:
                        with ops.device(caching_device):
                            self._cached_value = array_ops.identity(value)
                    else:
                        self._cached_value = None
                    ops.add_to_collections(collections, self)
コード例 #35
0
  def __init__(self,
               initial_value=None,
               name=None,
               caching_device=None,
               trainable=True,
               collections=None,
               dtype=None,
               shape=None):

    """Creates a variable.

    Args:
      initial_value: A `Tensor` or Python object convertible to a `Tensor`
        representing the initial value of this variable.
      name: The name of this variable. Automatically uniquified.
      caching_device: device where the variable value's read by default.
      trainable: Whether the global read of this variable will be used for
        training.
      collections: Additional collections to which the `read` operation for
        this variable is to be added. Defaults to [].
      dtype: The type of this variable. Can be omitted if it can be deduced
        from the initial_value. If different from the type of the initial
        value it will be cast to this type.
      shape: The shape of this variable. Only specify if there is no initial
        value but shape inference is desired.
    """
    if initial_value is not None:
      if callable(initial_value):
        initial_value = initial_value()
      initial_value = ops.convert_to_tensor(initial_value)
    if dtype is None:
      assert initial_value is not None, ("Trying to create a resource variable "
                                         "with no dtype or initial value. At"
                                         " least one of these must be set.")
      dtype = initial_value.dtype
    elif initial_value is not None:
      initial_value = math_ops.cast(initial_value, dtype)
    if shape is None:
      if initial_value is not None:
        shape = initial_value.get_shape().as_proto()
      else:
        shape = tensor_shape.unknown_shape()
    else:
      shape = tensor_shape.as_shape(shape)

    self._dtype = dtype
    with ops.name_scope(name, "Variable", [initial_value]) as name:
      self._handle = gen_resource_variable_ops.var_handle_op(shared_name=name,
                                                             name=name,
                                                             dtype=dtype,
                                                             shape=shape)

      with ops.name_scope("IsInitialized"):
        self._is_initialized_op = (
            gen_resource_variable_ops.var_is_initialized_op(self._handle))
      if initial_value is not None:
        with ops.name_scope("Create"):
          self._initialize_op = gen_resource_variable_ops.assign_variable_op(
              self._handle, initial_value)
        resources.register_resource(self._handle,
                                    self._initialize_op,
                                    self._is_initialized_op)

      with ops.name_scope("Read"):
        if caching_device is not None:
          with ops.device(caching_device):
            self._value = gen_resource_variable_ops.read_variable_op(
                self._handle, dtype=self._dtype)
        else:
          self._value = gen_resource_variable_ops.read_variable_op(
              self._handle, dtype=self._dtype)
        # TODO(apassos) this is terrible
        self._value.initializer = self._initialize_op
      _register_variable_read(
          self._value, trainable=trainable, collections=collections)
コード例 #36
0
    def _init_from_args(self,
                        initial_value=None,
                        trainable=True,
                        collections=None,
                        validate_shape=True,
                        caching_device=None,
                        name=None,
                        dtype=None,
                        constraint=None):
        """Creates a variable.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called.
        (Note that initializer functions from init_ops.py must first be bound
         to a shape before being used here.)
      trainable: If `True`, the default, also adds the variable to the graph
        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
        the default list of variables to use by the `Optimizer` classes.
      collections: List of graph collections keys. The new variable is added to
        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
      validate_shape: Ignored. Provided for compatibility with tf.Variable.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).
      constraint: An optional projection function to be applied to the variable
        after being updated by an `Optimizer` (e.g. used to implement norm
        constraints or value constraints for layer weights). The function must
        take as input the unprojected Tensor representing the value of the
        variable and return the Tensor for the projected value
        (which must have the same shape). Constraints are not safe to
        use when doing asynchronous distributed training.

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.

    @compatibility(eager)
    When Eager Execution is enabled, variables are never added to collections.
    It is not implicitly added to the `GLOBAL_VARIABLES` or
    `TRAINABLE_VARIABLES` collections, and the `collections` argument is
    ignored.
    @end_compatibility
    """
        if initial_value is None:
            raise ValueError("initial_value must be specified.")
        init_from_fn = callable(initial_value)

        if collections is None:
            collections = [ops.GraphKeys.GLOBAL_VARIABLES]
        if not isinstance(collections, (list, tuple, set)):
            raise ValueError(
                "collections argument to Variable constructor must be a list, tuple, "
                "or set. Got %s of type %s" % (collections, type(collections)))
        if constraint is not None and not callable(constraint):
            raise ValueError("The `constraint` argument must be a callable.")

        self._trainable = trainable
        if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
            collections = list(collections) + [
                ops.GraphKeys.TRAINABLE_VARIABLES
            ]
        self._save_slice_info = None
        # Store the graph key so optimizers know how to only retrieve variables from
        # this graph.
        self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
        with ops.init_scope():
            self._in_graph_mode = context.in_graph_mode()
            with ops.name_scope(
                    name, "Variable",
                [] if init_from_fn else [initial_value]) as name:
                # pylint: disable=protected-access
                handle_name = ops._name_from_scope_name(name)
                if init_from_fn:
                    # Use attr_scope and device(None) to simulate the behavior of
                    # colocate_with when the variable we want to colocate with doesn't
                    # yet exist.
                    if self._in_graph_mode:
                        attr = attr_value_pb2.AttrValue(
                            list=attr_value_pb2.AttrValue.ListValue(
                                s=[compat.as_bytes("loc:@%s" % handle_name)]))
                        with ops.get_default_graph()._attr_scope(
                            {"_class": attr}):
                            with ops.name_scope("Initializer"), ops.device(
                                    None):
                                initial_value = ops.convert_to_tensor(
                                    initial_value(),
                                    name="initial_value",
                                    dtype=dtype)
                            self._handle = _eager_safe_variable_handle(
                                shape=initial_value.get_shape(),
                                dtype=initial_value.dtype.base_dtype,
                                shared_name=handle_name,
                                name=name,
                                graph_mode=self._in_graph_mode)
                            self._handle_device = (
                                self._handle.device if self._in_graph_mode else
                                context.get_default_context().device_name)
                            self._shape = initial_value.get_shape()
                    else:
                        initial_value = initial_value()
                        with ops.name_scope("Initializer"):
                            initial_value = ops.convert_to_tensor(
                                initial_value,
                                name="initial_value",
                                dtype=dtype)
                        self._handle = _eager_safe_variable_handle(
                            shape=initial_value.get_shape(),
                            dtype=initial_value.dtype.base_dtype,
                            shared_name=handle_name,
                            name=name,
                            graph_mode=False)
                        self._handle_device = (
                            self._handle.device if self._in_graph_mode else
                            context.get_default_context().device_name)
                        self._shape = initial_value.get_shape()
                # pylint: enable=protected-access

                # Or get the initial value from a Tensor or Python object.
                else:
                    with ops.name_scope("Initializer"):
                        initial_value = ops.convert_to_tensor(
                            initial_value, name="initial_value", dtype=dtype)
                    # pylint: disable=protected-access
                    if (self._in_graph_mode and initial_value is not None
                            and initial_value.op._get_control_flow_context()
                            is not None):
                        raise ValueError(
                            "Initializer for variable %s is from inside a control-flow "
                            "construct, such as a loop or conditional. When creating a "
                            "variable inside a loop or conditional, use a lambda as the "
                            "initializer." % name)
                    # pylint: enable=protected-access
                    self._handle = _eager_safe_variable_handle(
                        shape=initial_value.get_shape(),
                        dtype=initial_value.dtype.base_dtype,
                        shared_name=handle_name,
                        name=name,
                        graph_mode=self._in_graph_mode)
                    self._handle_device = (
                        self._handle.device if self._in_graph_mode else
                        context.get_default_context().device_name)
                    self._shape = initial_value.get_shape()

                self._initial_value = initial_value if self._in_graph_mode else None
                self._handle_name = handle_name + ":0"
                self._dtype = initial_value.dtype.base_dtype
                self._constraint = constraint

                if self._in_graph_mode:
                    with ops.name_scope("IsInitialized"):
                        self._is_initialized_op = (
                            gen_resource_variable_ops.var_is_initialized_op(
                                self._handle))
                    if initial_value is not None:
                        with ops.name_scope("Assign") as n, ops.colocate_with(
                                self._handle):
                            self._initializer_op = (
                                gen_resource_variable_ops.assign_variable_op(
                                    self._handle,
                                    self.
                                    _try_guard_against_uninitialized_dependencies(
                                        initial_value),
                                    name=n))
                    with ops.name_scope("Read"), ops.colocate_with(
                            self._handle):
                        # Manually assign reads to the handle's device to avoid log
                        # messages.
                        with ops.device(self._handle_device):
                            value = self._read_variable_op()
                        self._graph_element = value
                        if caching_device is not None:
                            # Variables may be created in a tf.device() or ops.colocate_with()
                            # context. At the same time, users would expect caching device to
                            # be independent of this context, and/or would not expect the
                            # current device context to be merged with the caching device
                            # spec.  Therefore we reset the colocation stack before creating
                            # the cached value. Note that resetting the colocation stack will
                            # also reset the device stack.
                            with ops.colocate_with(None, ignore_existing=True):
                                with ops.device(caching_device):
                                    self._cached_value = array_ops.identity(
                                        value)
                        else:
                            self._cached_value = None
                else:
                    gen_resource_variable_ops.assign_variable_op(
                        self._handle, initial_value)
                    self._is_initialized_op = None
                    self._initializer_op = None
                    self._graph_element = None
                    if caching_device:
                        with ops.device(caching_device):
                            self._cached_value = self._read_variable_op()
                    else:
                        self._cached_value = None
                if context.in_graph_mode():
                    ops.add_to_collections(collections, self)
                elif ops.GraphKeys.GLOBAL_STEP in collections:
                    ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self)

        if not self._in_graph_mode:
            # After the handle has been created, set up a way to clean it up when
            # executing eagerly. We'll hold the only reference to the deleter, so that
            # when this object is garbage collected the deleter will be too. This
            # means ResourceVariables can be part of reference cycles without those
            # cycles being uncollectable, and means that no __del__ will be defined at
            # all in graph mode.
            self._handle_deleter = EagerResourceDeleter(
                handle=self._handle, handle_device=self._handle_device)
コード例 #37
0
    def _init_from_args(self,
                        initial_value=None,
                        trainable=None,
                        collections=None,
                        caching_device=None,
                        name=None,
                        dtype=None,
                        constraint=None,
                        synchronization=None,
                        aggregation=None,
                        distribute_strategy=None,
                        shape=None):
        """Creates a variable.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called.
        (Note that initializer functions from init_ops.py must first be bound
         to a shape before being used here.)
      trainable: If `True`, the default, also adds the variable to the graph
        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
        the default list of variables to use by the `Optimizer` classes.
        Defaults to `True`, unless `synchronization` is set to `ON_READ`, in
        which case it defaults to `False`.
      collections: List of graph collections keys. The new variable is added to
        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).
      constraint: An optional projection function to be applied to the variable
        after being updated by an `Optimizer` (e.g. used to implement norm
        constraints or value constraints for layer weights). The function must
        take as input the unprojected Tensor representing the value of the
        variable and return the Tensor for the projected value
        (which must have the same shape). Constraints are not safe to
        use when doing asynchronous distributed training.
      synchronization: Indicates when a distributed a variable will be
        aggregated. Accepted values are constants defined in the class
        `tf.VariableSynchronization`. By default the synchronization is set to
        `AUTO` and the current `DistributionStrategy` chooses
        when to synchronize.
      aggregation: Indicates how a distributed variable will be aggregated.
        Accepted values are constants defined in the class
        `tf.VariableAggregation`.
      distribute_strategy: DistributionStrategy under which this variable
        was created.
      shape: (optional) The shape of this variable. If None, the shape of
        `initial_value` will be used. When setting this argument to
        `tf.TensorShape(None)` (representing an unspecified shape), the variable
        can be assigned with values of different shapes.

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.

    @compatibility(eager)
    When Eager Execution is enabled, variables are never added to collections.
    It is not implicitly added to the `GLOBAL_VARIABLES` or
    `TRAINABLE_VARIABLES` collections, and the `collections` argument is
    ignored.
    @end_compatibility
    """
        synchronization, aggregation, trainable = (
            variables.validate_synchronization_aggregation_trainable(
                synchronization, aggregation, trainable, name))
        if initial_value is None:
            raise ValueError("initial_value must be specified.")
        init_from_fn = callable(initial_value)

        if isinstance(initial_value, ops.Tensor) and hasattr(
                initial_value,
                "graph") and initial_value.graph.building_function:
            raise ValueError(
                "Tensor-typed variable initializers must either be "
                "wrapped in an init_scope or callable "
                "(e.g., `tf.Variable(lambda : "
                "tf.truncated_normal([10, 40]))`) when building "
                "functions. Please file a feature request if this "
                "restriction inconveniences you.")

        if collections is None:
            collections = [ops.GraphKeys.GLOBAL_VARIABLES]
        if not isinstance(collections, (list, tuple, set)):
            raise ValueError(
                "collections argument to Variable constructor must be a list, tuple, "
                "or set. Got %s of type %s" % (collections, type(collections)))
        if constraint is not None and not callable(constraint):
            raise ValueError("The `constraint` argument must be a callable.")

        if isinstance(initial_value, trackable.CheckpointInitialValue):
            self._maybe_initialize_trackable()
            self._update_uid = initial_value.checkpoint_position.restore_uid
            initial_value = initial_value.wrapped_value

        if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
            collections = list(collections) + [
                ops.GraphKeys.TRAINABLE_VARIABLES
            ]
        with ops.init_scope():
            self._in_graph_mode = not context.executing_eagerly()
            with ops.name_scope(
                    name, "TrainableWrapper",
                [] if init_from_fn else [initial_value]) as name:
                # pylint: disable=protected-access
                handle_name = ops.name_from_scope_name(name)
                handle_name = (handle_name or "TrainableWrapperHandle")
                if self._in_graph_mode:
                    shared_name = handle_name
                    unique_id = shared_name
                else:
                    # When in eager mode use a uid for the shared_name, to prevent
                    # accidental sharing.
                    unique_id = "%s_%d" % (handle_name, ops.uid())
                    shared_name = None  #context.shared_name()
                # Use attr_scope and device(None) to simulate the behavior of
                # colocate_with when the variable we want to colocate with doesn't
                # yet exist.
                device_context_manager = (ops.device if self._in_graph_mode
                                          else ops.NullContextmanager)
                attr = attr_value_pb2.AttrValue(
                    list=attr_value_pb2.AttrValue.ListValue(
                        s=[compat.as_bytes("loc:@%s" % handle_name)]))
                with ops.get_default_graph()._attr_scope({"_class": attr}):
                    with ops.name_scope("Initializer"), device_context_manager(
                            None):
                        initial_value = ops.convert_to_tensor(
                            initial_value() if init_from_fn else initial_value,
                            name="initial_value",
                            dtype=dtype)
                    if shape is None:
                        shape = initial_value.shape
                    handle = resource_variable_ops.eager_safe_variable_handle(
                        initial_value=initial_value,
                        shape=None,  # shape,
                        shared_name=shared_name,
                        name=name,
                        graph_mode=self._in_graph_mode)
                # pylint: disable=protected-access
                if (self._in_graph_mode and initial_value is not None
                        and initial_value.op._get_control_flow_context()
                        is not None):
                    raise ValueError(
                        "Initializer for variable %s is from inside a control-flow "
                        "construct, such as a loop or conditional. When creating a "
                        "variable inside a loop or conditional, use a lambda as the "
                        "initializer." % name)
                # pylint: enable=protected-access
                dtype = initial_value.dtype.base_dtype

                if self._in_graph_mode:
                    with ops.name_scope("IsInitialized"):
                        is_initialized_op = (gen_resource_variable_ops.
                                             var_is_initialized_op(handle))
                    if initial_value is not None:
                        # pylint: disable=g-backslash-continuation
                        with ops.name_scope("Assign") as n, \
                            ops.colocate_with(None, ignore_existing=True), \
                            ops.device(handle.device):
                            # pylint: disable=protected-access
                            initializer_op = (
                                gen_resource_variable_ops.assign_variable_op(
                                    handle,
                                    variables.
                                    _try_guard_against_uninitialized_dependencies(
                                        name, initial_value),
                                    name=n))
                            # pylint: enable=protected-access
                        # pylint: enable=g-backslash-continuation
                    with ops.name_scope("Read"):
                        # Manually assign reads to the handle's device to avoid log
                        # messages.
                        with ops.device(handle.device):
                            with ops.control_dependencies([
                                    gen_resource_variable_ops.
                                    assign_variable_op(
                                        handle,
                                        self.prefetch_values(),
                                        name="AssignBeforeInitRead")
                            ]):
                                value = gen_resource_variable_ops.read_variable_op(
                                    handle, dtype)
                        graph_element = value
                        if caching_device is not None:
                            # Variables may be created in a tf.device() or ops.colocate_with()
                            # context. At the same time, users would expect caching device to
                            # be independent of this context, and/or would not expect the
                            # current device context to be merged with the caching device
                            # spec.  Therefore we reset the colocation stack before creating
                            # the cached value. Note that resetting the colocation stack will
                            # also reset the device stack.
                            with ops.colocate_with(None, ignore_existing=True):
                                with ops.device(caching_device):
                                    cached_value = array_ops.identity(value)
                        else:
                            cached_value = None
                else:
                    gen_resource_variable_ops.assign_variable_op(
                        handle, initial_value)
                    is_initialized_op = None
                    initializer_op = None
                    graph_element = None
                    if caching_device:
                        with ops.device(caching_device):
                            with ops.control_dependencies([
                                    gen_resource_variable_ops.
                                    assign_variable_op(
                                        handle,
                                        self.prefetch_values(),
                                        name="AssignBeforeInitRead")
                            ]):
                                cached_value = gen_resource_variable_ops.read_variable_op(
                                    handle, dtype)
                    else:
                        cached_value = None
                if not context.executing_eagerly():
                    # Eager variables are only added to collections if they are part of an
                    # eager variable store (otherwise in an interactive session they would
                    # hog memory and cause OOM). This is done in ops/variable_scope.py.
                    ops.add_to_collections(collections, self)
                elif ops.GraphKeys.GLOBAL_STEP in collections:
                    ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self)
            initial_value = initial_value if self._in_graph_mode else None
            super(resource_variable_ops.ResourceVariable,
                  self).__init__(trainable=trainable,
                                 shape=shape,
                                 dtype=dtype,
                                 handle=handle,
                                 synchronization=synchronization,
                                 constraint=constraint,
                                 aggregation=aggregation,
                                 distribute_strategy=distribute_strategy,
                                 name=name,
                                 unique_id=unique_id,
                                 handle_name=handle_name,
                                 graph_element=graph_element,
                                 initial_value=initial_value,
                                 initializer_op=initializer_op,
                                 is_initialized_op=is_initialized_op,
                                 cached_value=cached_value)
コード例 #38
0
  def _init_from_args(self,
                      initial_value=None,
                      trainable=True,
                      collections=None,
                      validate_shape=True,
                      caching_device=None,
                      name=None,
                      dtype=None):

    """Creates a variable.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called.
        (Note that initializer functions from init_ops.py must first be bound
         to a shape before being used here.)
      trainable: If `True`, the default, also adds the variable to the graph
        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
        the default list of variables to use by the `Optimizer` classes.
      collections: List of graph collections keys. The new variable is added to
        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
      validate_shape: Ignored. Provided for compatibility with tf.Variable.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.
    """
    if initial_value is None:
      raise ValueError("initial_value must be specified.")
    init_from_fn = callable(initial_value)

    if collections is None:
      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    if not isinstance(collections, (list, tuple, set)):
      raise ValueError(
          "collections argument to Variable constructor must be a list, tuple, "
          "or set. Got %s of type %s" % (collections, type(collections)))
    if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
    self._save_slice_info = None
    with ops.control_dependencies(None):
      with ops.name_scope(name, "Variable", [] if init_from_fn else
                          [initial_value]) as name:
        # pylint: disable=protected-access
        true_name = ops._name_from_scope_name(name)
        if init_from_fn:
          # Use attr_scope and device(None) to simulate the behavior of
          # colocate_with when the variable we want to colocate with doesn't
          # yet exist.
          attr = attr_value_pb2.AttrValue(
              list=attr_value_pb2.AttrValue.ListValue(
                  s=[compat.as_bytes("loc:@%s" % true_name)]))
          with ops.get_default_graph()._attr_scope({"_class": attr}):
            with ops.name_scope("Initializer"), ops.device(None):
              self._initial_value = ops.convert_to_tensor(
                  initial_value(), name="initial_value", dtype=dtype)
            self._handle = gen_resource_variable_ops.var_handle_op(
                shape=self._initial_value.get_shape(),
                dtype=self._initial_value.dtype.base_dtype,
                shared_name=true_name, name=name)
        # pylint: enable=protected-access

        # Or get the initial value from a Tensor or Python object.
        else:
          self._initial_value = ops.convert_to_tensor(
              initial_value, name="initial_value", dtype=dtype)
          self._handle = gen_resource_variable_ops.var_handle_op(
              shape=self._initial_value.get_shape(),
              dtype=self._initial_value.dtype.base_dtype,
              shared_name=true_name, name=name)

        self._dtype = self._initial_value.dtype.base_dtype

        with ops.name_scope("IsInitialized"):
          self._is_initialized_op = (
              gen_resource_variable_ops.var_is_initialized_op(self._handle))
        if initial_value is not None:
          with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
            self._initialize_op = gen_resource_variable_ops.assign_variable_op(
                self._handle, self._initial_value, name=n)
        with ops.name_scope("Read"), ops.colocate_with(self._handle):
          # Manually assign reads to the handle's device to avoid log messages.
          with ops.device(self._handle.device):
            value = gen_resource_variable_ops.read_variable_op(
                self._handle, dtype=self._dtype)
          self._graph_element = value
          if caching_device is not None:
            # Variables may be created in a tf.device() or ops.colocate_with()
            # context. At the same time, users would expect caching device to be
            # independent of this context, and/or would not expect the current
            # device context to be merged with the caching device spec.
            # Therefore we reset the colocation stack before creating the cached
            # value. Note that resetting the colocation stack will also reset
            # the device stack.
            with ops.colocate_with(None, ignore_existing=True):
              with ops.device(caching_device):
                self._cached_value = array_ops.identity(value)
          else:
            self._cached_value = None
          ops.add_to_collections(collections, self)
コード例 #39
0
 def f():
   for i in math_ops.range(3):
     ops.get_default_graph().experimental_acd_manager.run_independently(
         gen_resource_variable_ops.assign_variable_op(v.handle, i))