def testOneShotIteratorInsideContainer(self): components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) def within_container(): def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) iterator = (dataset_ops.Dataset.from_tensor_slices(components) .map(_map_fn).repeat(14).make_one_shot_iterator()) return iterator.get_next() server = server_lib.Server.create_local_server() # Create two iterators within unique containers, and run them to # make sure that the resources aren't shared. # # The test below would fail if cname were the same across both # sessions. for i in range(2): with session.Session(server.target) as sess: cname = "iteration%d" % i with ops.container(cname): get_next = within_container() for _ in range(14): for i in range(7): result = sess.run(get_next) for component, result_component in zip(components, result): self.assertAllEqual(component[i]**2, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next)
def testResetFails(self): # Creates variable with container name. with ops.container("test0"): v0 = variables.Variable(1.0, name="v0") # Creates variable with default container. v1 = variables.Variable(2.0, name="v1") # Verifies resetting the non-existent target returns error. with self.assertRaises(errors_impl.NotFoundError): session.Session.reset("nonexistent", ["test0"]) # Verifies resetting with config. # Verifies that resetting target with no server times out. with self.assertRaises(errors_impl.DeadlineExceededError): session.Session.reset( "grpc://localhost:0", ["test0"], config=config_pb2.ConfigProto(operation_timeout_in_ms=5)) # Verifies no containers are reset with non-existent container. server = self._cached_server sess = session.Session(server.target) sess.run(variables.global_variables_initializer()) self.assertAllEqual(1.0, sess.run(v0)) self.assertAllEqual(2.0, sess.run(v1)) # No container is reset, but the server is reset. session.Session.reset(server.target, ["test1"]) # Verifies that both variables are still valid. sess = session.Session(server.target) self.assertAllEqual(1.0, sess.run(v0)) self.assertAllEqual(2.0, sess.run(v1))
def testResetFails(self): # Creates variable with container name. with ops.container("test0"): v0 = variables.VariableV1(1.0, name="v0") # Creates variable with default container. v1 = variables.VariableV1(2.0, name="v1") # Verifies resetting the non-existent target returns error. with self.assertRaises(errors_impl.NotFoundError): session.Session.reset("nonexistent", ["test0"]) # Verifies resetting with config. # Verifies that resetting target with no server times out. with self.assertRaises(errors_impl.DeadlineExceededError): session.Session.reset( "grpc://localhost:0", ["test0"], config=config_pb2.ConfigProto(operation_timeout_in_ms=5)) # Verifies no containers are reset with non-existent container. server = self._cached_server sess = session.Session(server.target) sess.run(variables.global_variables_initializer()) self.assertAllEqual(1.0, sess.run(v0)) self.assertAllEqual(2.0, sess.run(v1)) # No container is reset, but the server is reset. session.Session.reset(server.target, ["test1"]) # Verifies that both variables are still valid. sess = session.Session(server.target) self.assertAllEqual(1.0, sess.run(v0)) self.assertAllEqual(2.0, sess.run(v1))
def testOneShotIteratorInsideContainer(self): components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) def within_container(): def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) iterator = ( dataset_ops.Dataset.from_tensor_slices(components) .map(_map_fn).repeat(14).make_one_shot_iterator()) return iterator.get_next() server = server_lib.Server.create_local_server() # Create two iterators within unique containers, and run them to # make sure that the resources aren't shared. # # The test below would fail if cname were the same across both # sessions. for j in range(2): with session.Session(server.target) as sess: cname = "iteration%d" % j with ops.container(cname): get_next = within_container() for _ in range(14): for i in range(7): result = sess.run(get_next) for component, result_component in zip(components, result): self.assertAllEqual(component[i]**2, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next)
def __call__(self, inputs, *args, **kwargs): # TODO(josh11b,ashankar,agarwal): Can we reduce the number of context # managers here and/or move some of the work into the constructor # for performance reasons? with ops.container(self._container): with variable_scope.variable_scope( variable_scope.get_variable_scope(), use_resource=True): return super(Network, self).__call__(inputs, *args, **kwargs)
def __call__(self, inputs, *args, **kwargs): # TODO(josh11b,ashankar,agarwal): Can we reduce the number of context # managers here and/or move some of the work into the constructor # for performance reasons? with ops.container(self._container): with variable_scope.variable_scope(variable_scope.get_variable_scope(), use_resource=True): return super(Network, self).__call__(inputs, *args, **kwargs)
def testContainerEager(self): with context.eager_mode(): v1 = resource_variable_ops.ResourceVariable(initial_value=lambda: 1, name="same") with ops.container("different"): v2 = resource_variable_ops.ResourceVariable(initial_value=lambda: 0, name="same") v2.assign(2) self.assertEqual(1, v1.read_value().numpy()) self.assertEqual(2, v2.read_value().numpy())
def testContainerEager(self): with context.eager_mode(): v1 = resource_variable_ops.ResourceVariable( initial_value=lambda: 1, name="same") with ops.container("different"): v2 = resource_variable_ops.ResourceVariable( initial_value=lambda: 0, name="same") v2.assign(2) self.assertEqual(1, v1.read_value().numpy()) self.assertEqual(2, v2.read_value().numpy())
def testContainer(self): with ops.Graph().as_default(): v0 = variables.Variable([0]) with ops.container("l1"): v1 = variables.Variable([1]) with ops.container("l2"): v2 = variables.Variable([2]) special_v = gen_state_ops.variable( shape=[1], dtype=dtypes.float32, name="VariableInL3", container="l3", shared_name="") v3 = variables.Variable([3]) v4 = variables.Variable([4]) self.assertEqual(compat.as_bytes(""), v0.op.get_attr("container")) self.assertEqual(compat.as_bytes("l1"), v1.op.get_attr("container")) self.assertEqual(compat.as_bytes("l2"), v2.op.get_attr("container")) self.assertEqual(compat.as_bytes("l3"), special_v.op.get_attr("container")) self.assertEqual(compat.as_bytes("l1"), v3.op.get_attr("container")) self.assertEqual(compat.as_bytes(""), v4.op.get_attr("container"))
def testMultipleContainers(self): with ops.container("test0"): v0 = variables.Variable(1.0, name="v0") with ops.container("test1"): v1 = variables.Variable(2.0, name="v0") server = server_lib.Server.create_local_server() sess = session.Session(server.target) sess.run(variables.global_variables_initializer()) self.assertAllEqual(1.0, sess.run(v0)) self.assertAllEqual(2.0, sess.run(v1)) # Resets container. Session aborts. session.Session.reset(server.target, ["test0"]) with self.assertRaises(errors_impl.AbortedError): sess.run(v1) # Connects to the same target. Device memory for the v0 would have # been released, so it will be uninitialized. But v1 should still # be valid. sess = session.Session(server.target) with self.assertRaises(errors_impl.FailedPreconditionError): sess.run(v0) self.assertAllEqual(2.0, sess.run(v1))
def false_fn(): # When this branch is created in cond below, # the container should begin with 'l1' v1 = variables.Variable([1]) q1 = data_flow_ops.FIFOQueue(1, dtypes.float32) with ops.container("l2f"): v2 = variables.Variable([2]) q2 = data_flow_ops.FIFOQueue(1, dtypes.float32) v3 = variables.Variable([1]) q3 = data_flow_ops.FIFOQueue(1, dtypes.float32) self.assertEqual(compat.as_bytes("l1"), container(v1)) self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref)) self.assertEqual(compat.as_bytes("l2f"), container(v2)) self.assertEqual(compat.as_bytes("l2f"), container(q2.queue_ref)) self.assertEqual(compat.as_bytes("l1"), container(v3)) self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref)) return constant_op.constant(6.0)
def testContainer(self): """Set containers outside & inside of cond_v2. Make sure the containers are set correctly for both variable creation (tested by variables.Variable) and for stateful ops (tested by FIFOQueue) """ self.skipTest("b/113048653") with ops.Graph().as_default() as g: with self.session(graph=g): v0 = variables.Variable([0]) q0 = data_flow_ops.FIFOQueue(1, dtypes.float32) def container(node): return node.op.get_attr("container") self.assertEqual(compat.as_bytes(""), container(v0)) self.assertEqual(compat.as_bytes(""), container(q0.queue_ref)) def true_fn(): # When this branch is created in cond below, # the container should begin with 'l1' v1 = variables.Variable([1]) q1 = data_flow_ops.FIFOQueue(1, dtypes.float32) with ops.container("l2t"): v2 = variables.Variable([2]) q2 = data_flow_ops.FIFOQueue(1, dtypes.float32) v3 = variables.Variable([1]) q3 = data_flow_ops.FIFOQueue(1, dtypes.float32) self.assertEqual(compat.as_bytes("l1"), container(v1)) self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref)) self.assertEqual(compat.as_bytes("l2t"), container(v2)) self.assertEqual(compat.as_bytes("l2t"), container(q2.queue_ref)) self.assertEqual(compat.as_bytes("l1"), container(v3)) self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref)) return constant_op.constant(2.0) def false_fn(): # When this branch is created in cond below, # the container should begin with 'l1' v1 = variables.Variable([1]) q1 = data_flow_ops.FIFOQueue(1, dtypes.float32) with ops.container("l2f"): v2 = variables.Variable([2]) q2 = data_flow_ops.FIFOQueue(1, dtypes.float32) v3 = variables.Variable([1]) q3 = data_flow_ops.FIFOQueue(1, dtypes.float32) self.assertEqual(compat.as_bytes("l1"), container(v1)) self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref)) self.assertEqual(compat.as_bytes("l2f"), container(v2)) self.assertEqual(compat.as_bytes("l2f"), container(q2.queue_ref)) self.assertEqual(compat.as_bytes("l1"), container(v3)) self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref)) return constant_op.constant(6.0) with ops.container("l1"): cnd_true = cond_v2.cond_v2(constant_op.constant(True), true_fn, false_fn) self.assertEquals(cnd_true.eval(), 2) cnd_false = cond_v2.cond_v2(constant_op.constant(False), true_fn, false_fn) self.assertEquals(cnd_false.eval(), 6) v4 = variables.Variable([3]) q4 = data_flow_ops.FIFOQueue(1, dtypes.float32) v5 = variables.Variable([4]) q5 = data_flow_ops.FIFOQueue(1, dtypes.float32) self.assertEqual(compat.as_bytes("l1"), container(v4)) self.assertEqual(compat.as_bytes("l1"), container(q4.queue_ref)) self.assertEqual(compat.as_bytes(""), container(v5)) self.assertEqual(compat.as_bytes(""), container(q5.queue_ref))
def testContainer(self): """Set containers outside & inside of cond_v2. Make sure the containers are set correctly for both variable creation (tested by variables.Variable) and for stateful ops (tested by FIFOQueue) """ self.skipTest("b/113048653") with ops.Graph().as_default() as g: with self.test_session(graph=g): v0 = variables.Variable([0]) q0 = data_flow_ops.FIFOQueue(1, dtypes.float32) def container(node): return node.op.get_attr("container") self.assertEqual(compat.as_bytes(""), container(v0)) self.assertEqual(compat.as_bytes(""), container(q0.queue_ref)) def true_fn(): # When this branch is created in cond below, # the container should begin with 'l1' v1 = variables.Variable([1]) q1 = data_flow_ops.FIFOQueue(1, dtypes.float32) with ops.container("l2t"): v2 = variables.Variable([2]) q2 = data_flow_ops.FIFOQueue(1, dtypes.float32) v3 = variables.Variable([1]) q3 = data_flow_ops.FIFOQueue(1, dtypes.float32) self.assertEqual(compat.as_bytes("l1"), container(v1)) self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref)) self.assertEqual(compat.as_bytes("l2t"), container(v2)) self.assertEqual(compat.as_bytes("l2t"), container(q2.queue_ref)) self.assertEqual(compat.as_bytes("l1"), container(v3)) self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref)) return constant_op.constant(2.0) def false_fn(): # When this branch is created in cond below, # the container should begin with 'l1' v1 = variables.Variable([1]) q1 = data_flow_ops.FIFOQueue(1, dtypes.float32) with ops.container("l2f"): v2 = variables.Variable([2]) q2 = data_flow_ops.FIFOQueue(1, dtypes.float32) v3 = variables.Variable([1]) q3 = data_flow_ops.FIFOQueue(1, dtypes.float32) self.assertEqual(compat.as_bytes("l1"), container(v1)) self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref)) self.assertEqual(compat.as_bytes("l2f"), container(v2)) self.assertEqual(compat.as_bytes("l2f"), container(q2.queue_ref)) self.assertEqual(compat.as_bytes("l1"), container(v3)) self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref)) return constant_op.constant(6.0) with ops.container("l1"): cnd_true = cond_v2.cond_v2(True, true_fn, false_fn) self.assertEquals(cnd_true[0].eval(), 2) cnd_false = cond_v2.cond_v2(False, true_fn, false_fn) self.assertEquals(cnd_false[0].eval(), 6) v4 = variables.Variable([3]) q4 = data_flow_ops.FIFOQueue(1, dtypes.float32) v5 = variables.Variable([4]) q5 = data_flow_ops.FIFOQueue(1, dtypes.float32) self.assertEqual(compat.as_bytes("l1"), container(v4)) self.assertEqual(compat.as_bytes("l1"), container(q4.queue_ref)) self.assertEqual(compat.as_bytes(""), container(v5)) self.assertEqual(compat.as_bytes(""), container(q5.queue_ref))