Exemplo n.º 1
0
  def testOneShotIteratorInsideContainer(self):
    components = (np.arange(7),
                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                  np.array(37.0) * np.arange(7))

    def within_container():
      def _map_fn(x, y, z):
        return math_ops.square(x), math_ops.square(y), math_ops.square(z)
      iterator = (dataset_ops.Dataset.from_tensor_slices(components)
                  .map(_map_fn).repeat(14).make_one_shot_iterator())
      return iterator.get_next()

    server = server_lib.Server.create_local_server()

    # Create two iterators within unique containers, and run them to
    # make sure that the resources aren't shared.
    #
    # The test below would fail if cname were the same across both
    # sessions.
    for i in range(2):
      with session.Session(server.target) as sess:
        cname = "iteration%d" % i
        with ops.container(cname):
          get_next = within_container()

        for _ in range(14):
          for i in range(7):
            result = sess.run(get_next)
            for component, result_component in zip(components, result):
              self.assertAllEqual(component[i]**2, result_component)
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(get_next)
    def testResetFails(self):
        # Creates variable with container name.
        with ops.container("test0"):
            v0 = variables.Variable(1.0, name="v0")
        # Creates variable with default container.
        v1 = variables.Variable(2.0, name="v1")
        # Verifies resetting the non-existent target returns error.
        with self.assertRaises(errors_impl.NotFoundError):
            session.Session.reset("nonexistent", ["test0"])

        # Verifies resetting with config.
        # Verifies that resetting target with no server times out.
        with self.assertRaises(errors_impl.DeadlineExceededError):
            session.Session.reset(
                "grpc://localhost:0", ["test0"],
                config=config_pb2.ConfigProto(operation_timeout_in_ms=5))

        # Verifies no containers are reset with non-existent container.
        server = self._cached_server
        sess = session.Session(server.target)
        sess.run(variables.global_variables_initializer())
        self.assertAllEqual(1.0, sess.run(v0))
        self.assertAllEqual(2.0, sess.run(v1))
        # No container is reset, but the server is reset.
        session.Session.reset(server.target, ["test1"])
        # Verifies that both variables are still valid.
        sess = session.Session(server.target)
        self.assertAllEqual(1.0, sess.run(v0))
        self.assertAllEqual(2.0, sess.run(v1))
Exemplo n.º 3
0
  def testResetFails(self):
    # Creates variable with container name.
    with ops.container("test0"):
      v0 = variables.VariableV1(1.0, name="v0")
    # Creates variable with default container.
    v1 = variables.VariableV1(2.0, name="v1")
    # Verifies resetting the non-existent target returns error.
    with self.assertRaises(errors_impl.NotFoundError):
      session.Session.reset("nonexistent", ["test0"])

    # Verifies resetting with config.
    # Verifies that resetting target with no server times out.
    with self.assertRaises(errors_impl.DeadlineExceededError):
      session.Session.reset(
          "grpc://localhost:0", ["test0"],
          config=config_pb2.ConfigProto(operation_timeout_in_ms=5))

    # Verifies no containers are reset with non-existent container.
    server = self._cached_server
    sess = session.Session(server.target)
    sess.run(variables.global_variables_initializer())
    self.assertAllEqual(1.0, sess.run(v0))
    self.assertAllEqual(2.0, sess.run(v1))
    # No container is reset, but the server is reset.
    session.Session.reset(server.target, ["test1"])
    # Verifies that both variables are still valid.
    sess = session.Session(server.target)
    self.assertAllEqual(1.0, sess.run(v0))
    self.assertAllEqual(2.0, sess.run(v1))
Exemplo n.º 4
0
  def testOneShotIteratorInsideContainer(self):
    components = (np.arange(7),
                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                  np.array(37.0) * np.arange(7))

    def within_container():

      def _map_fn(x, y, z):
        return math_ops.square(x), math_ops.square(y), math_ops.square(z)

      iterator = (
          dataset_ops.Dataset.from_tensor_slices(components)
          .map(_map_fn).repeat(14).make_one_shot_iterator())
      return iterator.get_next()

    server = server_lib.Server.create_local_server()

    # Create two iterators within unique containers, and run them to
    # make sure that the resources aren't shared.
    #
    # The test below would fail if cname were the same across both
    # sessions.
    for j in range(2):
      with session.Session(server.target) as sess:
        cname = "iteration%d" % j
        with ops.container(cname):
          get_next = within_container()

        for _ in range(14):
          for i in range(7):
            result = sess.run(get_next)
            for component, result_component in zip(components, result):
              self.assertAllEqual(component[i]**2, result_component)
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(get_next)
Exemplo n.º 5
0
 def __call__(self, inputs, *args, **kwargs):
     # TODO(josh11b,ashankar,agarwal): Can we reduce the number of context
     # managers here and/or move some of the work into the constructor
     # for performance reasons?
     with ops.container(self._container):
         with variable_scope.variable_scope(
                 variable_scope.get_variable_scope(), use_resource=True):
             return super(Network, self).__call__(inputs, *args, **kwargs)
