예제 #1
0
def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
  """Creates a variable handle with information to do shape inference."""
  container = ops.get_default_graph()._container  # pylint: disable=protected-access
  if container is None:
    container = ""
  handle = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
                                               shared_name=shared_name,
                                               name=name,
                                               container=container)
  if graph_mode:
    return handle

  with context.graph_mode(), ops.Graph().as_default() as graph:
    h = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
                                            shared_name=shared_name,
                                            name=name,
                                            container=container)

    # Tensor._handle_data contains information for the shape-inference code to
    # know the shape and dtype of the variable pointed to by a handle. Since
    # shape inference doesn't run in eager mode we copy this data here for when
    # the handle is captured by an eager mode function.
    # pylint: disable=protected-access
    handle._handle_data = resource_variable_ops.get_resource_handle_data(h)
    # pylint: enable=protected-access
  # Clean up op->graph->op reference cycles.
  ops.dismantle_graph(graph)
  return handle
  def testSharedName(self):
    v = resource_variable_ops.ResourceVariable(300.0, name="var4")
    self.evaluate(variables.global_variables_initializer())

    w = resource_variable_ops.var_handle_op(
        dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4")
    w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
    self.assertEqual(300.0, self.evaluate(w_read))

    x = resource_variable_ops.var_handle_op(
        dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5")
    with self.assertRaisesOpError("Resource .*/var5/.* does not exist"):
      x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
      self.evaluate(x_read)
  def testSharedName(self):
    with self.test_session():
      v = resource_variable_ops.ResourceVariable(300.0, name="var1")
      v.initializer.run()

      w = resource_variable_ops.var_handle_op(
          dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1")
      w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
      self.assertEqual(300.0, w_read.eval())

      x = resource_variable_ops.var_handle_op(
          dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1/")
      x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
      with self.assertRaisesOpError("Resource .*/var1//.* does not exist"):
        _ = x_read.eval()
 def testCreateRead(self):
   handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
   self.evaluate(resource_variable_ops.assign_variable_op(
       handle, constant_op.constant(1, dtype=dtypes.int32)))
   value = self.evaluate(
       resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32))
   self.assertAllEqual(1, value)
예제 #5
0
    def _custom_getter(getter=None, name=None, shape=None, dtype=dtypes.float32,  # pylint: disable=missing-docstring
                       initializer=None, regularizer=None, reuse=None,
                       trainable=True, collections=None, caching_device=None,  # pylint: disable=redefined-outer-name
                       partitioner=None, validate_shape=True,
                       use_resource=None):
      del getter, regularizer, collections, caching_device, partitioner
      del use_resource, validate_shape
      if name in self.tf_variables:
        if reuse:
          return self.tf_variables[name].initialized_value()
        else:
          raise ValueError("Specified reuse=%s but tried to reuse variables."
                           % reuse)
      # TODO(apassos): ensure this is on the same device as above
      v = _CapturedVariable(name, initializer, shape, dtype, trainable)
      self.variables[name] = v

      graph_mode_resource = resource_variable_ops.var_handle_op(
          shared_name=name, shape=shape, dtype=dtype)
      if initializer is None:
        initializer = _default_initializer(name, shape, dtype)
      resource_variable_ops.assign_variable_op(
          graph_mode_resource, initializer(shape, dtype))
      return _VariableFromResource(
          graph_mode_resource, dtype, name, shape=v.shape)
 def testAssignAdd(self):
     with self.test_session():
         handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
         resource_variable_ops.assign_variable_op(handle, constant_op.constant(1, dtype=dtypes.int32)).run()
         resource_variable_ops.assign_add_variable_op(handle, constant_op.constant(1, dtype=dtypes.int32)).run()
         read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
         self.assertEqual(read.eval(), 2)
 def testHandleDtypeShapeMatch(self):
     with self.test_session():
         handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
         with self.assertRaises(ValueError):
             resource_variable_ops.assign_variable_op(handle, constant_op.constant(0.0, dtype=dtypes.float32)).run()
         with self.assertRaises(ValueError):
             resource_variable_ops.assign_variable_op(handle, constant_op.constant([0], dtype=dtypes.int32)).run()
         resource_variable_ops.assign_variable_op(handle, constant_op.constant(0, dtype=dtypes.int32)).run()
 def testDtypeSurvivesIdentity(self):
   with self.test_session():
     handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
     id_handle = array_ops.identity(handle)
     resource_variable_ops.assign_variable_op(id_handle,
                                              constant_op.constant(
                                                  0,
                                                  dtype=dtypes.int32)).run()
 def testReadVariableDtypeMismatchEager(self):
   with context.eager_mode():
     handle = resource_variable_ops.var_handle_op(
         dtype=dtypes.int32, shape=[1], name="foo")
     with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                  "Trying to read variable with wrong dtype. "
                                  "Expected float got int32."):
       _ = resource_variable_ops.read_variable_op(handle, dtype=dtypes.float32)
 def testCreateRead(self):
   with self.test_session():
     handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
     resource_variable_ops.assign_variable_op(
         handle, constant_op.constant(1, dtype=dtypes.int32)).run()
     value = resource_variable_ops.read_variable_op(
         handle, dtype=dtypes.int32).eval()
     self.assertAllEqual(1, value)
  def testSharedName(self):
    with self.cached_session():
      v = resource_variable_ops.ResourceVariable(300.0, name="var4")
      variables.global_variables_initializer().run()

      w = resource_variable_ops.var_handle_op(
          dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4",
          # Needed in Eager since we get a unique container name by default.
          container=ops.get_default_graph()._container)
      w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
      self.assertEqual(300.0, self.evaluate(w_read))

      x = resource_variable_ops.var_handle_op(
          dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5",
          container=ops.get_default_graph()._container)
      with self.assertRaisesOpError("Resource .*/var5/.* does not exist"):
        resource_variable_ops.read_variable_op(x, v.dtype.base_dtype).eval()
 def testScatterAdd(self):
     with self.test_session():
         handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[1, 1])
         resource_variable_ops.assign_variable_op(handle, constant_op.constant([[1]], dtype=dtypes.int32)).run()
         resource_variable_ops.resource_scatter_add(
             handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)
         ).run()
         read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
         self.assertEqual(read.eval(), [[3]])
  def testSharedNameWithNamescope(self):
    with ops.name_scope("foo"):
      v = resource_variable_ops.ResourceVariable(300.0, name="var3")
      self.evaluate(variables.global_variables_initializer())

    w = resource_variable_ops.var_handle_op(
        dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var3")
    w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
    self.assertEqual(300.0, self.evaluate(w_read))
 def testScatterAdd(self):
   handle = resource_variable_ops.var_handle_op(
       dtype=dtypes.int32, shape=[1, 1])
   self.evaluate(resource_variable_ops.assign_variable_op(
       handle, constant_op.constant([[1]], dtype=dtypes.int32)))
   self.evaluate(resource_variable_ops.resource_scatter_add(
       handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)))
   read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
   self.assertEqual(self.evaluate(read), [[3]])
 def testAssignAdd(self):
   handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
   self.evaluate(resource_variable_ops.assign_variable_op(
       handle, constant_op.constant(1, dtype=dtypes.int32)))
   self.evaluate(resource_variable_ops.assign_add_variable_op(
       handle, constant_op.constant(1, dtype=dtypes.int32)))
   read = self.evaluate(
       resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32))
   self.assertEqual(read, 2)
  def testSharedName(self):
    v = resource_variable_ops.ResourceVariable(300.0, name="var1")
    self.evaluate(variables.global_variables_initializer())

    w = resource_variable_ops.var_handle_op(
        dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1")
    w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
    self.assertEqual(300.0, self.evaluate(w_read))

    x = resource_variable_ops.var_handle_op(
        dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var2")
    if context.in_graph_mode():
      with self.assertRaisesOpError("Resource .*/var2/.* does not exist"):
        x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
        self.evaluate(x_read)
    else:
      with self.assertRaisesRegexp(errors.NotFoundError,
                                   "Attempted to read a nonexistent variable."):
        _ = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
  def testSharedNameWithNamescope(self):
    with self.test_session():
      with ops.name_scope("foo"):
        v = resource_variable_ops.ResourceVariable(300.0, name="var1")
        v.initializer.run()

      w = resource_variable_ops.var_handle_op(
          dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var1")
      w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
      self.assertEqual(300.0, w_read.eval())
 def testScatterUpdateString(self):
   handle = resource_variable_ops.var_handle_op(
       dtype=dtypes.string, shape=[1, 1])
   self.evaluate(resource_variable_ops.assign_variable_op(
       handle, constant_op.constant([["a"]], dtype=dtypes.string)))
   self.evaluate(resource_variable_ops.resource_scatter_update(
       handle, [0], constant_op.constant([["b"]], dtype=dtypes.string)))
   read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.string)
   self.assertEqual(compat.as_bytes(self.evaluate(read)[0][0]),
                    compat.as_bytes("b"))
 def testAssignVariableDtypeMismatchEager(self):
   with context.eager_mode():
     handle = resource_variable_ops.var_handle_op(
         dtype=dtypes.int32, shape=[1], name="foo")
     resource_variable_ops.assign_variable_op(
         handle, constant_op.constant([1]))
     with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                  "Trying to assign variable with wrong "
                                  "dtype. Expected int32 got float."):
       resource_variable_ops.assign_variable_op(
           handle, constant_op.constant([1.], dtype=dtypes.float32))
 def testDestroyResource(self):
   v = resource_variable_ops.ResourceVariable(3.0, name="var0")
   self.evaluate(variables.global_variables_initializer())
   self.assertEqual(3.0, self.evaluate(v.value()))
   self.evaluate(resource_variable_ops.destroy_resource_op(v.handle))
   with self.assertRaises(errors.FailedPreconditionError):
     self.evaluate(v.value())
   # Handle to a resource not actually created.
   handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
   # Should raise no exception
   self.evaluate(resource_variable_ops.destroy_resource_op(
       handle, ignore_lookup_error=True))
예제 #21
0
 def testScatterMaxScalar(self):
   with self.test_session() as sess, self.test_scope():
     handle = resource_variable_ops.var_handle_op(
         dtype=dtypes.int32, shape=[1, 1])
     sess.run(
         resource_variable_ops.assign_variable_op(
             handle, constant_op.constant([[6]], dtype=dtypes.int32)))
     sess.run(
         resource_variable_ops.resource_scatter_max(
             handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
     read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
     self.assertEqual(self.evaluate(read), [[6]])
예제 #22
0
 def testScatterSub(self):
   with self.test_session() as sess, self.test_scope():
     handle = resource_variable_ops.var_handle_op(
         dtype=dtypes.int32, shape=[2, 1])
     sess.run(
         resource_variable_ops.assign_variable_op(
             handle, constant_op.constant([[4], [1]], dtype=dtypes.int32)))
     sess.run(
         resource_variable_ops.resource_scatter_sub(
             handle, [1], constant_op.constant([[2]], dtype=dtypes.int32)))
     read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
     self.assertAllEqual(self.evaluate(read), [[4], [-1]])
예제 #23
0
 def testScatterDiv(self):
   with self.test_session() as sess, self.test_scope():
     handle = resource_variable_ops.var_handle_op(
         dtype=dtypes.int32, shape=[1, 1])
     sess.run(
         resource_variable_ops.assign_variable_op(
             handle, constant_op.constant([[6]], dtype=dtypes.int32)))
     sess.run(
         resource_variable_ops.resource_scatter_div(
             handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
     read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
     self.assertAllEqual(sess.run(read), [[2]])
 def testDestroyResource(self):
   with self.test_session() as sess:
     v = resource_variable_ops.ResourceVariable(3.0)
     variables.global_variables_initializer().run()
     self.assertEqual(3.0, v.value().eval())
     sess.run(resource_variable_ops.destroy_resource_op(v.handle))
     with self.assertRaises(errors.NotFoundError):
       v.value().eval()
     # Handle to a resource not actually created.
     handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
     # Should raise no exception
     sess.run(resource_variable_ops.destroy_resource_op(
         handle, ignore_lookup_error=True))
 def testManyAssigns(self):
     with self.test_session() as session:
         handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
         create = resource_variable_ops.assign_variable_op(handle, constant_op.constant(1, dtype=dtypes.int32))
         with ops.control_dependencies([create]):
             first_read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
         with ops.control_dependencies([first_read]):
             write = resource_variable_ops.assign_variable_op(handle, constant_op.constant(2, dtype=dtypes.int32))
         with ops.control_dependencies([write]):
             second_read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
         f, s = session.run([first_read, second_read])
         self.assertEqual(f, 1)
         self.assertEqual(s, 2)
예제 #26
0
 def testScatterUpdate(self):
     with self.test_session() as sess, self.test_scope():
         handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32,
                                                      shape=[1, 1])
         sess.run(
             resource_variable_ops.assign_variable_op(
                 handle, constant_op.constant([[6]], dtype=dtypes.int32)))
         sess.run(
             resource_variable_ops.resource_scatter_update(
                 handle, [0], constant_op.constant([[3]],
                                                   dtype=dtypes.int32)))
         read = resource_variable_ops.read_variable_op(handle,
                                                       dtype=dtypes.int32)
         self.assertEqual(sess.run(read), [[3]])
 def testScatterAdd(self):
     with ops.device("cpu:0"):
         handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32,
                                                      shape=[1, 1])
         self.evaluate(
             resource_variable_ops.assign_variable_op(
                 handle, constant_op.constant([[1]], dtype=dtypes.int32)))
         self.evaluate(
             resource_variable_ops.resource_scatter_add(
                 handle, [0], constant_op.constant([[2]],
                                                   dtype=dtypes.int32)))
         read = resource_variable_ops.read_variable_op(handle,
                                                       dtype=dtypes.int32)
         self.assertEqual(self.evaluate(read), [[3]])
 def testScatterUpdateString(self):
     handle = resource_variable_ops.var_handle_op(dtype=dtypes.string,
                                                  shape=[1, 1])
     self.evaluate(
         resource_variable_ops.assign_variable_op(
             handle, constant_op.constant([["a"]], dtype=dtypes.string)))
     self.evaluate(
         resource_variable_ops.resource_scatter_update(
             handle, [0], constant_op.constant([["b"]],
                                               dtype=dtypes.string)))
     read = resource_variable_ops.read_variable_op(handle,
                                                   dtype=dtypes.string)
     self.assertEqual(compat.as_bytes(self.evaluate(read)[0][0]),
                      compat.as_bytes("b"))
예제 #29
0
 def testScatterNdAddOps(self):
   with self.test_session() as sess, self.test_scope():
     handle = resource_variable_ops.var_handle_op(
         dtype=dtypes.float32, shape=[8])
     sess.run(
         resource_variable_ops.assign_variable_op(
             handle, constant_op.constant([1] * 8, dtype=dtypes.float32)))
     indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
     updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32)
     expected = np.array([1, 12, 1, 11, 10, 1, 1, 13])
     sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates))
     read = resource_variable_ops.read_variable_op(
         handle, dtype=dtypes.float32)
     self.assertAllClose(expected, self.evaluate(read))
  def testSharedNameWithNamescope(self):
    with self.cached_session():
      with ops.name_scope("foo"):
        v = resource_variable_ops.ResourceVariable(300.0, name="var6")
        self.assertEqual("foo/var6", v._shared_name)  # pylint: disable=protected-access
        self.assertEqual("foo/var6:0", v.name)
        self.evaluate(variables.global_variables_initializer())

      w = resource_variable_ops.var_handle_op(
          dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var6",
          # Needed in Eager since we get a unique container name by default.
          container=ops.get_default_graph()._container)
      w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
      self.assertEqual(300.0, self.evaluate(w_read))
