def testManyAssigns(self): with self.test_session() as session: handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) create = resource_variable_ops.assign_variable_op(handle, constant_op.constant(1, dtype=dtypes.int32)) with ops.control_dependencies([create]): first_read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) with ops.control_dependencies([first_read]): write = resource_variable_ops.assign_variable_op(handle, constant_op.constant(2, dtype=dtypes.int32)) with ops.control_dependencies([write]): second_read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) f, s = session.run([first_read, second_read]) self.assertEqual(f, 1) self.assertEqual(s, 2)
def testSharedName(self): v = resource_variable_ops.ResourceVariable(300.0, name="var4") self.evaluate(variables.global_variables_initializer()) w = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4") w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, self.evaluate(w_read)) x = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5") with self.assertRaisesOpError("Resource .*/var5/.* does not exist"): x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype) self.evaluate(x_read)
def testSharedName(self): with self.test_session(): v = resource_variable_ops.ResourceVariable(300.0, name="var1") v.initializer.run() w = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1") w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, w_read.eval()) x = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1/") x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype) with self.assertRaisesOpError("Resource .*/var1//.* does not exist"): _ = x_read.eval()
def testAssignAdd(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) resource_variable_ops.assign_variable_op(handle, constant_op.constant(1, dtype=dtypes.int32)).run() resource_variable_ops.assign_add_variable_op(handle, constant_op.constant(1, dtype=dtypes.int32)).run() read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(read.eval(), 2)
def testCreateRead(self): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) self.evaluate(resource_variable_ops.assign_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32))) value = self.evaluate( resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)) self.assertAllEqual(1, value)
def testReadVariableDtypeMismatchEager(self): with context.eager_mode(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1], name="foo") with self.assertRaisesRegexp(errors.InvalidArgumentError, "Trying to read variable with wrong dtype. " "Expected float got int32."): _ = resource_variable_ops.read_variable_op(handle, dtype=dtypes.float32)
def testCreateRead(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) resource_variable_ops.assign_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32)).run() value = resource_variable_ops.read_variable_op( handle, dtype=dtypes.int32).eval() self.assertAllEqual(1, value)
def testSharedName(self): with self.cached_session(): v = resource_variable_ops.ResourceVariable(300.0, name="var4") variables.global_variables_initializer().run() w = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4", # Needed in Eager since we get a unique container name by default. container=ops.get_default_graph()._container) w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, self.evaluate(w_read)) x = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5", container=ops.get_default_graph()._container) with self.assertRaisesOpError("Resource .*/var5/.* does not exist"): resource_variable_ops.read_variable_op(x, v.dtype.base_dtype).eval()
def testReturnMultipleResourceHandles(self): with self.test_scope(): v1 = resource_variable_ops.ResourceVariable(1.25) v2 = resource_variable_ops.ResourceVariable(2.0) def f(v): return v.handle, 3.0 * v, v2.handle, v + v2 f = function.defun(f) v1_handle, v1_times_3, v2_handle, variable_sum = f(v1) self.assertAllEqual(v1.numpy(), resource_variable_ops.read_variable_op( v1_handle, dtypes.float32).numpy()) self.assertEqual(3.75, v1_times_3.numpy()) self.assertAllEqual(v2.numpy(), resource_variable_ops.read_variable_op( v2_handle, dtypes.float32).numpy()) self.assertEqual(3.25, variable_sum.numpy())
def testScatterAdd(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[1, 1]) resource_variable_ops.assign_variable_op(handle, constant_op.constant([[1]], dtype=dtypes.int32)).run() resource_variable_ops.resource_scatter_add( handle, [0], constant_op.constant([[2]], dtype=dtypes.int32) ).run() read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(read.eval(), [[3]])
def testAssignAdd(self): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) self.evaluate(resource_variable_ops.assign_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32))) self.evaluate(resource_variable_ops.assign_add_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32))) read = self.evaluate( resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)) self.assertEqual(read, 2)
def testSharedNameWithNamescope(self): with ops.name_scope("foo"): v = resource_variable_ops.ResourceVariable(300.0, name="var3") self.evaluate(variables.global_variables_initializer()) w = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var3") w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, self.evaluate(w_read))
def testScatterAdd(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) self.evaluate(resource_variable_ops.assign_variable_op( handle, constant_op.constant([[1]], dtype=dtypes.int32))) self.evaluate(resource_variable_ops.resource_scatter_add( handle, [0], constant_op.constant([[2]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[3]])
def testScatterUpdateString(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.string, shape=[1, 1]) self.evaluate(resource_variable_ops.assign_variable_op( handle, constant_op.constant([["a"]], dtype=dtypes.string))) self.evaluate(resource_variable_ops.resource_scatter_update( handle, [0], constant_op.constant([["b"]], dtype=dtypes.string))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.string) self.assertEqual(compat.as_bytes(self.evaluate(read)[0][0]), compat.as_bytes("b"))
def testSharedName(self): v = resource_variable_ops.ResourceVariable(300.0, name="var1") self.evaluate(variables.global_variables_initializer()) w = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1") w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, self.evaluate(w_read)) x = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var2") if context.in_graph_mode(): with self.assertRaisesOpError("Resource .*/var2/.* does not exist"): x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype) self.evaluate(x_read) else: with self.assertRaisesRegexp(errors.NotFoundError, "Attempted to read a nonexistent variable."): _ = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
def testSharedNameWithNamescope(self): with self.test_session(): with ops.name_scope("foo"): v = resource_variable_ops.ResourceVariable(300.0, name="var1") v.initializer.run() w = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var1") w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, w_read.eval())
def testScatterDiv(self): with self.test_session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) sess.run( resource_variable_ops.assign_variable_op( handle, constant_op.constant([[6]], dtype=dtypes.int32))) sess.run( resource_variable_ops.resource_scatter_div( handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertAllEqual(sess.run(read), [[2]])
def testReturnResourceHandle(self): with self.test_scope(): v = resource_variable_ops.ResourceVariable([[1.0, 2.0], [3.0, 4.0]]) def f(v): return v.handle f = function.defun(f) handle = f(v) self.assertAllEqual(v.numpy(), resource_variable_ops.read_variable_op( handle, dtypes.float32).numpy())
def testScatterSub(self): with self.test_session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[2, 1]) sess.run( resource_variable_ops.assign_variable_op( handle, constant_op.constant([[4], [1]], dtype=dtypes.int32))) sess.run( resource_variable_ops.resource_scatter_sub( handle, [1], constant_op.constant([[2]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertAllEqual(self.evaluate(read), [[4], [-1]])
def testScatterMaxScalar(self): with self.test_session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) sess.run( resource_variable_ops.assign_variable_op( handle, constant_op.constant([[6]], dtype=dtypes.int32))) sess.run( resource_variable_ops.resource_scatter_max( handle, [0], constant_op.constant(3, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[6]])
def testSharedNameWithNamescope(self): with self.test_session(): with ops.name_scope("foo"): v = resource_variable_ops.ResourceVariable(300.0, name="var6") self.assertEqual("foo/var6", v._shared_name) # pylint: disable=protected-access self.assertEqual("foo/var6:0", v.name) self.evaluate(variables.global_variables_initializer()) w = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var6", # Needed in Eager since we get a unique container name by default. container=ops.get_default_graph()._container) w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, self.evaluate(w_read))
def testScatterUpdate(self): with self.test_session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[1, 1]) sess.run( resource_variable_ops.assign_variable_op( handle, constant_op.constant([[6]], dtype=dtypes.int32))) sess.run( resource_variable_ops.resource_scatter_update( handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(sess.run(read), [[3]])
def testScatterNdAddOps(self): with self.test_session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.float32, shape=[8]) sess.run( resource_variable_ops.assign_variable_op( handle, constant_op.constant([1] * 8, dtype=dtypes.float32))) indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32) expected = np.array([1, 12, 1, 11, 10, 1, 1, 13]) sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates)) read = resource_variable_ops.read_variable_op( handle, dtype=dtypes.float32) self.assertAllClose(expected, sess.run(read))
def testReturnResourceHandle(self): with self.test_scope(): v = resource_variable_ops.ResourceVariable([[1.0, 2.0], [3.0, 4.0]]) def f(v): return v.handle f = function.defun(f) handle = f(v) self.assertAllEqual( v.numpy(), resource_variable_ops.read_variable_op(handle, dtypes.float32).numpy())
def testScatterMin(self): with ops.device("cpu:0"): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[1, 1]) self.evaluate( resource_variable_ops.assign_variable_op( handle, constant_op.constant([[6]], dtype=dtypes.int32))) self.evaluate( resource_variable_ops.resource_scatter_min( handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[3]])
def testScatterNdAddOps(self): with self.test_session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.float32, shape=[8]) sess.run( resource_variable_ops.assign_variable_op( handle, constant_op.constant([1] * 8, dtype=dtypes.float32))) indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32) expected = np.array([1, 12, 1, 11, 10, 1, 1, 13]) sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates)) read = resource_variable_ops.read_variable_op( handle, dtype=dtypes.float32) self.assertAllClose(expected, self.evaluate(read))
def testSharedNameWithNamescope(self): with self.cached_session(): with ops.name_scope("foo"): v = resource_variable_ops.ResourceVariable(300.0, name="var6") self.assertEqual("foo/var6", v._shared_name) # pylint: disable=protected-access self.assertEqual("foo/var6:0", v.name) self.evaluate(variables.global_variables_initializer()) w = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var6", # Needed in Eager since we get a unique container name by default. container=ops.get_default_graph()._container) w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, self.evaluate(w_read))
def testScatterUpdateString(self): handle = resource_variable_ops.var_handle_op(dtype=dtypes.string, shape=[1, 1]) self.evaluate( resource_variable_ops.assign_variable_op( handle, constant_op.constant([["a"]], dtype=dtypes.string))) self.evaluate( resource_variable_ops.resource_scatter_update( handle, [0], constant_op.constant([["b"]], dtype=dtypes.string))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.string) self.assertEqual(compat.as_bytes(self.evaluate(read)[0][0]), compat.as_bytes("b"))
def testIndependentOpsInLoop(self): v = resource_variable_ops.ResourceVariable(0) self.evaluate(variables.global_variables_initializer()) @def_function.function def f(): for i in math_ops.range(3): ops.get_default_graph().experimental_acd_manager.run_independently( gen_resource_variable_ops.assign_variable_op(v.handle, i)) self.evaluate(f()) # TODO(mdan): Find a more robust way to test in loops. self.assertEqual( self.evaluate( resource_variable_ops.read_variable_op(v.handle, dtypes.int32)), 2)
def testScatterAdd(self): with self.test_session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[2, 1]) sess.run( resource_variable_ops.assign_variable_op( handle, constant_op.constant([[1], [7]], dtype=dtypes.int32))) sess.run( resource_variable_ops.resource_scatter_add( handle, [0], constant_op.constant([[2]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertAllEqual(sess.run(read), [[3], [7]])
def testScatterSub(self): with self.session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[2, 1]) sess.run( resource_variable_ops.assign_variable_op( handle, constant_op.constant([[4], [1]], dtype=dtypes.int32))) sess.run( resource_variable_ops.resource_scatter_sub( handle, [1], constant_op.constant([[2]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertAllEqual(self.evaluate(read), [[4], [-1]])
def testSharedName(self): with self.test_session(): v = resource_variable_ops.ResourceVariable(300.0, name="var4") variables.global_variables_initializer().run() w = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4", # Needed in Eager since we get a unique container name by default. container=ops.get_default_graph()._container) w_read = resource_variable_ops.read_variable_op( w, v.dtype.base_dtype) self.assertEqual(300.0, w_read.eval()) x = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5", container=ops.get_default_graph()._container) with self.assertRaisesOpError( "Resource .*/var5/.* does not exist"): resource_variable_ops.read_variable_op( x, v.dtype.base_dtype).eval()
def testScatterMin(self): with ops.device("cpu:0"): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) self.evaluate( resource_variable_ops.assign_variable_op(handle, constant_op.constant( [[6]], dtype=dtypes.int32))) self.evaluate( resource_variable_ops.resource_scatter_min(handle, [0], constant_op.constant( [[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[3]])
def testIndependentOpsRunInParallel(self): v = resource_variable_ops.ResourceVariable(1) self.evaluate(variables.global_variables_initializer()) @def_function.function def f(): gen_resource_variable_ops.assign_variable_op(v.handle, 1) ops.get_default_graph().experimental_acd_manager.run_independently( gen_resource_variable_ops.assign_variable_op(v.handle, 2)) # A function with two identical ops, should cause a data race in most # conditions. var_values = set() for _ in range(1000): self.evaluate(f()) var_values.add( self.evaluate( resource_variable_ops.read_variable_op( v.handle, dtypes.int32))) # With regular control dependencies, the function should always run the # first assign first, and the value 1 should never be seen. self.assertSetEqual(var_values, set((1, 2)))
def inner(var1, var2): return (resource_variable_ops.read_variable_op(var1, dtypes.float32) + resource_variable_ops.read_variable_op(var2, dtypes.float32))
def inner(var1, var2): return ( resource_variable_ops.read_variable_op(var1, dtypes.float32) + resource_variable_ops.read_variable_op(var2, dtypes.float32))