Exemplo n.º 6
0
 def __call__(self, inputs, *args, **kwargs):
   # TODO(josh11b,ashankar,agarwal): Can we reduce the number of context
   # managers here and/or move some of the work into the constructor
   # for performance reasons?
   with ops.container(self._container):
     with variable_scope.variable_scope(variable_scope.get_variable_scope(),
                                        use_resource=True):
       return super(Network, self).__call__(inputs, *args, **kwargs)
 def testContainerEager(self):
   with context.eager_mode():
     v1 = resource_variable_ops.ResourceVariable(initial_value=lambda: 1,
                                                 name="same")
     with ops.container("different"):
       v2 = resource_variable_ops.ResourceVariable(initial_value=lambda: 0,
                                                   name="same")
     v2.assign(2)
     self.assertEqual(1, v1.read_value().numpy())
     self.assertEqual(2, v2.read_value().numpy())
Exemplo n.º 8
0
 def testContainerEager(self):
     with context.eager_mode():
         v1 = resource_variable_ops.ResourceVariable(
             initial_value=lambda: 1, name="same")
         with ops.container("different"):
             v2 = resource_variable_ops.ResourceVariable(
                 initial_value=lambda: 0, name="same")
         v2.assign(2)
         self.assertEqual(1, v1.read_value().numpy())
         self.assertEqual(2, v2.read_value().numpy())
Exemplo n.º 9
0
 def testContainer(self):
   with ops.Graph().as_default():
     v0 = variables.Variable([0])
     with ops.container("l1"):
       v1 = variables.Variable([1])
       with ops.container("l2"):
         v2 = variables.Variable([2])
         special_v = gen_state_ops.variable(
             shape=[1],
             dtype=dtypes.float32,
             name="VariableInL3",
             container="l3",
             shared_name="")
       v3 = variables.Variable([3])
     v4 = variables.Variable([4])
   self.assertEqual(compat.as_bytes(""), v0.op.get_attr("container"))
   self.assertEqual(compat.as_bytes("l1"), v1.op.get_attr("container"))
   self.assertEqual(compat.as_bytes("l2"), v2.op.get_attr("container"))
   self.assertEqual(compat.as_bytes("l3"), special_v.op.get_attr("container"))
   self.assertEqual(compat.as_bytes("l1"), v3.op.get_attr("container"))
   self.assertEqual(compat.as_bytes(""), v4.op.get_attr("container"))
 def testContainer(self):
   with ops.Graph().as_default():
     v0 = variables.Variable([0])
     with ops.container("l1"):
       v1 = variables.Variable([1])
       with ops.container("l2"):
         v2 = variables.Variable([2])
         special_v = gen_state_ops.variable(
             shape=[1],
             dtype=dtypes.float32,
             name="VariableInL3",
             container="l3",
             shared_name="")
       v3 = variables.Variable([3])
     v4 = variables.Variable([4])
   self.assertEqual(compat.as_bytes(""), v0.op.get_attr("container"))
   self.assertEqual(compat.as_bytes("l1"), v1.op.get_attr("container"))
   self.assertEqual(compat.as_bytes("l2"), v2.op.get_attr("container"))
   self.assertEqual(compat.as_bytes("l3"), special_v.op.get_attr("container"))
   self.assertEqual(compat.as_bytes("l1"), v3.op.get_attr("container"))
   self.assertEqual(compat.as_bytes(""), v4.op.get_attr("container"))
  def testMultipleContainers(self):
    with ops.container("test0"):
      v0 = variables.Variable(1.0, name="v0")
    with ops.container("test1"):
      v1 = variables.Variable(2.0, name="v0")
    server = server_lib.Server.create_local_server()
    sess = session.Session(server.target)
    sess.run(variables.global_variables_initializer())
    self.assertAllEqual(1.0, sess.run(v0))
    self.assertAllEqual(2.0, sess.run(v1))

    # Resets container. Session aborts.
    session.Session.reset(server.target, ["test0"])
    with self.assertRaises(errors_impl.AbortedError):
      sess.run(v1)

    # Connects to the same target. Device memory for the v0 would have
    # been released, so it will be uninitialized. But v1 should still
    # be valid.
    sess = session.Session(server.target)
    with self.assertRaises(errors_impl.FailedPreconditionError):
      sess.run(v0)
    self.assertAllEqual(2.0, sess.run(v1))
    def testMultipleContainers(self):
        with ops.container("test0"):
            v0 = variables.Variable(1.0, name="v0")
        with ops.container("test1"):
            v1 = variables.Variable(2.0, name="v0")
        server = server_lib.Server.create_local_server()
        sess = session.Session(server.target)
        sess.run(variables.global_variables_initializer())
        self.assertAllEqual(1.0, sess.run(v0))
        self.assertAllEqual(2.0, sess.run(v1))

        # Resets container. Session aborts.
        session.Session.reset(server.target, ["test0"])
        with self.assertRaises(errors_impl.AbortedError):
            sess.run(v1)

        # Connects to the same target. Device memory for the v0 would have
        # been released, so it will be uninitialized. But v1 should still
        # be valid.
        sess = session.Session(server.target)
        with self.assertRaises(errors_impl.FailedPreconditionError):
            sess.run(v0)
        self.assertAllEqual(2.0, sess.run(v1))
