Example #1
0
  def test_multiple_workers_with_shared_queue(self):
    gfile.Glob = self._orig_glob
    filenames = self._create_sorted_temp_files([
        "ABC\n", "DEF\n", "GHI\n", "JKL\n", "MNO\n", "PQR\n", "STU\n", "VWX\n",
        "YZ\n"
    ])

    batch_size = 1
    queue_capacity = 5
    name = "my_batch"
    shared_file_name_queue_name = "%s/file_name_queue" % name
    example_queue_name = "%s/fifo_queue" % name
    worker_file_name_queue_name = "%s/file_name_queue/fifo_queue" % name

    server = tf.train.Server.create_local_server()

    with tf.Graph().as_default() as g1, tf.Session(
        server.target, graph=g1) as session:
      _, inputs = _read_keyed_batch_examples_shared_queue(
          filenames,
          batch_size,
          reader=tf.TextLineReader,
          randomize_input=False,
          num_epochs=1,
          queue_capacity=queue_capacity,
          name=name)
      session.run(tf.initialize_local_variables())

      # Run the three queues once manually.
      self._run_queue(shared_file_name_queue_name, session)
      self._run_queue(worker_file_name_queue_name, session)
      self._run_queue(example_queue_name, session)

      self.assertAllEqual(session.run(inputs), [b"ABC"])

      # Run the worker and the example queue.
      self._run_queue(worker_file_name_queue_name, session)
      self._run_queue(example_queue_name, session)

      self.assertAllEqual(session.run(inputs), [b"DEF"])

    with tf.Graph().as_default() as g2, tf.Session(
        server.target, graph=g2) as session:
      _, inputs = _read_keyed_batch_examples_shared_queue(
          filenames,
          batch_size,
          reader=tf.TextLineReader,
          randomize_input=False,
          num_epochs=1,
          queue_capacity=queue_capacity,
          name=name)

      # Run the worker and the example queue.
      self._run_queue(worker_file_name_queue_name, session)
      self._run_queue(example_queue_name, session)

      self.assertAllEqual(session.run(inputs), [b"GHI"])

    self.assertTrue(g1 is not g2)
Example #2
0
    def test_multiple_workers_with_shared_queue(self):
        gfile.Glob = self._orig_glob
        filenames = self._create_sorted_temp_files([
            "ABC\n", "DEF\n", "GHI\n", "JKL\n", "MNO\n", "PQR\n", "STU\n",
            "VWX\n", "YZ\n"
        ])

        batch_size = 1
        queue_capacity = 5
        name = "my_batch"
        shared_file_name_queue_name = "%s/file_name_queue" % name
        example_queue_name = "%s/fifo_queue" % name
        worker_file_name_queue_name = "%s/file_name_queue/fifo_queue" % name

        server = tf.train.Server.create_local_server()

        with tf.Graph().as_default() as g1, tf.Session(server.target,
                                                       graph=g1) as session:
            _, inputs = _read_keyed_batch_examples_shared_queue(
                filenames,
                batch_size,
                reader=tf.TextLineReader,
                randomize_input=False,
                num_epochs=1,
                queue_capacity=queue_capacity,
                name=name)
            session.run(tf.initialize_local_variables())

            # Run the three queues once manually.
            self._run_queue(shared_file_name_queue_name, session)
            self._run_queue(worker_file_name_queue_name, session)
            self._run_queue(example_queue_name, session)

            self.assertAllEqual(session.run(inputs), [b"ABC"])

            # Run the worker and the example queue.
            self._run_queue(worker_file_name_queue_name, session)
            self._run_queue(example_queue_name, session)

            self.assertAllEqual(session.run(inputs), [b"DEF"])

        with tf.Graph().as_default() as g2, tf.Session(server.target,
                                                       graph=g2) as session:
            _, inputs = _read_keyed_batch_examples_shared_queue(
                filenames,
                batch_size,
                reader=tf.TextLineReader,
                randomize_input=False,
                num_epochs=1,
                queue_capacity=queue_capacity,
                name=name)

            # Run the worker and the example queue.
            self._run_queue(worker_file_name_queue_name, session)
            self._run_queue(example_queue_name, session)

            self.assertAllEqual(session.run(inputs), [b"GHI"])

        self.assertTrue(g1 is not g2)