예제 #31
0
 def testDestroyResource(self):
     v = resource_variable_ops.ResourceVariable(3.0, name="var0")
     self.evaluate(variables.global_variables_initializer())
     self.assertEqual(3.0, self.evaluate(v.value()))
     self.evaluate(resource_variable_ops.destroy_resource_op(v.handle))
     with self.assertRaises(errors.FailedPreconditionError):
         self.evaluate(v.value())
     # Handle to a resource not actually created.
     handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32,
                                                  shape=[])
     # Should raise no exception
     self.evaluate(
         resource_variable_ops.destroy_resource_op(
             handle, ignore_lookup_error=True))
예제 #32
0
 def testScatterNdAddOps(self):
   with self.test_session() as sess, self.test_scope():
     handle = resource_variable_ops.var_handle_op(
         dtype=dtypes.float32, shape=[8])
     sess.run(
         resource_variable_ops.assign_variable_op(
             handle, constant_op.constant([1] * 8, dtype=dtypes.float32)))
     indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
     updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32)
     expected = np.array([1, 12, 1, 11, 10, 1, 1, 13])
     sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates))
     read = resource_variable_ops.read_variable_op(
         handle, dtype=dtypes.float32)
     self.assertAllClose(expected, self.evaluate(read))
예제 #33
0
 def testHandleDtypeShapeMatch(self):
     with self.test_session():
         handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32,
                                                      shape=[])
         with self.assertRaises(ValueError):
             resource_variable_ops.assign_variable_op(
                 handle, constant_op.constant(0.0,
                                              dtype=dtypes.float32)).run()
         with self.assertRaises(ValueError):
             resource_variable_ops.assign_variable_op(
                 handle, constant_op.constant([0],
                                              dtype=dtypes.int32)).run()
         resource_variable_ops.assign_variable_op(
             handle, constant_op.constant(0, dtype=dtypes.int32)).run()
  def testSharedNameWithNamescope(self):
    with self.test_session():
      with ops.name_scope("foo"):
        v = resource_variable_ops.ResourceVariable(300.0, name="var6")
        self.assertEqual("foo/var6", v._shared_name)  # pylint: disable=protected-access
        self.assertEqual("foo/var6:0", v.name)
        self.evaluate(variables.global_variables_initializer())

      w = resource_variable_ops.var_handle_op(
          dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var6",
          # Needed in Eager since we get a unique container name by default.
          container=ops.get_default_graph()._container)
      w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
      self.assertEqual(300.0, self.evaluate(w_read))
 def testDestroyResource(self):
     with self.test_session() as sess:
         v = resource_variable_ops.ResourceVariable(3.0)
         variables.global_variables_initializer().run()
         self.assertEqual(3.0, v.value().eval())
         sess.run(resource_variable_ops.destroy_resource_op(v.handle))
         with self.assertRaises(errors.NotFoundError):
             v.value().eval()
         # Handle to a resource not actually created.
         handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32,
                                                      shape=[])
         # Should raise no exception
         sess.run(
             resource_variable_ops.destroy_resource_op(
                 handle, ignore_lookup_error=True))
