def testManyAssigns(self):
     with self.test_session() as session:
         handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
         create = resource_variable_ops.assign_variable_op(handle, constant_op.constant(1, dtype=dtypes.int32))
         with ops.control_dependencies([create]):
             first_read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
         with ops.control_dependencies([first_read]):
             write = resource_variable_ops.assign_variable_op(handle, constant_op.constant(2, dtype=dtypes.int32))
         with ops.control_dependencies([write]):
             second_read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
         f, s = session.run([first_read, second_read])
         self.assertEqual(f, 1)
         self.assertEqual(s, 2)
  def testSharedName(self):
    v = resource_variable_ops.ResourceVariable(300.0, name="var4")
    self.evaluate(variables.global_variables_initializer())

    w = resource_variable_ops.var_handle_op(
        dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4")
    w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
    self.assertEqual(300.0, self.evaluate(w_read))

    x = resource_variable_ops.var_handle_op(
        dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5")
    with self.assertRaisesOpError("Resource .*/var5/.* does not exist"):
      x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
      self.evaluate(x_read)
  def testSharedName(self):
    with self.test_session():
      v = resource_variable_ops.ResourceVariable(300.0, name="var1")
      v.initializer.run()

      w = resource_variable_ops.var_handle_op(
          dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1")
      w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
      self.assertEqual(300.0, w_read.eval())

      x = resource_variable_ops.var_handle_op(
          dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1/")
      x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
      with self.assertRaisesOpError("Resource .*/var1//.* does not exist"):
        _ = x_read.eval()
 def testAssignAdd(self):
     with self.test_session():
         handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
         resource_variable_ops.assign_variable_op(handle, constant_op.constant(1, dtype=dtypes.int32)).run()
         resource_variable_ops.assign_add_variable_op(handle, constant_op.constant(1, dtype=dtypes.int32)).run()
         read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
         self.assertEqual(read.eval(), 2)
 def testCreateRead(self):
   handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
   self.evaluate(resource_variable_ops.assign_variable_op(
       handle, constant_op.constant(1, dtype=dtypes.int32)))
   value = self.evaluate(
       resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32))
   self.assertAllEqual(1, value)
 def testReadVariableDtypeMismatchEager(self):
   with context.eager_mode():
     handle = resource_variable_ops.var_handle_op(
         dtype=dtypes.int32, shape=[1], name="foo")
     with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                  "Trying to read variable with wrong dtype. "
                                  "Expected float got int32."):
       _ = resource_variable_ops.read_variable_op(handle, dtype=dtypes.float32)
 def testCreateRead(self):
   with self.test_session():
     handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
     resource_variable_ops.assign_variable_op(
         handle, constant_op.constant(1, dtype=dtypes.int32)).run()
     value = resource_variable_ops.read_variable_op(
         handle, dtype=dtypes.int32).eval()
     self.assertAllEqual(1, value)
  def testSharedName(self):
    with self.cached_session():
      v = resource_variable_ops.ResourceVariable(300.0, name="var4")
      variables.global_variables_initializer().run()

      w = resource_variable_ops.var_handle_op(
          dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4",
          # Needed in Eager since we get a unique container name by default.
          container=ops.get_default_graph()._container)
      w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
      self.assertEqual(300.0, self.evaluate(w_read))

      x = resource_variable_ops.var_handle_op(
          dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5",
          container=ops.get_default_graph()._container)
      with self.assertRaisesOpError("Resource .*/var5/.* does not exist"):
        resource_variable_ops.read_variable_op(x, v.dtype.base_dtype).eval()
예제 #9
0
  def testReturnMultipleResourceHandles(self):
    with self.test_scope():
      v1 = resource_variable_ops.ResourceVariable(1.25)
      v2 = resource_variable_ops.ResourceVariable(2.0)

      def f(v):
        return v.handle, 3.0 * v, v2.handle, v + v2

      f = function.defun(f)
      v1_handle, v1_times_3, v2_handle, variable_sum = f(v1)
      self.assertAllEqual(v1.numpy(),
                          resource_variable_ops.read_variable_op(
                              v1_handle, dtypes.float32).numpy())
      self.assertEqual(3.75, v1_times_3.numpy())
      self.assertAllEqual(v2.numpy(),
                          resource_variable_ops.read_variable_op(
                              v2_handle, dtypes.float32).numpy())
      self.assertEqual(3.25, variable_sum.numpy())
 def testScatterAdd(self):
     with self.test_session():
         handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[1, 1])
         resource_variable_ops.assign_variable_op(handle, constant_op.constant([[1]], dtype=dtypes.int32)).run()
         resource_variable_ops.resource_scatter_add(
             handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)
         ).run()
         read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
         self.assertEqual(read.eval(), [[3]])
 def testAssignAdd(self):
   handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
   self.evaluate(resource_variable_ops.assign_variable_op(
       handle, constant_op.constant(1, dtype=dtypes.int32)))
   self.evaluate(resource_variable_ops.assign_add_variable_op(
       handle, constant_op.constant(1, dtype=dtypes.int32)))
   read = self.evaluate(
       resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32))
   self.assertEqual(read, 2)
  def testSharedNameWithNamescope(self):
    with ops.name_scope("foo"):
      v = resource_variable_ops.ResourceVariable(300.0, name="var3")
      self.evaluate(variables.global_variables_initializer())

    w = resource_variable_ops.var_handle_op(
        dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var3")
    w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
    self.assertEqual(300.0, self.evaluate(w_read))
 def testScatterAdd(self):
   handle = resource_variable_ops.var_handle_op(
       dtype=dtypes.int32, shape=[1, 1])
   self.evaluate(resource_variable_ops.assign_variable_op(
       handle, constant_op.constant([[1]], dtype=dtypes.int32)))
   self.evaluate(resource_variable_ops.resource_scatter_add(
       handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)))
   read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
   self.assertEqual(self.evaluate(read), [[3]])
 def testScatterUpdateString(self):
   handle = resource_variable_ops.var_handle_op(
       dtype=dtypes.string, shape=[1, 1])
   self.evaluate(resource_variable_ops.assign_variable_op(
       handle, constant_op.constant([["a"]], dtype=dtypes.string)))
   self.evaluate(resource_variable_ops.resource_scatter_update(
       handle, [0], constant_op.constant([["b"]], dtype=dtypes.string)))
   read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.string)
   self.assertEqual(compat.as_bytes(self.evaluate(read)[0][0]),
                    compat.as_bytes("b"))
  def testSharedName(self):
    v = resource_variable_ops.ResourceVariable(300.0, name="var1")
    self.evaluate(variables.global_variables_initializer())

    w = resource_variable_ops.var_handle_op(
        dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1")
    w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
    self.assertEqual(300.0, self.evaluate(w_read))

    x = resource_variable_ops.var_handle_op(
        dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var2")
    if context.in_graph_mode():
      with self.assertRaisesOpError("Resource .*/var2/.* does not exist"):
        x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
        self.evaluate(x_read)
    else:
      with self.assertRaisesRegexp(errors.NotFoundError,
                                   "Attempted to read a nonexistent variable."):
        _ = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
  def testSharedNameWithNamescope(self):
    with self.test_session():
      with ops.name_scope("foo"):
        v = resource_variable_ops.ResourceVariable(300.0, name="var1")
        v.initializer.run()

      w = resource_variable_ops.var_handle_op(
          dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var1")
      w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
      self.assertEqual(300.0, w_read.eval())
예제 #17
0
 def testScatterDiv(self):
   with self.test_session() as sess, self.test_scope():
     handle = resource_variable_ops.var_handle_op(
         dtype=dtypes.int32, shape=[1, 1])
     sess.run(
         resource_variable_ops.assign_variable_op(
             handle, constant_op.constant([[6]], dtype=dtypes.int32)))
     sess.run(
         resource_variable_ops.resource_scatter_div(
             handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
     read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
     self.assertAllEqual(sess.run(read), [[2]])
예제 #18
0
  def testReturnResourceHandle(self):
    with self.test_scope():
      v = resource_variable_ops.ResourceVariable([[1.0, 2.0], [3.0, 4.0]])

      def f(v):
        return v.handle

      f = function.defun(f)
      handle = f(v)
      self.assertAllEqual(v.numpy(),
                          resource_variable_ops.read_variable_op(
                              handle, dtypes.float32).numpy())
예제 #19
0
 def testScatterSub(self):
   with self.test_session() as sess, self.test_scope():
     handle = resource_variable_ops.var_handle_op(
         dtype=dtypes.int32, shape=[2, 1])
     sess.run(
         resource_variable_ops.assign_variable_op(
             handle, constant_op.constant([[4], [1]], dtype=dtypes.int32)))
     sess.run(
         resource_variable_ops.resource_scatter_sub(
             handle, [1], constant_op.constant([[2]], dtype=dtypes.int32)))
     read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
     self.assertAllEqual(self.evaluate(read), [[4], [-1]])
예제 #20
0
 def testScatterMaxScalar(self):
   with self.test_session() as sess, self.test_scope():
     handle = resource_variable_ops.var_handle_op(
         dtype=dtypes.int32, shape=[1, 1])
     sess.run(
         resource_variable_ops.assign_variable_op(
             handle, constant_op.constant([[6]], dtype=dtypes.int32)))
     sess.run(
         resource_variable_ops.resource_scatter_max(
             handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
     read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
     self.assertEqual(self.evaluate(read), [[6]])
  def testSharedNameWithNamescope(self):
    with self.test_session():
      with ops.name_scope("foo"):
        v = resource_variable_ops.ResourceVariable(300.0, name="var6")
        self.assertEqual("foo/var6", v._shared_name)  # pylint: disable=protected-access
        self.assertEqual("foo/var6:0", v.name)
        self.evaluate(variables.global_variables_initializer())

      w = resource_variable_ops.var_handle_op(
          dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var6",
          # Needed in Eager since we get a unique container name by default.
          container=ops.get_default_graph()._container)
      w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
      self.assertEqual(300.0, self.evaluate(w_read))
예제 #22
0
 def testScatterUpdate(self):
     with self.test_session() as sess, self.test_scope():
         handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32,
                                                      shape=[1, 1])
         sess.run(
             resource_variable_ops.assign_variable_op(
                 handle, constant_op.constant([[6]], dtype=dtypes.int32)))
         sess.run(
             resource_variable_ops.resource_scatter_update(
                 handle, [0], constant_op.constant([[3]],
                                                   dtype=dtypes.int32)))
         read = resource_variable_ops.read_variable_op(handle,
                                                       dtype=dtypes.int32)
         self.assertEqual(sess.run(read), [[3]])
예제 #23
0
 def testScatterNdAddOps(self):
   with self.test_session() as sess, self.test_scope():
     handle = resource_variable_ops.var_handle_op(
         dtype=dtypes.float32, shape=[8])
     sess.run(
         resource_variable_ops.assign_variable_op(
             handle, constant_op.constant([1] * 8, dtype=dtypes.float32)))
     indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
     updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32)
     expected = np.array([1, 12, 1, 11, 10, 1, 1, 13])
     sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates))
     read = resource_variable_ops.read_variable_op(
         handle, dtype=dtypes.float32)
     self.assertAllClose(expected, sess.run(read))
예제 #24
0
    def testReturnResourceHandle(self):
        with self.test_scope():
            v = resource_variable_ops.ResourceVariable([[1.0, 2.0], [3.0,
                                                                     4.0]])

            def f(v):
                return v.handle

            f = function.defun(f)
            handle = f(v)
            self.assertAllEqual(
                v.numpy(),
                resource_variable_ops.read_variable_op(handle,
                                                       dtypes.float32).numpy())
예제 #25
0
 def testScatterMin(self):
     with ops.device("cpu:0"):
         handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32,
                                                      shape=[1, 1])
         self.evaluate(
             resource_variable_ops.assign_variable_op(
                 handle, constant_op.constant([[6]], dtype=dtypes.int32)))
         self.evaluate(
             resource_variable_ops.resource_scatter_min(
                 handle, [0], constant_op.constant([[3]],
                                                   dtype=dtypes.int32)))
         read = resource_variable_ops.read_variable_op(handle,
                                                       dtype=dtypes.int32)
         self.assertEqual(self.evaluate(read), [[3]])
예제 #26
0
 def testScatterNdAddOps(self):
   with self.test_session() as sess, self.test_scope():
     handle = resource_variable_ops.var_handle_op(
         dtype=dtypes.float32, shape=[8])
     sess.run(
         resource_variable_ops.assign_variable_op(
             handle, constant_op.constant([1] * 8, dtype=dtypes.float32)))
     indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
     updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32)
     expected = np.array([1, 12, 1, 11, 10, 1, 1, 13])
     sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates))
     read = resource_variable_ops.read_variable_op(
         handle, dtype=dtypes.float32)
     self.assertAllClose(expected, self.evaluate(read))
  def testSharedNameWithNamescope(self):
    with self.cached_session():
      with ops.name_scope("foo"):
        v = resource_variable_ops.ResourceVariable(300.0, name="var6")
        self.assertEqual("foo/var6", v._shared_name)  # pylint: disable=protected-access
        self.assertEqual("foo/var6:0", v.name)
        self.evaluate(variables.global_variables_initializer())

      w = resource_variable_ops.var_handle_op(
          dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var6",
          # Needed in Eager since we get a unique container name by default.
          container=ops.get_default_graph()._container)
      w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
      self.assertEqual(300.0, self.evaluate(w_read))
