def test_batch_text_lines(self):
    gfile.Glob = self._orig_glob
    filename = self._create_temp_file("A\nB\nC\nD\nE\n")

    batch_size = 3
    queue_capacity = 10
    name = "my_batch"

    with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
      inputs = graph_io.read_batch_examples(
          [filename],
          batch_size,
          reader=io_ops.TextLineReader,
          randomize_input=False,
          num_epochs=1,
          queue_capacity=queue_capacity,
          read_batch_size=10,
          name=name)
      self.assertAllEqual((None,), inputs.get_shape().as_list())
      session.run(variables.local_variables_initializer())

      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(session, coord=coord)

      self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"])
      self.assertAllEqual(session.run(inputs), [b"D", b"E"])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
      coord.join(threads)
  def test_batch_randomized(self):
    batch_size = 17
    queue_capacity = 1234
    name = "my_batch"

    with ops.Graph().as_default() as g, self.test_session(graph=g) as sess:
      inputs = graph_io.read_batch_examples(
          _VALID_FILE_PATTERN,
          batch_size,
          reader=io_ops.TFRecordReader,
          randomize_input=True,
          queue_capacity=queue_capacity,
          name=name)
      self.assertAllEqual((batch_size,), inputs.get_shape().as_list())
      self.assertEqual("%s:1" % name, inputs.name)
      file_name_queue_name = "%s/file_name_queue" % name
      file_names_name = "%s/input" % file_name_queue_name
      example_queue_name = "%s/random_shuffle_queue" % name
      op_nodes = test_util.assert_ops_in_graph({
          file_names_name: "Const",
          file_name_queue_name: "FIFOQueue",
          "%s/read/TFRecordReader" % name: "TFRecordReader",
          example_queue_name: "RandomShuffleQueue",
          name: "QueueDequeueMany"
      }, g)
      self.assertEqual(
          set(_FILE_NAMES), set(sess.run(["%s:0" % file_names_name])[0]))
      self.assertEqual(queue_capacity,
                       op_nodes[example_queue_name].attr["capacity"].i)
Exemple #3
0
    def test_batch_text_lines(self):
        gfile.Glob = self._orig_glob
        filename = self._create_temp_file("A\nB\nC\nD\nE\n")

        batch_size = 3
        queue_capacity = 10
        name = "my_batch"

        with ops.Graph().as_default() as g, self.test_session(
                graph=g) as session:
            inputs = graph_io.read_batch_examples(
                [filename],
                batch_size,
                reader=io_ops.TextLineReader,
                randomize_input=False,
                num_epochs=1,
                queue_capacity=queue_capacity,
                read_batch_size=10,
                name=name)
            self.assertAllEqual((None, ), inputs.get_shape().as_list())
            session.run(variables.local_variables_initializer())

            coord = coordinator.Coordinator()
            threads = queue_runner_impl.start_queue_runners(session,
                                                            coord=coord)

            self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"])
            self.assertAllEqual(session.run(inputs), [b"D", b"E"])
            with self.assertRaises(errors.OutOfRangeError):
                session.run(inputs)

            coord.request_stop()
            coord.join(threads)
Exemple #4
0
    def test_batch_randomized(self):
        batch_size = 17
        queue_capacity = 1234
        name = "my_batch"

        with ops.Graph().as_default() as g, self.test_session(graph=g) as sess:
            inputs = graph_io.read_batch_examples(
                _VALID_FILE_PATTERN,
                batch_size,
                reader=io_ops.TFRecordReader,
                randomize_input=True,
                queue_capacity=queue_capacity,
                name=name)
            self.assertAllEqual((batch_size, ), inputs.get_shape().as_list())
            self.assertEqual("%s:1" % name, inputs.name)
            file_name_queue_name = "%s/file_name_queue" % name
            file_names_name = "%s/input" % file_name_queue_name
            example_queue_name = "%s/random_shuffle_queue" % name
            op_nodes = test_util.assert_ops_in_graph(
                {
                    file_names_name: "Const",
                    file_name_queue_name: "FIFOQueueV2",
                    "%s/read/TFRecordReaderV2" % name: "TFRecordReaderV2",
                    example_queue_name: "RandomShuffleQueueV2",
                    name: "QueueDequeueManyV2"
                }, g)
            self.assertEqual(set(_FILE_NAMES),
                             set(sess.run(["%s:0" % file_names_name])[0]))
            self.assertEqual(queue_capacity,
                             op_nodes[example_queue_name].attr["capacity"].i)
Exemple #5
0
    def test_read_text_lines_multifile(self):
        gfile.Glob = self._orig_glob
        filenames = self._create_sorted_temp_files(["ABC\n", "DEF\nGHK\n"])

        batch_size = 1
        queue_capacity = 5
        name = "my_batch"

        with ops.Graph().as_default() as g, self.test_session(
                graph=g) as session:
            inputs = graph_io.read_batch_examples(
                filenames,
                batch_size,
                reader=io_ops.TextLineReader,
                randomize_input=False,
                num_epochs=1,
                queue_capacity=queue_capacity,
                name=name)
            self.assertAllEqual((None, ), inputs.get_shape().as_list())
            session.run(variables.local_variables_initializer())

            coord = coordinator.Coordinator()
            threads = queue_runner_impl.start_queue_runners(session,
                                                            coord=coord)

            self.assertEqual("%s:1" % name, inputs.name)
            file_name_queue_name = "%s/file_name_queue" % name
            file_names_name = "%s/input" % file_name_queue_name
            example_queue_name = "%s/fifo_queue" % name
            test_util.assert_ops_in_graph(
                {
                    file_names_name: "Const",
                    file_name_queue_name: "FIFOQueueV2",
                    "%s/read/TextLineReaderV2" % name: "TextLineReaderV2",
                    example_queue_name: "FIFOQueueV2",
                    name: "QueueDequeueUpToV2"
                }, g)

            self.assertAllEqual(session.run(inputs), [b"ABC"])
            self.assertAllEqual(session.run(inputs), [b"DEF"])
            self.assertAllEqual(session.run(inputs), [b"GHK"])
            with self.assertRaises(errors.OutOfRangeError):
                session.run(inputs)

            coord.request_stop()
            coord.join(threads)
  def test_read_text_lines_multifile(self):
    gfile.Glob = self._orig_glob
    filenames = self._create_sorted_temp_files(["ABC\n", "DEF\nGHK\n"])

    batch_size = 1
    queue_capacity = 5
    name = "my_batch"

    with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
      inputs = graph_io.read_batch_examples(
          filenames,
          batch_size,
          reader=io_ops.TextLineReader,
          randomize_input=False,
          num_epochs=1,
          queue_capacity=queue_capacity,
          name=name)
      self.assertAllEqual((None,), inputs.get_shape().as_list())
      session.run(variables.local_variables_initializer())

      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(session, coord=coord)

      self.assertEqual("%s:1" % name, inputs.name)
      file_name_queue_name = "%s/file_name_queue" % name
      file_names_name = "%s/input" % file_name_queue_name
      example_queue_name = "%s/fifo_queue" % name
      test_util.assert_ops_in_graph({
          file_names_name: "Const",
          file_name_queue_name: "FIFOQueue",
          "%s/read/TextLineReader" % name: "TextLineReader",
          example_queue_name: "FIFOQueue",
          name: "QueueDequeueUpTo"
      }, g)

      self.assertAllEqual(session.run(inputs), [b"ABC"])
      self.assertAllEqual(session.run(inputs), [b"DEF"])
      self.assertAllEqual(session.run(inputs), [b"GHK"])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
      coord.join(threads)