예제 #36
0
 def testScatterSub(self):
     with self.session() as sess, self.test_scope():
         handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32,
                                                      shape=[2, 1])
         sess.run(
             resource_variable_ops.assign_variable_op(
                 handle, constant_op.constant([[4], [1]],
                                              dtype=dtypes.int32)))
         sess.run(
             resource_variable_ops.resource_scatter_sub(
                 handle, [1], constant_op.constant([[2]],
                                                   dtype=dtypes.int32)))
         read = resource_variable_ops.read_variable_op(handle,
                                                       dtype=dtypes.int32)
         self.assertAllEqual(self.evaluate(read), [[4], [-1]])
예제 #37
0
 def testScatterAdd(self):
     with self.test_session() as sess, self.test_scope():
         handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32,
                                                      shape=[2, 1])
         sess.run(
             resource_variable_ops.assign_variable_op(
                 handle, constant_op.constant([[1], [7]],
                                              dtype=dtypes.int32)))
         sess.run(
             resource_variable_ops.resource_scatter_add(
                 handle, [0], constant_op.constant([[2]],
                                                   dtype=dtypes.int32)))
         read = resource_variable_ops.read_variable_op(handle,
                                                       dtype=dtypes.int32)
         self.assertAllEqual(sess.run(read), [[3], [7]])
