def testDtypeSurvivesIdentity(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) id_handle = array_ops.identity(handle) resource_variable_ops.create_variable_op( id_handle, constant_op.constant(0, dtype=dtypes.int32)).run()
def testCreateRead(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) resource_variable_ops.create_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32)).run() value = resource_variable_ops.read_variable_op( handle, dtype=dtypes.int32).eval() self.assertAllEqual(1, value)
def testAssignAdd(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) resource_variable_ops.create_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32)).run() assign_add = resource_variable_ops.assign_add_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32)) self.assertEqual(assign_add.eval(), 2)
def testCreateRead(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) resource_variable_ops.create_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32)).run() value = resource_variable_ops.read_variable_op( handle, dtype=dtypes.int32).eval() self.assertAllEqual(1, value)
def testScatterAdd(self): with self.test_session(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) resource_variable_ops.create_variable_op( handle, constant_op.constant([[1]], dtype=dtypes.int32)).run() resource_variable_ops.resource_scatter_add( handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)).run() read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(read.eval(), [[3]])
def testAssignAdd(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) resource_variable_ops.create_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32)).run() resource_variable_ops.assign_add_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32)).run() read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(read.eval(), 2)
def testScatterAdd(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[1, 1]) resource_variable_ops.create_variable_op( handle, constant_op.constant([[1]], dtype=dtypes.int32)).run() resource_variable_ops.resource_scatter_add( handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)).run() read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(read.eval(), [[3]])
def testHandleDtypeShapeMatch(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) with self.assertRaises(ValueError): resource_variable_ops.create_variable_op( handle, constant_op.constant(0.0, dtype=dtypes.float32)).run() with self.assertRaises(ValueError): resource_variable_ops.create_variable_op( handle, constant_op.constant([0], dtype=dtypes.int32)).run() resource_variable_ops.create_variable_op( handle, constant_op.constant(0, dtype=dtypes.int32)).run()
def testManyAssigns(self): with self.test_session() as session: handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) create = resource_variable_ops.create_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32)) with ops.control_dependencies([create]): first_read = resource_variable_ops.read_variable_op( handle, dtype=dtypes.int32) with ops.control_dependencies([first_read]): write = resource_variable_ops.assign_variable_op( handle, constant_op.constant(2, dtype=dtypes.int32)) with ops.control_dependencies([write]): second_read = resource_variable_ops.read_variable_op( handle, dtype=dtypes.int32) f, s = session.run([first_read, second_read]) self.assertEqual(f, 1) self.assertEqual(s, 2)
def testManyAssigns(self): with self.test_session() as session: handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) create = resource_variable_ops.create_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32)) with ops.control_dependencies([create]): first_read = resource_variable_ops.read_variable_op( handle, dtype=dtypes.int32) with ops.control_dependencies([first_read]): write = resource_variable_ops.assign_variable_op( handle, constant_op.constant(2, dtype=dtypes.int32)) with ops.control_dependencies([write]): second_read = resource_variable_ops.read_variable_op( handle, dtype=dtypes.int32) f, s = session.run([first_read, second_read]) self.assertEqual(f, 1) self.assertEqual(s, 2)
def testHandleDtypeShapeMatch(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) with self.assertRaises(ValueError): resource_variable_ops.create_variable_op( handle, constant_op.constant(0.0, dtype=dtypes.float32)).run() with self.assertRaises(ValueError): resource_variable_ops.create_variable_op( handle, constant_op.constant([0], dtype=dtypes.int32)).run() resource_variable_ops.create_variable_op( handle, constant_op.constant(0, dtype=dtypes.int32)).run()
def testDtypeSurvivesIdentity(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) id_handle = array_ops.identity(handle) resource_variable_ops.create_variable_op( id_handle, constant_op.constant(0, dtype=dtypes.int32)).run()