예제 #28
0
 def testScatterUpdateString(self):
     handle = resource_variable_ops.var_handle_op(dtype=dtypes.string,
                                                  shape=[1, 1])
     self.evaluate(
         resource_variable_ops.assign_variable_op(
             handle, constant_op.constant([["a"]], dtype=dtypes.string)))
     self.evaluate(
         resource_variable_ops.resource_scatter_update(
             handle, [0], constant_op.constant([["b"]],
                                               dtype=dtypes.string)))
     read = resource_variable_ops.read_variable_op(handle,
                                                   dtype=dtypes.string)
     self.assertEqual(compat.as_bytes(self.evaluate(read)[0][0]),
                      compat.as_bytes("b"))
예제 #29
0
  def testIndependentOpsInLoop(self):
    v = resource_variable_ops.ResourceVariable(0)
    self.evaluate(variables.global_variables_initializer())

    @def_function.function
    def f():
      for i in math_ops.range(3):
        ops.get_default_graph().experimental_acd_manager.run_independently(
            gen_resource_variable_ops.assign_variable_op(v.handle, i))

    self.evaluate(f())
    # TODO(mdan): Find a more robust way to test in loops.
    self.assertEqual(
        self.evaluate(
            resource_variable_ops.read_variable_op(v.handle, dtypes.int32)), 2)