Exemplo n.º 13
0
        def false_fn():
          # When this branch is created in cond below,
          # the container should begin with 'l1'
          v1 = variables.Variable([1])
          q1 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          with ops.container("l2f"):
            v2 = variables.Variable([2])
            q2 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          v3 = variables.Variable([1])
          q3 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          self.assertEqual(compat.as_bytes("l1"), container(v1))
          self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref))
          self.assertEqual(compat.as_bytes("l2f"), container(v2))
          self.assertEqual(compat.as_bytes("l2f"), container(q2.queue_ref))
          self.assertEqual(compat.as_bytes("l1"), container(v3))
          self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref))

          return constant_op.constant(6.0)
Exemplo n.º 14
0
        def false_fn():
          # When this branch is created in cond below,
          # the container should begin with 'l1'
          v1 = variables.Variable([1])
          q1 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          with ops.container("l2f"):
            v2 = variables.Variable([2])
            q2 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          v3 = variables.Variable([1])
          q3 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          self.assertEqual(compat.as_bytes("l1"), container(v1))
          self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref))
          self.assertEqual(compat.as_bytes("l2f"), container(v2))
          self.assertEqual(compat.as_bytes("l2f"), container(q2.queue_ref))
          self.assertEqual(compat.as_bytes("l1"), container(v3))
          self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref))

          return constant_op.constant(6.0)
Exemplo n.º 15
0
    def testContainer(self):
        """Set containers outside & inside of cond_v2.

    Make sure the containers are set correctly for both variable creation
    (tested by variables.Variable) and for stateful ops (tested by FIFOQueue)
    """
        self.skipTest("b/113048653")
        with ops.Graph().as_default() as g:
            with self.session(graph=g):

                v0 = variables.Variable([0])
                q0 = data_flow_ops.FIFOQueue(1, dtypes.float32)

                def container(node):
                    return node.op.get_attr("container")

                self.assertEqual(compat.as_bytes(""), container(v0))
                self.assertEqual(compat.as_bytes(""), container(q0.queue_ref))

                def true_fn():
                    # When this branch is created in cond below,
                    # the container should begin with 'l1'
                    v1 = variables.Variable([1])
                    q1 = data_flow_ops.FIFOQueue(1, dtypes.float32)

                    with ops.container("l2t"):
                        v2 = variables.Variable([2])
                        q2 = data_flow_ops.FIFOQueue(1, dtypes.float32)

                    v3 = variables.Variable([1])
                    q3 = data_flow_ops.FIFOQueue(1, dtypes.float32)

                    self.assertEqual(compat.as_bytes("l1"), container(v1))
                    self.assertEqual(compat.as_bytes("l1"),
                                     container(q1.queue_ref))
                    self.assertEqual(compat.as_bytes("l2t"), container(v2))
                    self.assertEqual(compat.as_bytes("l2t"),
                                     container(q2.queue_ref))
                    self.assertEqual(compat.as_bytes("l1"), container(v3))
                    self.assertEqual(compat.as_bytes("l1"),
                                     container(q3.queue_ref))

                    return constant_op.constant(2.0)

                def false_fn():
                    # When this branch is created in cond below,
                    # the container should begin with 'l1'
                    v1 = variables.Variable([1])
                    q1 = data_flow_ops.FIFOQueue(1, dtypes.float32)

                    with ops.container("l2f"):
                        v2 = variables.Variable([2])
                        q2 = data_flow_ops.FIFOQueue(1, dtypes.float32)

                    v3 = variables.Variable([1])
                    q3 = data_flow_ops.FIFOQueue(1, dtypes.float32)

                    self.assertEqual(compat.as_bytes("l1"), container(v1))
                    self.assertEqual(compat.as_bytes("l1"),
                                     container(q1.queue_ref))
                    self.assertEqual(compat.as_bytes("l2f"), container(v2))
                    self.assertEqual(compat.as_bytes("l2f"),
                                     container(q2.queue_ref))
                    self.assertEqual(compat.as_bytes("l1"), container(v3))
                    self.assertEqual(compat.as_bytes("l1"),
                                     container(q3.queue_ref))

                    return constant_op.constant(6.0)

                with ops.container("l1"):
                    cnd_true = cond_v2.cond_v2(constant_op.constant(True),
                                               true_fn, false_fn)
                    self.assertEquals(cnd_true.eval(), 2)

                    cnd_false = cond_v2.cond_v2(constant_op.constant(False),
                                                true_fn, false_fn)
                    self.assertEquals(cnd_false.eval(), 6)

                    v4 = variables.Variable([3])
                    q4 = data_flow_ops.FIFOQueue(1, dtypes.float32)
                v5 = variables.Variable([4])
                q5 = data_flow_ops.FIFOQueue(1, dtypes.float32)

            self.assertEqual(compat.as_bytes("l1"), container(v4))
            self.assertEqual(compat.as_bytes("l1"), container(q4.queue_ref))
            self.assertEqual(compat.as_bytes(""), container(v5))
            self.assertEqual(compat.as_bytes(""), container(q5.queue_ref))
