コード例 #1
0
    def test_batch_text_lines(self):
        gfile.Glob = self._orig_glob
        filename = self._create_temp_file("A\nB\nC\nD\nE\n")

        batch_size = 3
        queue_capacity = 10
        name = "my_batch"

        with ops.Graph().as_default() as g, self.session(graph=g) as session:
            inputs = graph_io.read_batch_examples(
                [filename],
                batch_size,
                reader=io_ops.TextLineReader,
                randomize_input=False,
                num_epochs=1,
                queue_capacity=queue_capacity,
                read_batch_size=10,
                name=name)
            self.assertAllEqual((None, ), inputs.get_shape().as_list())
            session.run(variables.local_variables_initializer())

            coord = coordinator.Coordinator()
            threads = queue_runner_impl.start_queue_runners(session,
                                                            coord=coord)

            self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"])
            self.assertAllEqual(session.run(inputs), [b"D", b"E"])
            with self.assertRaises(errors.OutOfRangeError):
                session.run(inputs)

            coord.request_stop()
            coord.join(threads)
コード例 #2
0
    def test_batch_randomized_multiple_globs(self):
        batch_size = 17
        queue_capacity = 1234
        name = "my_batch"

        with ops.Graph().as_default() as g, self.session(graph=g) as sess:
            inputs = graph_io.read_batch_examples(
                [_VALID_FILE_PATTERN, _VALID_FILE_PATTERN_2],
                batch_size,
                reader=io_ops.TFRecordReader,
                randomize_input=True,
                queue_capacity=queue_capacity,
                name=name)
            self.assertAllEqual((batch_size, ), inputs.get_shape().as_list())
            self.assertEqual("%s:1" % name, inputs.name)
            file_name_queue_name = "%s/file_name_queue" % name
            file_names_name = "%s/input" % file_name_queue_name
            example_queue_name = "%s/random_shuffle_queue" % name
            op_nodes = test_util.assert_ops_in_graph(
                {
                    file_names_name: "Const",
                    file_name_queue_name: "FIFOQueueV2",
                    "%s/read/TFRecordReaderV2" % name: "TFRecordReaderV2",
                    example_queue_name: "RandomShuffleQueueV2",
                    name: "QueueDequeueManyV2"
                }, g)
            self.assertEqual(set(_FILE_NAMES + _FILE_NAMES_2),
                             set(sess.run(["%s:0" % file_names_name])[0]))
            self.assertEqual(queue_capacity,
                             op_nodes[example_queue_name].attr["capacity"].i)
コード例 #3
0
    def test_read_text_lines_multifile(self):
        gfile.Glob = self._orig_glob
        filenames = self._create_sorted_temp_files(["ABC\n", "DEF\nGHK\n"])

        batch_size = 1
        queue_capacity = 5
        name = "my_batch"

        with ops.Graph().as_default() as g, self.session(graph=g) as session:
            inputs = graph_io.read_batch_examples(
                filenames,
                batch_size,
                reader=io_ops.TextLineReader,
                randomize_input=False,
                num_epochs=1,
                queue_capacity=queue_capacity,
                name=name)
            self.assertAllEqual((None, ), inputs.get_shape().as_list())
            session.run(variables.local_variables_initializer())

            coord = coordinator.Coordinator()
            threads = queue_runner_impl.start_queue_runners(session,
                                                            coord=coord)

            self.assertEqual("%s:1" % name, inputs.name)
            file_name_queue_name = "%s/file_name_queue" % name
            file_names_name = "%s/input" % file_name_queue_name
            example_queue_name = "%s/fifo_queue" % name
            test_util.assert_ops_in_graph(
                {
                    file_names_name: "Const",
                    file_name_queue_name: "FIFOQueueV2",
                    "%s/read/TextLineReaderV2" % name: "TextLineReaderV2",
                    example_queue_name: "FIFOQueueV2",
                    name: "QueueDequeueUpToV2"
                }, g)

            self.assertAllEqual(session.run(inputs), [b"ABC"])
            self.assertAllEqual(session.run(inputs), [b"DEF"])
            self.assertAllEqual(session.run(inputs), [b"GHK"])
            with self.assertRaises(errors.OutOfRangeError):
                session.run(inputs)

            coord.request_stop()
            coord.join(threads)