예제 #30
0
 def testScatterAdd(self):
     with self.test_session() as sess, self.test_scope():
         handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32,
                                                      shape=[2, 1])
         sess.run(
             resource_variable_ops.assign_variable_op(
                 handle, constant_op.constant([[1], [7]],
                                              dtype=dtypes.int32)))
         sess.run(
             resource_variable_ops.resource_scatter_add(
                 handle, [0], constant_op.constant([[2]],
                                                   dtype=dtypes.int32)))
         read = resource_variable_ops.read_variable_op(handle,
                                                       dtype=dtypes.int32)
         self.assertAllEqual(sess.run(read), [[3], [7]])
예제 #31
0
 def testScatterSub(self):
     with self.session() as sess, self.test_scope():
         handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32,
                                                      shape=[2, 1])
         sess.run(
             resource_variable_ops.assign_variable_op(
                 handle, constant_op.constant([[4], [1]],
                                              dtype=dtypes.int32)))
         sess.run(
             resource_variable_ops.resource_scatter_sub(
                 handle, [1], constant_op.constant([[2]],
                                                   dtype=dtypes.int32)))
         read = resource_variable_ops.read_variable_op(handle,
                                                       dtype=dtypes.int32)
         self.assertAllEqual(self.evaluate(read), [[4], [-1]])