예제 #38
0
    def testSharedName(self):
        with self.test_session():
            v = resource_variable_ops.ResourceVariable(300.0, name="var4")
            variables.global_variables_initializer().run()

            w = resource_variable_ops.var_handle_op(
                dtype=v.dtype.base_dtype,
                shape=v.get_shape(),
                shared_name="var4",
                # Needed in Eager since we get a unique container name by default.
                container=ops.get_default_graph()._container)
            w_read = resource_variable_ops.read_variable_op(
                w, v.dtype.base_dtype)
            self.assertEqual(300.0, w_read.eval())

            x = resource_variable_ops.var_handle_op(
                dtype=v.dtype.base_dtype,
                shape=v.get_shape(),
                shared_name="var5",
                container=ops.get_default_graph()._container)
            with self.assertRaisesOpError(
                    "Resource .*/var5/.* does not exist"):
                resource_variable_ops.read_variable_op(
                    x, v.dtype.base_dtype).eval()
예제 #39
0
def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
    """Creates a variable handle with information to do shape inference."""
    container = ops.get_default_graph()._container  # pylint: disable=protected-access
    if container is None:
        container = ""
    handle = resource_variable_ops.var_handle_op(shape=shape,
                                                 dtype=dtype,
                                                 shared_name=shared_name,
                                                 name=name,
                                                 container=container)
    if graph_mode:
        return handle

    with context.graph_mode(), ops.Graph().as_default() as graph:
        h = resource_variable_ops.var_handle_op(shape=shape,
                                                dtype=dtype,
                                                shared_name=shared_name,
                                                name=name,
                                                container=container)

        # Tensor._handle_data contains information for the shape-inference code to
        # know the shape and dtype of the variable pointed to by a handle. Since
        # shape inference doesn't run in eager mode we copy this data here for when
        # the handle is captured by an eager mode function.
        # pylint: disable=protected-access
        if ops._USE_C_SHAPES:
            handle._handle_data = resource_variable_ops.get_resource_handle_data(
                h)
        else:
            if h._handle_data is None:
                ops.set_shape_and_handle_data_for_outputs(h.op)
            handle._handle_data = h._handle_data
        # pylint: enable=protected-access
    # Clean up op->graph->op reference cycles.
    ops.dismantle_graph(graph)
    return handle
 def testScatterMin(self):
   with ops.device("cpu:0"):
     handle = resource_variable_ops.var_handle_op(
         dtype=dtypes.int32, shape=[1, 1])
     self.evaluate(
         resource_variable_ops.assign_variable_op(handle,
                                                  constant_op.constant(
                                                      [[6]],
                                                      dtype=dtypes.int32)))
     self.evaluate(
         resource_variable_ops.resource_scatter_min(handle, [0],
                                                    constant_op.constant(
                                                        [[3]],
                                                        dtype=dtypes.int32)))
     read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
     self.assertEqual(self.evaluate(read), [[3]])
 def testManyAssigns(self):
   handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
   create = resource_variable_ops.assign_variable_op(
       handle, constant_op.constant(1, dtype=dtypes.int32))
   with ops.control_dependencies([create]):
     first_read = resource_variable_ops.read_variable_op(
         handle, dtype=dtypes.int32)
   with ops.control_dependencies([first_read]):
     write = resource_variable_ops.assign_variable_op(
         handle, constant_op.constant(2, dtype=dtypes.int32))
   with ops.control_dependencies([write]):
     second_read = resource_variable_ops.read_variable_op(
         handle, dtype=dtypes.int32)
   f, s = self.evaluate([first_read, second_read])
   self.assertEqual(f, 1)
   self.assertEqual(s, 2)
