def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): """Creates a variable handle with information to do shape inference.""" container = ops.get_default_graph()._container # pylint: disable=protected-access if container is None: container = "" handle = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, shared_name=shared_name, name=name, container=container) if graph_mode: return handle with context.graph_mode(), ops.Graph().as_default() as graph: h = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, shared_name=shared_name, name=name, container=container) # Tensor._handle_data contains information for the shape-inference code to # know the shape and dtype of the variable pointed to by a handle. Since # shape inference doesn't run in eager mode we copy this data here for when # the handle is captured by an eager mode function. # pylint: disable=protected-access handle._handle_data = resource_variable_ops.get_resource_handle_data(h) # pylint: enable=protected-access # Clean up op->graph->op reference cycles. ops.dismantle_graph(graph) return handle
def testSharedName(self): v = resource_variable_ops.ResourceVariable(300.0, name="var4") self.evaluate(variables.global_variables_initializer()) w = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4") w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, self.evaluate(w_read)) x = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5") with self.assertRaisesOpError("Resource .*/var5/.* does not exist"): x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype) self.evaluate(x_read)
def testSharedName(self): with self.test_session(): v = resource_variable_ops.ResourceVariable(300.0, name="var1") v.initializer.run() w = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1") w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, w_read.eval()) x = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1/") x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype) with self.assertRaisesOpError("Resource .*/var1//.* does not exist"): _ = x_read.eval()
def testCreateRead(self): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) self.evaluate(resource_variable_ops.assign_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32))) value = self.evaluate( resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)) self.assertAllEqual(1, value)
def _custom_getter(getter=None, name=None, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring initializer=None, regularizer=None, reuse=None, trainable=True, collections=None, caching_device=None, # pylint: disable=redefined-outer-name partitioner=None, validate_shape=True, use_resource=None): del getter, regularizer, collections, caching_device, partitioner del use_resource, validate_shape if name in self.tf_variables: if reuse: return self.tf_variables[name].initialized_value() else: raise ValueError("Specified reuse=%s but tried to reuse variables." % reuse) # TODO(apassos): ensure this is on the same device as above v = _CapturedVariable(name, initializer, shape, dtype, trainable) self.variables[name] = v graph_mode_resource = resource_variable_ops.var_handle_op( shared_name=name, shape=shape, dtype=dtype) if initializer is None: initializer = _default_initializer(name, shape, dtype) resource_variable_ops.assign_variable_op( graph_mode_resource, initializer(shape, dtype)) return _VariableFromResource( graph_mode_resource, dtype, name, shape=v.shape)
def testAssignAdd(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) resource_variable_ops.assign_variable_op(handle, constant_op.constant(1, dtype=dtypes.int32)).run() resource_variable_ops.assign_add_variable_op(handle, constant_op.constant(1, dtype=dtypes.int32)).run() read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(read.eval(), 2)
def testHandleDtypeShapeMatch(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) with self.assertRaises(ValueError): resource_variable_ops.assign_variable_op(handle, constant_op.constant(0.0, dtype=dtypes.float32)).run() with self.assertRaises(ValueError): resource_variable_ops.assign_variable_op(handle, constant_op.constant([0], dtype=dtypes.int32)).run() resource_variable_ops.assign_variable_op(handle, constant_op.constant(0, dtype=dtypes.int32)).run()
def testDtypeSurvivesIdentity(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) id_handle = array_ops.identity(handle) resource_variable_ops.assign_variable_op(id_handle, constant_op.constant( 0, dtype=dtypes.int32)).run()
def testReadVariableDtypeMismatchEager(self): with context.eager_mode(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1], name="foo") with self.assertRaisesRegexp(errors.InvalidArgumentError, "Trying to read variable with wrong dtype. " "Expected float got int32."): _ = resource_variable_ops.read_variable_op(handle, dtype=dtypes.float32)
def testCreateRead(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) resource_variable_ops.assign_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32)).run() value = resource_variable_ops.read_variable_op( handle, dtype=dtypes.int32).eval() self.assertAllEqual(1, value)
def testSharedName(self): with self.cached_session(): v = resource_variable_ops.ResourceVariable(300.0, name="var4") variables.global_variables_initializer().run() w = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4", # Needed in Eager since we get a unique container name by default. container=ops.get_default_graph()._container) w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, self.evaluate(w_read)) x = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5", container=ops.get_default_graph()._container) with self.assertRaisesOpError("Resource .*/var5/.* does not exist"): resource_variable_ops.read_variable_op(x, v.dtype.base_dtype).eval()
def testScatterAdd(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[1, 1]) resource_variable_ops.assign_variable_op(handle, constant_op.constant([[1]], dtype=dtypes.int32)).run() resource_variable_ops.resource_scatter_add( handle, [0], constant_op.constant([[2]], dtype=dtypes.int32) ).run() read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(read.eval(), [[3]])
def testSharedNameWithNamescope(self): with ops.name_scope("foo"): v = resource_variable_ops.ResourceVariable(300.0, name="var3") self.evaluate(variables.global_variables_initializer()) w = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var3") w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, self.evaluate(w_read))
def testScatterAdd(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) self.evaluate(resource_variable_ops.assign_variable_op( handle, constant_op.constant([[1]], dtype=dtypes.int32))) self.evaluate(resource_variable_ops.resource_scatter_add( handle, [0], constant_op.constant([[2]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[3]])
def testAssignAdd(self): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) self.evaluate(resource_variable_ops.assign_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32))) self.evaluate(resource_variable_ops.assign_add_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32))) read = self.evaluate( resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)) self.assertEqual(read, 2)
def testSharedName(self): v = resource_variable_ops.ResourceVariable(300.0, name="var1") self.evaluate(variables.global_variables_initializer()) w = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1") w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, self.evaluate(w_read)) x = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var2") if context.in_graph_mode(): with self.assertRaisesOpError("Resource .*/var2/.* does not exist"): x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype) self.evaluate(x_read) else: with self.assertRaisesRegexp(errors.NotFoundError, "Attempted to read a nonexistent variable."): _ = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
def testSharedNameWithNamescope(self): with self.test_session(): with ops.name_scope("foo"): v = resource_variable_ops.ResourceVariable(300.0, name="var1") v.initializer.run() w = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var1") w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, w_read.eval())
def testScatterUpdateString(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.string, shape=[1, 1]) self.evaluate(resource_variable_ops.assign_variable_op( handle, constant_op.constant([["a"]], dtype=dtypes.string))) self.evaluate(resource_variable_ops.resource_scatter_update( handle, [0], constant_op.constant([["b"]], dtype=dtypes.string))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.string) self.assertEqual(compat.as_bytes(self.evaluate(read)[0][0]), compat.as_bytes("b"))
def testAssignVariableDtypeMismatchEager(self): with context.eager_mode(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1], name="foo") resource_variable_ops.assign_variable_op( handle, constant_op.constant([1])) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Trying to assign variable with wrong " "dtype. Expected int32 got float."): resource_variable_ops.assign_variable_op( handle, constant_op.constant([1.], dtype=dtypes.float32))
def testDestroyResource(self): v = resource_variable_ops.ResourceVariable(3.0, name="var0") self.evaluate(variables.global_variables_initializer()) self.assertEqual(3.0, self.evaluate(v.value())) self.evaluate(resource_variable_ops.destroy_resource_op(v.handle)) with self.assertRaises(errors.FailedPreconditionError): self.evaluate(v.value()) # Handle to a resource not actually created. handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) # Should raise no exception self.evaluate(resource_variable_ops.destroy_resource_op( handle, ignore_lookup_error=True))
def testScatterMaxScalar(self): with self.test_session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) sess.run( resource_variable_ops.assign_variable_op( handle, constant_op.constant([[6]], dtype=dtypes.int32))) sess.run( resource_variable_ops.resource_scatter_max( handle, [0], constant_op.constant(3, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[6]])
def testScatterSub(self): with self.test_session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[2, 1]) sess.run( resource_variable_ops.assign_variable_op( handle, constant_op.constant([[4], [1]], dtype=dtypes.int32))) sess.run( resource_variable_ops.resource_scatter_sub( handle, [1], constant_op.constant([[2]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertAllEqual(self.evaluate(read), [[4], [-1]])
def testScatterDiv(self): with self.test_session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) sess.run( resource_variable_ops.assign_variable_op( handle, constant_op.constant([[6]], dtype=dtypes.int32))) sess.run( resource_variable_ops.resource_scatter_div( handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertAllEqual(sess.run(read), [[2]])
def testDestroyResource(self): with self.test_session() as sess: v = resource_variable_ops.ResourceVariable(3.0) variables.global_variables_initializer().run() self.assertEqual(3.0, v.value().eval()) sess.run(resource_variable_ops.destroy_resource_op(v.handle)) with self.assertRaises(errors.NotFoundError): v.value().eval() # Handle to a resource not actually created. handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) # Should raise no exception sess.run(resource_variable_ops.destroy_resource_op( handle, ignore_lookup_error=True))
def testManyAssigns(self): with self.test_session() as session: handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) create = resource_variable_ops.assign_variable_op(handle, constant_op.constant(1, dtype=dtypes.int32)) with ops.control_dependencies([create]): first_read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) with ops.control_dependencies([first_read]): write = resource_variable_ops.assign_variable_op(handle, constant_op.constant(2, dtype=dtypes.int32)) with ops.control_dependencies([write]): second_read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) f, s = session.run([first_read, second_read]) self.assertEqual(f, 1) self.assertEqual(s, 2)
def testScatterUpdate(self): with self.test_session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[1, 1]) sess.run( resource_variable_ops.assign_variable_op( handle, constant_op.constant([[6]], dtype=dtypes.int32))) sess.run( resource_variable_ops.resource_scatter_update( handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(sess.run(read), [[3]])
def testScatterAdd(self): with ops.device("cpu:0"): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[1, 1]) self.evaluate( resource_variable_ops.assign_variable_op( handle, constant_op.constant([[1]], dtype=dtypes.int32))) self.evaluate( resource_variable_ops.resource_scatter_add( handle, [0], constant_op.constant([[2]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[3]])
def testScatterUpdateString(self): handle = resource_variable_ops.var_handle_op(dtype=dtypes.string, shape=[1, 1]) self.evaluate( resource_variable_ops.assign_variable_op( handle, constant_op.constant([["a"]], dtype=dtypes.string))) self.evaluate( resource_variable_ops.resource_scatter_update( handle, [0], constant_op.constant([["b"]], dtype=dtypes.string))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.string) self.assertEqual(compat.as_bytes(self.evaluate(read)[0][0]), compat.as_bytes("b"))
def testScatterNdAddOps(self): with self.test_session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.float32, shape=[8]) sess.run( resource_variable_ops.assign_variable_op( handle, constant_op.constant([1] * 8, dtype=dtypes.float32))) indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32) expected = np.array([1, 12, 1, 11, 10, 1, 1, 13]) sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates)) read = resource_variable_ops.read_variable_op( handle, dtype=dtypes.float32) self.assertAllClose(expected, self.evaluate(read))
def testSharedNameWithNamescope(self): with self.cached_session(): with ops.name_scope("foo"): v = resource_variable_ops.ResourceVariable(300.0, name="var6") self.assertEqual("foo/var6", v._shared_name) # pylint: disable=protected-access self.assertEqual("foo/var6:0", v.name) self.evaluate(variables.global_variables_initializer()) w = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var6", # Needed in Eager since we get a unique container name by default. container=ops.get_default_graph()._container) w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, self.evaluate(w_read))
def testDestroyResource(self): v = resource_variable_ops.ResourceVariable(3.0, name="var0") self.evaluate(variables.global_variables_initializer()) self.assertEqual(3.0, self.evaluate(v.value())) self.evaluate(resource_variable_ops.destroy_resource_op(v.handle)) with self.assertRaises(errors.FailedPreconditionError): self.evaluate(v.value()) # Handle to a resource not actually created. handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) # Should raise no exception self.evaluate( resource_variable_ops.destroy_resource_op( handle, ignore_lookup_error=True))
def testHandleDtypeShapeMatch(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) with self.assertRaises(ValueError): resource_variable_ops.assign_variable_op( handle, constant_op.constant(0.0, dtype=dtypes.float32)).run() with self.assertRaises(ValueError): resource_variable_ops.assign_variable_op( handle, constant_op.constant([0], dtype=dtypes.int32)).run() resource_variable_ops.assign_variable_op( handle, constant_op.constant(0, dtype=dtypes.int32)).run()
def testSharedNameWithNamescope(self): with self.test_session(): with ops.name_scope("foo"): v = resource_variable_ops.ResourceVariable(300.0, name="var6") self.assertEqual("foo/var6", v._shared_name) # pylint: disable=protected-access self.assertEqual("foo/var6:0", v.name) self.evaluate(variables.global_variables_initializer()) w = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var6", # Needed in Eager since we get a unique container name by default. container=ops.get_default_graph()._container) w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, self.evaluate(w_read))
def testDestroyResource(self): with self.test_session() as sess: v = resource_variable_ops.ResourceVariable(3.0) variables.global_variables_initializer().run() self.assertEqual(3.0, v.value().eval()) sess.run(resource_variable_ops.destroy_resource_op(v.handle)) with self.assertRaises(errors.NotFoundError): v.value().eval() # Handle to a resource not actually created. handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) # Should raise no exception sess.run( resource_variable_ops.destroy_resource_op( handle, ignore_lookup_error=True))
def testScatterSub(self): with self.session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[2, 1]) sess.run( resource_variable_ops.assign_variable_op( handle, constant_op.constant([[4], [1]], dtype=dtypes.int32))) sess.run( resource_variable_ops.resource_scatter_sub( handle, [1], constant_op.constant([[2]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertAllEqual(self.evaluate(read), [[4], [-1]])
def testScatterAdd(self): with self.test_session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[2, 1]) sess.run( resource_variable_ops.assign_variable_op( handle, constant_op.constant([[1], [7]], dtype=dtypes.int32))) sess.run( resource_variable_ops.resource_scatter_add( handle, [0], constant_op.constant([[2]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertAllEqual(sess.run(read), [[3], [7]])
def testSharedName(self): with self.test_session(): v = resource_variable_ops.ResourceVariable(300.0, name="var4") variables.global_variables_initializer().run() w = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4", # Needed in Eager since we get a unique container name by default. container=ops.get_default_graph()._container) w_read = resource_variable_ops.read_variable_op( w, v.dtype.base_dtype) self.assertEqual(300.0, w_read.eval()) x = resource_variable_ops.var_handle_op( dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5", container=ops.get_default_graph()._container) with self.assertRaisesOpError( "Resource .*/var5/.* does not exist"): resource_variable_ops.read_variable_op( x, v.dtype.base_dtype).eval()
def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): """Creates a variable handle with information to do shape inference.""" container = ops.get_default_graph()._container # pylint: disable=protected-access if container is None: container = "" handle = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, shared_name=shared_name, name=name, container=container) if graph_mode: return handle with context.graph_mode(), ops.Graph().as_default() as graph: h = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, shared_name=shared_name, name=name, container=container) # Tensor._handle_data contains information for the shape-inference code to # know the shape and dtype of the variable pointed to by a handle. Since # shape inference doesn't run in eager mode we copy this data here for when # the handle is captured by an eager mode function. # pylint: disable=protected-access if ops._USE_C_SHAPES: handle._handle_data = resource_variable_ops.get_resource_handle_data( h) else: if h._handle_data is None: ops.set_shape_and_handle_data_for_outputs(h.op) handle._handle_data = h._handle_data # pylint: enable=protected-access # Clean up op->graph->op reference cycles. ops.dismantle_graph(graph) return handle
def testScatterMin(self): with ops.device("cpu:0"): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) self.evaluate( resource_variable_ops.assign_variable_op(handle, constant_op.constant( [[6]], dtype=dtypes.int32))) self.evaluate( resource_variable_ops.resource_scatter_min(handle, [0], constant_op.constant( [[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[3]])
def testManyAssigns(self): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) create = resource_variable_ops.assign_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32)) with ops.control_dependencies([create]): first_read = resource_variable_ops.read_variable_op( handle, dtype=dtypes.int32) with ops.control_dependencies([first_read]): write = resource_variable_ops.assign_variable_op( handle, constant_op.constant(2, dtype=dtypes.int32)) with ops.control_dependencies([write]): second_read = resource_variable_ops.read_variable_op( handle, dtype=dtypes.int32) f, s = self.evaluate([first_read, second_read]) self.assertEqual(f, 1) self.assertEqual(s, 2)
def _custom_getter( getter=None, name=None, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring initializer=None, regularizer=None, reuse=None, trainable=True, collections=None, caching_device=None, # pylint: disable=redefined-outer-name partitioner=None, validate_shape=True, use_resource=None): del getter, regularizer, collections, caching_device, partitioner del use_resource, validate_shape if name in self.tf_variables: if reuse: return self.tf_variables[name].initialized_value() else: raise ValueError( "Specified reuse=%s but tried to reuse variables." % reuse) # TODO(apassos): ensure this is on the same device as above v = _CapturedVariable(name, initializer, shape, dtype, trainable) self.variables[name] = v graph_mode_resource = resource_variable_ops.var_handle_op( shared_name=name, shape=shape, dtype=dtype) if initializer is None: initializer = _default_initializer(name, shape, dtype) with tf_ops.control_dependencies([ resource_variable_ops.assign_variable_op( graph_mode_resource, initializer(shape, dtype)) ]): handle = array_ops.identity(v.variable.handle) return _VariableFromResource(handle, dtype, name, shape=v.shape)
def testDtypeSurvivesIdentity(self): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) id_handle = array_ops.identity(handle) self.evaluate(resource_variable_ops.assign_variable_op( id_handle, constant_op.constant(0, dtype=dtypes.int32)))
def testDtypeSurvivesIdentity(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) id_handle = array_ops.identity(handle) resource_variable_ops.assign_variable_op( id_handle, constant_op.constant(0, dtype=dtypes.int32)).run()
def testFetchHandle(self): with self.test_session(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1], name="foo") self.assertGreater(len(handle.eval()), 0)
def testUnprintableHandle(self): with context.eager_mode(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1], name="foo") self.assertIn("<unprintable>", str(handle)) self.assertIn("<unprintable>", repr(handle))