def test_batch_text_lines(self): gfile.Glob = self._orig_glob filename = self._create_temp_file("A\nB\nC\nD\nE\n") batch_size = 3 queue_capacity = 10 name = "my_batch" with ops.Graph().as_default() as g, self.session(graph=g) as session: inputs = graph_io.read_batch_examples( [filename], batch_size, reader=io_ops.TextLineReader, randomize_input=False, num_epochs=1, queue_capacity=queue_capacity, read_batch_size=10, name=name) self.assertAllEqual((None, ), inputs.get_shape().as_list()) session.run(variables.local_variables_initializer()) coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(session, coord=coord) self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"]) self.assertAllEqual(session.run(inputs), [b"D", b"E"]) with self.assertRaises(errors.OutOfRangeError): session.run(inputs) coord.request_stop() coord.join(threads)
def test_batch_randomized_multiple_globs(self): batch_size = 17 queue_capacity = 1234 name = "my_batch" with ops.Graph().as_default() as g, self.session(graph=g) as sess: inputs = graph_io.read_batch_examples( [_VALID_FILE_PATTERN, _VALID_FILE_PATTERN_2], batch_size, reader=io_ops.TFRecordReader, randomize_input=True, queue_capacity=queue_capacity, name=name) self.assertAllEqual((batch_size, ), inputs.get_shape().as_list()) self.assertEqual("%s:1" % name, inputs.name) file_name_queue_name = "%s/file_name_queue" % name file_names_name = "%s/input" % file_name_queue_name example_queue_name = "%s/random_shuffle_queue" % name op_nodes = test_util.assert_ops_in_graph( { file_names_name: "Const", file_name_queue_name: "FIFOQueueV2", "%s/read/TFRecordReaderV2" % name: "TFRecordReaderV2", example_queue_name: "RandomShuffleQueueV2", name: "QueueDequeueManyV2" }, g) self.assertEqual(set(_FILE_NAMES + _FILE_NAMES_2), set(sess.run(["%s:0" % file_names_name])[0])) self.assertEqual(queue_capacity, op_nodes[example_queue_name].attr["capacity"].i)
def test_read_text_lines_multifile(self): gfile.Glob = self._orig_glob filenames = self._create_sorted_temp_files(["ABC\n", "DEF\nGHK\n"]) batch_size = 1 queue_capacity = 5 name = "my_batch" with ops.Graph().as_default() as g, self.session(graph=g) as session: inputs = graph_io.read_batch_examples( filenames, batch_size, reader=io_ops.TextLineReader, randomize_input=False, num_epochs=1, queue_capacity=queue_capacity, name=name) self.assertAllEqual((None, ), inputs.get_shape().as_list()) session.run(variables.local_variables_initializer()) coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(session, coord=coord) self.assertEqual("%s:1" % name, inputs.name) file_name_queue_name = "%s/file_name_queue" % name file_names_name = "%s/input" % file_name_queue_name example_queue_name = "%s/fifo_queue" % name test_util.assert_ops_in_graph( { file_names_name: "Const", file_name_queue_name: "FIFOQueueV2", "%s/read/TextLineReaderV2" % name: "TextLineReaderV2", example_queue_name: "FIFOQueueV2", name: "QueueDequeueUpToV2" }, g) self.assertAllEqual(session.run(inputs), [b"ABC"]) self.assertAllEqual(session.run(inputs), [b"DEF"]) self.assertAllEqual(session.run(inputs), [b"GHK"]) with self.assertRaises(errors.OutOfRangeError): session.run(inputs) coord.request_stop() coord.join(threads)