Example #3
0
    def test_read_text_lines_multifile_with_shared_queue(self):
        gfile.Glob = self._orig_glob
        filenames = self._create_sorted_temp_files(["ABC\n", "DEF\nGHK\n"])

        batch_size = 1
        queue_capacity = 5
        name = "my_batch"

        with ops.Graph().as_default() as g, self.test_session(
                graph=g) as session:
            keys, inputs = _read_keyed_batch_examples_shared_queue(
                filenames,
                batch_size,
                reader=io_ops.TextLineReader,
                randomize_input=False,
                num_epochs=1,
                queue_capacity=queue_capacity,
                name=name)
            self.assertAllEqual((None, ), keys.get_shape().as_list())
            self.assertAllEqual((None, ), inputs.get_shape().as_list())
            session.run([
                variables.local_variables_initializer(),
                variables.global_variables_initializer()
            ])

            coord = coordinator.Coordinator()
            threads = queue_runner_impl.start_queue_runners(session,
                                                            coord=coord)

            self.assertEqual("%s:1" % name, inputs.name)
            example_queue_name = "%s/fifo_queue" % name
            worker_file_name_queue_name = "%s/file_name_queue/fifo_queue" % name
            test_util.assert_ops_in_graph(
                {
                    "%s/read/TextLineReaderV2" % name: "TextLineReaderV2",
                    example_queue_name: "FIFOQueueV2",
                    worker_file_name_queue_name: "FIFOQueueV2",
                    name: "QueueDequeueUpToV2"
                }, g)

            self.assertAllEqual(session.run(inputs), [b"ABC"])
            self.assertAllEqual(session.run(inputs), [b"DEF"])
            self.assertAllEqual(session.run(inputs), [b"GHK"])
            with self.assertRaises(errors.OutOfRangeError):
                session.run(inputs)

            coord.request_stop()
            coord.join(threads)
Example #4
0
    def test_read_text_lines_multifile_with_shared_queue(self):
        gfile.Glob = self._orig_glob
        filenames = self._create_sorted_temp_files(["ABC\n", "DEF\nGHK\n"])

        batch_size = 1
        queue_capacity = 5
        name = "my_batch"

        with tf.Graph().as_default() as g, self.test_session(
                graph=g) as session:
            _, inputs = _read_keyed_batch_examples_shared_queue(
                filenames,
                batch_size,
                reader=tf.TextLineReader,
                randomize_input=False,
                num_epochs=1,
                queue_capacity=queue_capacity,
                name=name)
            session.run(tf.initialize_local_variables())

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(session, coord=coord)

            self.assertEqual("%s:1" % name, inputs.name)
            shared_file_name_queue_name = "%s/file_name_queue" % name
            file_names_name = "%s/input" % shared_file_name_queue_name
            example_queue_name = "%s/fifo_queue" % name
            worker_file_name_queue_name = "%s/file_name_queue/fifo_queue" % name
            test_util.assert_ops_in_graph(
                {
                    file_names_name: "Const",
                    shared_file_name_queue_name: "FIFOQueue",
                    "%s/read/TextLineReader" % name: "TextLineReader",
                    example_queue_name: "FIFOQueue",
                    worker_file_name_queue_name: "FIFOQueue",
                    name: "QueueDequeueUpTo"
                }, g)

            self.assertAllEqual(session.run(inputs), [b"ABC"])
            self.assertAllEqual(session.run(inputs), [b"DEF"])
            self.assertAllEqual(session.run(inputs), [b"GHK"])
            with self.assertRaises(errors.OutOfRangeError):
                session.run(inputs)

            coord.request_stop()
            coord.join(threads)
Example #5
0
  def test_read_text_lines_multifile_with_shared_queue(self):
    gfile.Glob = self._orig_glob
    filenames = self._create_sorted_temp_files(["ABC\n", "DEF\nGHK\n"])

    batch_size = 1
    queue_capacity = 5
    name = "my_batch"

    with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
      keys, inputs = _read_keyed_batch_examples_shared_queue(
          filenames,
          batch_size,
          reader=tf.TextLineReader,
          randomize_input=False,
          num_epochs=1,
          queue_capacity=queue_capacity,
          name=name)
      self.assertAllEqual((None,), keys.get_shape().as_list())
      self.assertAllEqual((None,), inputs.get_shape().as_list())
      session.run(tf.local_variables_initializer())

      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(session, coord=coord)

      self.assertEqual("%s:1" % name, inputs.name)
      shared_file_name_queue_name = "%s/file_name_queue" % name
      file_names_name = "%s/input" % shared_file_name_queue_name
      example_queue_name = "%s/fifo_queue" % name
      worker_file_name_queue_name = "%s/file_name_queue/fifo_queue" % name
      test_util.assert_ops_in_graph({
          file_names_name: "Const",
          shared_file_name_queue_name: "FIFOQueue",
          "%s/read/TextLineReader" % name: "TextLineReader",
          example_queue_name: "FIFOQueue",
          worker_file_name_queue_name: "FIFOQueue",
          name: "QueueDequeueUpTo"
      }, g)

      self.assertAllEqual(session.run(inputs), [b"ABC"])
      self.assertAllEqual(session.run(inputs), [b"DEF"])
      self.assertAllEqual(session.run(inputs), [b"GHK"])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
      coord.join(threads)