Exemplo n.º 16
0
  def testContainer(self):
    """Set containers outside & inside of cond_v2.

    Make sure the containers are set correctly for both variable creation
    (tested by variables.Variable) and for stateful ops (tested by FIFOQueue)
    """
    self.skipTest("b/113048653")
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g):

        v0 = variables.Variable([0])
        q0 = data_flow_ops.FIFOQueue(1, dtypes.float32)

        def container(node):
          return node.op.get_attr("container")

        self.assertEqual(compat.as_bytes(""), container(v0))
        self.assertEqual(compat.as_bytes(""), container(q0.queue_ref))

        def true_fn():
          # When this branch is created in cond below,
          # the container should begin with 'l1'
          v1 = variables.Variable([1])
          q1 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          with ops.container("l2t"):
            v2 = variables.Variable([2])
            q2 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          v3 = variables.Variable([1])
          q3 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          self.assertEqual(compat.as_bytes("l1"), container(v1))
          self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref))
          self.assertEqual(compat.as_bytes("l2t"), container(v2))
          self.assertEqual(compat.as_bytes("l2t"), container(q2.queue_ref))
          self.assertEqual(compat.as_bytes("l1"), container(v3))
          self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref))

          return constant_op.constant(2.0)

        def false_fn():
          # When this branch is created in cond below,
          # the container should begin with 'l1'
          v1 = variables.Variable([1])
          q1 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          with ops.container("l2f"):
            v2 = variables.Variable([2])
            q2 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          v3 = variables.Variable([1])
          q3 = data_flow_ops.FIFOQueue(1, dtypes.float32)

          self.assertEqual(compat.as_bytes("l1"), container(v1))
          self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref))
          self.assertEqual(compat.as_bytes("l2f"), container(v2))
          self.assertEqual(compat.as_bytes("l2f"), container(q2.queue_ref))
          self.assertEqual(compat.as_bytes("l1"), container(v3))
          self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref))

          return constant_op.constant(6.0)

        with ops.container("l1"):
          cnd_true = cond_v2.cond_v2(True, true_fn, false_fn)
          self.assertEquals(cnd_true[0].eval(), 2)

          cnd_false = cond_v2.cond_v2(False, true_fn, false_fn)
          self.assertEquals(cnd_false[0].eval(), 6)

          v4 = variables.Variable([3])
          q4 = data_flow_ops.FIFOQueue(1, dtypes.float32)
        v5 = variables.Variable([4])
        q5 = data_flow_ops.FIFOQueue(1, dtypes.float32)

      self.assertEqual(compat.as_bytes("l1"), container(v4))
      self.assertEqual(compat.as_bytes("l1"), container(q4.queue_ref))
      self.assertEqual(compat.as_bytes(""), container(v5))
      self.assertEqual(compat.as_bytes(""), container(q5.queue_ref))