예제 #42
0
        def _custom_getter(
                getter=None,
                name=None,
                shape=None,
                dtype=dtypes.float32,  # pylint: disable=missing-docstring
                initializer=None,
                regularizer=None,
                reuse=None,
                trainable=True,
                collections=None,
                caching_device=None,  # pylint: disable=redefined-outer-name
                partitioner=None,
                validate_shape=True,
                use_resource=None):
            del getter, regularizer, collections, caching_device, partitioner
            del use_resource, validate_shape
            if name in self.tf_variables:
                if reuse:
                    return self.tf_variables[name].initialized_value()
                else:
                    raise ValueError(
                        "Specified reuse=%s but tried to reuse variables." %
                        reuse)
            # TODO(apassos): ensure this is on the same device as above
            v = _CapturedVariable(name, initializer, shape, dtype, trainable)
            self.variables[name] = v

            graph_mode_resource = resource_variable_ops.var_handle_op(
                shared_name=name, shape=shape, dtype=dtype)
            if initializer is None:
                initializer = _default_initializer(name, shape, dtype)
            with tf_ops.control_dependencies([
                    resource_variable_ops.assign_variable_op(
                        graph_mode_resource, initializer(shape, dtype))
            ]):
                handle = array_ops.identity(v.variable.handle)
            return _VariableFromResource(handle, dtype, name, shape=v.shape)
 def testDtypeSurvivesIdentity(self):
   handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
   id_handle = array_ops.identity(handle)
   self.evaluate(resource_variable_ops.assign_variable_op(
       id_handle, constant_op.constant(0, dtype=dtypes.int32)))
 def testDtypeSurvivesIdentity(self):
   handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
   id_handle = array_ops.identity(handle)
   self.evaluate(resource_variable_ops.assign_variable_op(
       id_handle, constant_op.constant(0, dtype=dtypes.int32)))
예제 #45
0
 def testDtypeSurvivesIdentity(self):
   with self.test_session():
     handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
     id_handle = array_ops.identity(handle)
     resource_variable_ops.assign_variable_op(
         id_handle, constant_op.constant(0, dtype=dtypes.int32)).run()
 def testFetchHandle(self):
   with self.test_session():
     handle = resource_variable_ops.var_handle_op(
         dtype=dtypes.int32, shape=[1], name="foo")
     self.assertGreater(len(handle.eval()), 0)
 def testUnprintableHandle(self):
   with context.eager_mode():
     handle = resource_variable_ops.var_handle_op(
         dtype=dtypes.int32, shape=[1], name="foo")
     self.assertIn("<unprintable>", str(handle))
     self.assertIn("<unprintable>", repr(handle))
 def testUnprintableHandle(self):
   with context.eager_mode():
     handle = resource_variable_ops.var_handle_op(
         dtype=dtypes.int32, shape=[1], name="foo")
     self.assertIn("<unprintable>", str(handle))
     self.assertIn("<unprintable>", repr(handle))