예제 #32
0
    def testSharedName(self):
        with self.test_session():
            v = resource_variable_ops.ResourceVariable(300.0, name="var4")
            variables.global_variables_initializer().run()

            w = resource_variable_ops.var_handle_op(
                dtype=v.dtype.base_dtype,
                shape=v.get_shape(),
                shared_name="var4",
                # Needed in Eager since we get a unique container name by default.
                container=ops.get_default_graph()._container)
            w_read = resource_variable_ops.read_variable_op(
                w, v.dtype.base_dtype)
            self.assertEqual(300.0, w_read.eval())

            x = resource_variable_ops.var_handle_op(
                dtype=v.dtype.base_dtype,
                shape=v.get_shape(),
                shared_name="var5",
                container=ops.get_default_graph()._container)
            with self.assertRaisesOpError(
                    "Resource .*/var5/.* does not exist"):
                resource_variable_ops.read_variable_op(
                    x, v.dtype.base_dtype).eval()
 def testScatterMin(self):
   with ops.device("cpu:0"):
     handle = resource_variable_ops.var_handle_op(
         dtype=dtypes.int32, shape=[1, 1])
     self.evaluate(
         resource_variable_ops.assign_variable_op(handle,
                                                  constant_op.constant(
                                                      [[6]],
                                                      dtype=dtypes.int32)))
     self.evaluate(
         resource_variable_ops.resource_scatter_min(handle, [0],
                                                    constant_op.constant(
                                                        [[3]],
                                                        dtype=dtypes.int32)))
     read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
     self.assertEqual(self.evaluate(read), [[3]])
예제 #34
0
    def testIndependentOpsRunInParallel(self):
        v = resource_variable_ops.ResourceVariable(1)
        self.evaluate(variables.global_variables_initializer())

        @def_function.function
        def f():
            gen_resource_variable_ops.assign_variable_op(v.handle, 1)
            ops.get_default_graph().experimental_acd_manager.run_independently(
                gen_resource_variable_ops.assign_variable_op(v.handle, 2))

        # A function with two identical ops, should cause a data race in most
        # conditions.
        var_values = set()
        for _ in range(1000):
            self.evaluate(f())
            var_values.add(
                self.evaluate(
                    resource_variable_ops.read_variable_op(
                        v.handle, dtypes.int32)))
        # With regular control dependencies, the function should always run the
        # first assign first, and the value 1 should never be seen.
        self.assertSetEqual(var_values, set((1, 2)))
예제 #35
0
 def inner(var1, var2):
   return (resource_variable_ops.read_variable_op(var1, dtypes.float32) +
           resource_variable_ops.read_variable_op(var2, dtypes.float32))
예제 #36
0
 def inner(var1, var2):
     return (
         resource_variable_ops.read_variable_op(var1, dtypes.float32) +
         resource_variable_ops.read_variable_op(var2, dtypes.float32))