def _apply_transform(self, transform_input): filename_queue = input_ops.string_input_producer(self._work_units, shuffle=self.shuffle, seed=self._seed) if self.shuffle: queue = data_flow_ops.RandomShuffleQueue( capacity=self.queue_capacity, min_after_dequeue=self.min_after_dequeue, dtypes=[dtypes.string, dtypes.string], shapes=[[], []], seed=self.seed) else: queue = data_flow_ops.FIFOQueue(capacity=self.queue_capacity, dtypes=[dtypes.string, dtypes.string], shapes=[[], []]) enqueue_ops = [] for _ in range(self.num_threads): reader = self._reader_cls(**self._reader_kwargs) enqueue_ops.append(queue.enqueue(reader.read(filename_queue))) runner = queue_runner.QueueRunner(queue, enqueue_ops) queue_runner.add_queue_runner(runner) dequeued = queue.dequeue_many(self.batch_size) # pylint: disable=not-callable return self.return_type(*dequeued)
def _apply_transform(self, transform_input): filename_queue = input_ops.string_input_producer(self.work_units, num_epochs=self.num_epochs, shuffle=self.shuffle, seed=self.seed) reader_ops = [] for _ in range(self.num_threads): reader = self._reader_cls(**self._reader_kwargs) reader_ops.append(reader.read_up_to(filename_queue, self.enqueue_size)) if self.shuffle: dequeued = input_ops.shuffle_batch_join( reader_ops, self.batch_size, capacity=self.queue_capacity, min_after_dequeue=self.min_after_dequeue, seed=self.seed, enqueue_many=True, shared_name=None, name=None) else: dequeued = input_ops.batch_join(reader_ops, self.batch_size, capacity=self.queue_capacity, enqueue_many=True, dynamic_pad=False, shared_name=None, name=None) # pylint: disable=not-callable return self.return_type(*dequeued)
def parallel_read(data_sources, reader_class, num_epochs=None, num_readers=4, reader_kwargs=None, shuffle=True, dtypes=None, capacity=256, min_after_dequeue=128): """Reads multiple records in parallel from data_sources using n readers. It uses a ParallelReader to read from multiple files in parallel using multiple readers created using `reader_class` with `reader_kwargs'. If shuffle is True the common_queue would be a RandomShuffleQueue otherwise it would be a FIFOQueue. Usage: data_sources = ['path_to/train*'] key, value = parallel_read(data_sources, tf.CSVReader, num_readers=4) Args: data_sources: a list/tuple of files or the location of the data, i.e. /cns/../train@128, /cns/.../train* or /tmp/.../train* reader_class: one of the io_ops.ReaderBase subclasses ex: TFRecordReader num_epochs: The number of times each data source is read. If left as None, the data will be cycled through indefinitely. num_readers: a integer, number of Readers to create. reader_kwargs: an optional dict, of kwargs for the reader. shuffle: boolean, wether should shuffle the files and the records by using RandomShuffleQueue as common_queue. dtypes: A list of types. The length of dtypes must equal the number of elements in each record. If it is None it will default to [tf.string, tf.string] for (key, value). capacity: integer, capacity of the common_queue. min_after_dequeue: integer, minimum number of records in the common_queue after dequeue. Needed for a good shuffle. Returns: key, value: a tuple of keys and values from the data_source. """ data_files = get_data_files(data_sources) with ops.name_scope('parallel_read'): filename_queue = tf_input.string_input_producer(data_files, num_epochs=num_epochs, shuffle=shuffle) dtypes = dtypes or [tf_dtypes.string, tf_dtypes.string] if shuffle: common_queue = data_flow_ops.RandomShuffleQueue( capacity=capacity, min_after_dequeue=min_after_dequeue, dtypes=dtypes) else: common_queue = data_flow_ops.FIFOQueue(capacity=capacity, dtypes=dtypes) return ParallelReader(reader_class, common_queue, num_readers=num_readers, reader_kwargs=reader_kwargs).read(filename_queue)
def _apply_transform(self, transform_input, **kwargs): filename_queue = input_ops.string_input_producer( self.work_units, num_epochs=kwargs.get("num_epochs"), shuffle=self.shuffle, seed=self.seed) reader_ops = [] for _ in range(self.num_threads): reader = self._reader_cls(**self._reader_kwargs) reader_ops.append( reader.read_up_to(filename_queue, self.enqueue_size)) if self.shuffle: dequeued = input_ops.shuffle_batch_join( reader_ops, self.batch_size, capacity=self.queue_capacity, min_after_dequeue=self.min_after_dequeue, seed=self.seed, enqueue_many=True, shared_name=None, name=None) else: dequeued = input_ops.batch_join(reader_ops, self.batch_size, capacity=self.queue_capacity, enqueue_many=True, dynamic_pad=False, shared_name=None, name=None) # pylint: disable=not-callable return self.return_type(*dequeued)
def testManagedMainErrorTwoQueues(self): # Tests that the supervisor correctly raises a main loop # error even when using multiple queues for input. logdir = self._test_dir("managed_main_error_two_queues") os.makedirs(logdir) data_path = self._csv_data(logdir) with self.assertRaisesRegexp(RuntimeError, "fail at step 3"): with ops.Graph().as_default(): # Create an input pipeline that reads the file 3 times. filename_queue = input_lib.string_input_producer([data_path], num_epochs=3) reader = io_ops.TextLineReader() _, csv = reader.read(filename_queue) rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]]) shuff_rec = input_lib.shuffle_batch(rec, 1, 6, 4) sv = supervisor.Supervisor(logdir=logdir) with sv.managed_session("") as sess: for step in range(9): if sv.should_stop(): break elif step == 3: raise RuntimeError("fail at step 3") else: sess.run(shuff_rec)
def single_pass_read(data_sources, reader_class, reader_kwargs=None, scope=None): """Reads sequentially the data_sources using the reader, doing a single pass. Args: data_sources: a list/tuple of files or the location of the data, i.e. /path/to/train@128, /path/to/train* or /tmp/.../train* reader_class: one of the io_ops.ReaderBase subclasses ex: TFRecordReader. reader_kwargs: an optional dict, of kwargs for the reader. scope: Optional name scope for the ops. Returns: key, value: a tuple of keys and values from the data_source. """ data_files = get_data_files(data_sources) with ops.name_scope(scope, 'single_pass_read'): filename_queue = tf_input.string_input_producer(data_files, num_epochs=1, shuffle=False, capacity=1, name='filenames') reader_kwargs = reader_kwargs or {} return reader_class(**reader_kwargs).read(filename_queue)
def _apply_transform(self, transform_input): filename_queue = input_ops.string_input_producer(self._work_units, shuffle=self.shuffle, seed=self._seed) if self.shuffle: queue = data_flow_ops.RandomShuffleQueue( capacity=self.queue_capacity, min_after_dequeue=self.min_after_dequeue, dtypes=[dtypes.string, dtypes.string], shapes=[[], []], seed=self.seed) else: queue = data_flow_ops.FIFOQueue( capacity=self.queue_capacity, dtypes=[dtypes.string, dtypes.string], shapes=[[], []]) enqueue_ops = [] for _ in range(self.num_threads): reader = self._reader_cls(**self._reader_kwargs) enqueue_ops.append(queue.enqueue(reader.read(filename_queue))) runner = queue_runner.QueueRunner(queue, enqueue_ops) queue_runner.add_queue_runner(runner) dequeued = queue.dequeue_many(self.batch_size) # pylint: disable=not-callable return self.return_type(*dequeued)
def parallel_read(data_sources, reader_class, num_epochs=None, num_readers=4, reader_kwargs=None, shuffle=True, dtypes=None, capacity=256, min_after_dequeue=128): """Reads multiple records in parallel from data_sources using n readers. It uses a ParallelReader to read from multiple files in parallel using multiple readers created using `reader_class` with `reader_kwargs'. If shuffle is True the common_queue would be a RandomShuffleQueue otherwise it would be a FIFOQueue. Usage: data_sources = ['path_to/train*'] key, value = parallel_read(data_sources, tf.CSVReader, num_readers=4) Args: data_sources: a list/tuple of files or the location of the data, i.e. /cns/../train@128, /cns/.../train* or /tmp/.../train* reader_class: one of the io_ops.ReaderBase subclasses ex: TFRecordReader num_epochs: The number of times each data source is read. If left as None, the data will be cycled through indefinitely. num_readers: a integer, number of Readers to create. reader_kwargs: an optional dict, of kwargs for the reader. shuffle: boolean, wether should shuffle the files and the records by using RandomShuffleQueue as common_queue. dtypes: A list of types. The length of dtypes must equal the number of elements in each record. If it is None it will default to [tf.string, tf.string] for (key, value). capacity: integer, capacity of the common_queue. min_after_dequeue: integer, minimum number of records in the common_queue after dequeue. Needed for a good shuffle. Returns: key, value: a tuple of keys and values from the data_source. """ data_files = get_data_files(data_sources) with ops.name_scope('parallel_read'): filename_queue = tf_input.string_input_producer( data_files, num_epochs=num_epochs, shuffle=shuffle) dtypes = dtypes or [tf_dtypes.string, tf_dtypes.string] if shuffle: common_queue = data_flow_ops.RandomShuffleQueue( capacity=capacity, min_after_dequeue=min_after_dequeue, dtypes=dtypes) else: common_queue = data_flow_ops.FIFOQueue(capacity=capacity, dtypes=dtypes) return ParallelReader(reader_class, common_queue, num_readers=num_readers, reader_kwargs=reader_kwargs).read(filename_queue)
def _verify_read_up_to_out(self, shared_queue): with self.cached_session(): num_files = 3 num_records_per_file = 7 tfrecord_paths = test_utils.create_tfrecord_files( tempfile.mkdtemp(), num_files=num_files, num_records_per_file=num_records_per_file) p_reader = parallel_reader.ParallelReader( io_ops.TFRecordReader, shared_queue, num_readers=5) data_files = parallel_reader.get_data_files(tfrecord_paths) filename_queue = input_lib.string_input_producer( data_files, num_epochs=1) key, value = p_reader.read_up_to(filename_queue, 4) count0 = 0 count1 = 0 count2 = 0 all_keys_count = 0 all_values_count = 0 sv = supervisor.Supervisor(logdir=tempfile.mkdtemp()) with sv.prepare_or_wait_for_session() as sess: sv.start_queue_runners(sess) while True: try: current_keys, current_values = sess.run([key, value]) self.assertEqual(len(current_keys), len(current_values)) all_keys_count += len(current_keys) all_values_count += len(current_values) for current_key in current_keys: if '0-of-3' in str(current_key): count0 += 1 if '1-of-3' in str(current_key): count1 += 1 if '2-of-3' in str(current_key): count2 += 1 except errors_impl.OutOfRangeError: break self.assertEqual(count0, num_records_per_file) self.assertEqual(count1, num_records_per_file) self.assertEqual(count2, num_records_per_file) self.assertEqual( all_keys_count, num_files * num_records_per_file) self.assertEqual(all_values_count, all_keys_count) self.assertEqual( count0 + count1 + count2, all_keys_count)
def _get_shared_file_name_queue(file_names, shuffle, num_epochs, name): # Creating a dummy variable so we can put the shared queue in ps if there is # a PS and in the worker otherwise. TODO(rohanj): Figure out how to place an # op on PS without this hack dummy_var = var_ops.Variable(initial_value=0, name='queue_placement_var') with ops.device(dummy_var.device): shared_file_name_queue = input_ops.string_input_producer( constant_op.constant(file_names, name='input'), shuffle=shuffle, num_epochs=num_epochs, name=name, shared_name=name) return shared_file_name_queue
def _verify_read_up_to_out(self, shared_queue): with self.test_session(): num_files = 3 num_records_per_file = 7 tfrecord_paths = test_utils.create_tfrecord_files( self.get_temp_dir(), num_files=num_files, num_records_per_file=num_records_per_file) p_reader = parallel_reader.ParallelReader( io_ops.TFRecordReader, shared_queue, num_readers=5) data_files = parallel_reader.get_data_files(tfrecord_paths) filename_queue = input_lib.string_input_producer(data_files, num_epochs=1) key, value = p_reader.read_up_to(filename_queue, 4) count0 = 0 count1 = 0 count2 = 0 all_keys_count = 0 all_values_count = 0 sv = supervisor.Supervisor(logdir=self.get_temp_dir()) with sv.prepare_or_wait_for_session() as sess: sv.start_queue_runners(sess) while True: try: current_keys, current_values = sess.run([key, value]) self.assertEquals(len(current_keys), len(current_values)) all_keys_count += len(current_keys) all_values_count += len(current_values) for current_key in current_keys: if '0-of-3' in str(current_key): count0 += 1 if '1-of-3' in str(current_key): count1 += 1 if '2-of-3' in str(current_key): count2 += 1 except errors_impl.OutOfRangeError: break self.assertEquals(count0, num_records_per_file) self.assertEquals(count1, num_records_per_file) self.assertEquals(count2, num_records_per_file) self.assertEquals( all_keys_count, num_files * num_records_per_file) self.assertEquals(all_values_count, all_keys_count) self.assertEquals( count0 + count1 + count2, all_keys_count)
def _get_shared_file_name_queue(file_names, shuffle, num_epochs, name): # Creating a dummy variable so we can put the shared queue in ps if there is # a PS and in the worker otherwise. TODO(rohanj): Figure out how to place an # op on PS without this hack dummy_var = var_ops.Variable(initial_value=0, name='queue_placement_var') with ops.device(dummy_var.device): shared_file_name_queue = input_ops.string_input_producer( constant_op.constant( file_names, name='input'), shuffle=shuffle, num_epochs=num_epochs, name=name, shared_name=name) return shared_file_name_queue
def single_pass_read(data_sources, reader_class, reader_kwargs=None): """Reads sequentially the data_sources using the reader, doing a single pass. Args: data_sources: a list/tuple of files or the location of the data, i.e. /path/to/train@128, /path/to/train* or /tmp/.../train* reader_class: one of the io_ops.ReaderBase subclasses ex: TFRecordReader. reader_kwargs: an optional dict, of kwargs for the reader. Returns: key, value: a tuple of keys and values from the data_source. """ data_files = get_data_files(data_sources) with ops.name_scope("single_pass_read"): filename_queue = tf_input.string_input_producer(data_files, num_epochs=1, shuffle=False, capacity=1) reader_kwargs = reader_kwargs or {} return reader_class(**reader_kwargs).read(filename_queue)
def testManagedEndOfInputOneQueue(self): # Tests that the supervisor finishes without an error when using # a fixed number of epochs, reading from a single queue. logdir = self._test_dir("managed_end_of_input_one_queue") os.makedirs(logdir) data_path = self._csv_data(logdir) with ops.Graph().as_default(): # Create an input pipeline that reads the file 3 times. filename_queue = input_lib.string_input_producer( [data_path], num_epochs=3) reader = io_ops.TextLineReader() _, csv = reader.read(filename_queue) rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]]) sv = supervisor.Supervisor(logdir=logdir) with sv.managed_session("") as sess: while not sv.should_stop(): sess.run(rec)
def testManagedEndOfInputOneQueue(self): # Tests that the supervisor finishes without an error when using # a fixed number of epochs, reading from a single queue. logdir = self._test_dir("managed_end_of_input_one_queue") os.makedirs(logdir) data_path = self._csv_data(logdir) with ops.Graph().as_default(): # Create an input pipeline that reads the file 3 times. filename_queue = input_lib.string_input_producer([data_path], num_epochs=3) reader = io_ops.TextLineReader() _, csv = reader.read(filename_queue) rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]]) sv = supervisor.Supervisor(logdir=logdir) with sv.managed_session("") as sess: while not sv.should_stop(): sess.run(rec)
def testReadFromSameFile(self): with self.cached_session() as sess: reader1 = io_ops.LMDBReader(name="test_read_from_same_file1") reader2 = io_ops.LMDBReader(name="test_read_from_same_file2") filename_queue = input_lib.string_input_producer( [self.db_path], num_epochs=None) key1, value1 = reader1.read(filename_queue) key2, value2 = reader2.read(filename_queue) coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(sess, coord=coord) for _ in range(3): for _ in range(10): k1, v1, k2, v2 = self.evaluate([key1, value1, key2, value2]) self.assertAllEqual(compat.as_bytes(k1), compat.as_bytes(k2)) self.assertAllEqual(compat.as_bytes(v1), compat.as_bytes(v2)) coord.request_stop() coord.join(threads)
def testReadFromSameFile(self): with self.test_session() as sess: reader1 = io_ops.LMDBReader(name="test_read_from_same_file1") reader2 = io_ops.LMDBReader(name="test_read_from_same_file2") filename_queue = input_lib.string_input_producer( [self.db_path], num_epochs=None) key1, value1 = reader1.read(filename_queue) key2, value2 = reader2.read(filename_queue) coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(sess, coord=coord) for _ in range(3): for _ in range(10): k1, v1, k2, v2 = sess.run([key1, value1, key2, value2]) self.assertAllEqual(compat.as_bytes(k1), compat.as_bytes(k2)) self.assertAllEqual(compat.as_bytes(v1), compat.as_bytes(v2)) coord.request_stop() coord.join(threads)
def testReadFromFileRepeatedly(self): with self.cached_session() as sess: reader = io_ops.LMDBReader(name="test_read_from_file_repeated") filename_queue = input_lib.string_input_producer( [self.db_path], num_epochs=None) key, value = reader.read(filename_queue) coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(sess, coord=coord) # Iterate over the lmdb 3 times. for _ in range(3): # Go over all 10 records each time. for j in range(10): k, v = self.evaluate([key, value]) self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(j))) self.assertAllEqual( compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + j)))) coord.request_stop() coord.join(threads)
def _get_filename_queue(self, epoch_limit): """Constructs a filename queue with an epoch limit. `epoch_limit` is intended as an error checking fallback to prevent a reader from infinitely looping in its requests for more work items if none are available in any file. It should be set high enough that it is never reached assuming at least one record exists in some file. Args: epoch_limit: The maximum number of times to read through the complete list of files before throwing an OutOfRangeError. Returns: A tuple of (filename_queue, epoch_limiter): filename_queue: A FIFOQueue with filename work items. epoch_limiter: The local variable used for epoch limitation. This should be set to zero before a reader is passed `filename_queue` in order to reset the epoch limiter's state. """ epoch_limiter = variable_scope.variable( initial_value=constant_op.constant(0, dtype=dtypes.int64), name="epoch_limiter", trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES]) filenames_tensor = array_ops.reshape( ops.convert_to_tensor(self._filenames), [-1]) # We can't rely on epoch_limiter being initialized, since queue runners are # started before local variables are initialized. Instead, we ignore epoch # limits before variable initialization. This means that prior to variable # initialization, a QueueRunner may cause a reader to enter an un-checked # infinite loop. However, as soon as local variables are initialized, we # will start incrementing and checking epoch_limiter, which will interrupt # any in-progress loops. conditional_count_up_to = control_flow_ops.cond( state_ops.is_variable_initialized(epoch_limiter), lambda: epoch_limiter.count_up_to(epoch_limit), lambda: constant_op.constant(0, dtype=dtypes.int64)) with ops.control_dependencies([conditional_count_up_to]): filenames_tensor = array_ops.identity(filenames_tensor) filename_queue = input_lib.string_input_producer(filenames_tensor, shuffle=False, capacity=1) return filename_queue, epoch_limiter
def testReadFromFileRepeatedly(self): with self.test_session() as sess: reader = io_ops.LMDBReader(name="test_read_from_file_repeated") filename_queue = input_lib.string_input_producer( [self.db_path], num_epochs=None) key, value = reader.read(filename_queue) coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(sess, coord=coord) # Iterate over the lmdb 3 times. for _ in range(3): # Go over all 10 records each time. for j in range(10): k, v = sess.run([key, value]) self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(j))) self.assertAllEqual( compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + j)))) coord.request_stop() coord.join(threads)
def _get_filename_queue(self, epoch_limit): """Constructs a filename queue with an epoch limit. `epoch_limit` is intended as an error checking fallback to prevent a reader from infinitely looping in its requests for more work items if none are available in any file. It should be set high enough that it is never reached assuming at least one record exists in some file. Args: epoch_limit: The maximum number of times to read through the complete list of files before throwing an OutOfRangeError. Returns: A tuple of (filename_queue, epoch_limiter): filename_queue: A FIFOQueue with filename work items. epoch_limiter: The local variable used for epoch limitation. This should be set to zero before a reader is passed `filename_queue` in order to reset the epoch limiter's state. """ epoch_limiter = variable_scope.variable( initial_value=constant_op.constant(0, dtype=dtypes.int64), name="epoch_limiter", trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES]) filenames_tensor = array_ops.reshape( ops.convert_to_tensor(self._filenames), [-1]) # We can't rely on epoch_limiter being initialized, since queue runners are # started before local variables are initialized. Instead, we ignore epoch # limits before variable initialization. This means that prior to variable # initialization, a QueueRunner may cause a reader to enter an un-checked # infinite loop. However, as soon as local variables are initialized, we # will start incrementing and checking epoch_limiter, which will interrupt # any in-progress loops. conditional_count_up_to = control_flow_ops.cond( state_ops.is_variable_initialized( epoch_limiter), lambda: epoch_limiter.count_up_to(epoch_limit), lambda: constant_op.constant(0, dtype=dtypes.int64)) with ops.control_dependencies([conditional_count_up_to]): filenames_tensor = array_ops.identity(filenames_tensor) filename_queue = input_lib.string_input_producer( filenames_tensor, shuffle=False, capacity=1) return filename_queue, epoch_limiter
def _verify_all_data_sources_read(self, shared_queue): with self.cached_session(): tfrecord_paths = test_utils.create_tfrecord_files( self.get_temp_dir(), num_files=3) num_readers = len(tfrecord_paths) p_reader = parallel_reader.ParallelReader(io_ops.TFRecordReader, shared_queue, num_readers=num_readers) data_files = parallel_reader.get_data_files(tfrecord_paths) filename_queue = input_lib.string_input_producer(data_files) key, value = p_reader.read(filename_queue) count0 = 0 count1 = 0 count2 = 0 num_reads = 50 sv = supervisor.Supervisor(logdir=self.get_temp_dir()) with sv.prepare_or_wait_for_session() as sess: sv.start_queue_runners(sess) for _ in range(num_reads): current_key, _ = sess.run([key, value]) if '0-of-3' in str(current_key): count0 += 1 if '1-of-3' in str(current_key): count1 += 1 if '2-of-3' in str(current_key): count2 += 1 self.assertGreater(count0, 0) self.assertGreater(count1, 0) self.assertGreater(count2, 0) self.assertEqual(count0 + count1 + count2, num_reads)
def _verify_all_data_sources_read(self, shared_queue): with self.test_session(): tfrecord_paths = test_utils.create_tfrecord_files( self.get_temp_dir(), num_files=3) num_readers = len(tfrecord_paths) p_reader = parallel_reader.ParallelReader( io_ops.TFRecordReader, shared_queue, num_readers=num_readers) data_files = parallel_reader.get_data_files(tfrecord_paths) filename_queue = input_lib.string_input_producer(data_files) key, value = p_reader.read(filename_queue) count0 = 0 count1 = 0 count2 = 0 num_reads = 50 sv = supervisor.Supervisor(logdir=self.get_temp_dir()) with sv.prepare_or_wait_for_session() as sess: sv.start_queue_runners(sess) for _ in range(num_reads): current_key, _ = sess.run([key, value]) if '0-of-3' in str(current_key): count0 += 1 if '1-of-3' in str(current_key): count1 += 1 if '2-of-3' in str(current_key): count2 += 1 self.assertGreater(count0, 0) self.assertGreater(count1, 0) self.assertGreater(count2, 0) self.assertEquals(count0 + count1 + count2, num_reads)
def testManagedMainErrorTwoQueues(self): # Tests that the supervisor correctly raises a main loop # error even when using multiple queues for input. logdir = self._test_dir("managed_main_error_two_queues") os.makedirs(logdir) data_path = self._csv_data(logdir) with self.assertRaisesRegexp(RuntimeError, "fail at step 3"): with ops.Graph().as_default(): # Create an input pipeline that reads the file 3 times. filename_queue = input_lib.string_input_producer( [data_path], num_epochs=3) reader = io_ops.TextLineReader() _, csv = reader.read(filename_queue) rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]]) shuff_rec = input_lib.shuffle_batch(rec, 1, 6, 4) sv = supervisor.Supervisor(logdir=logdir) with sv.managed_session("") as sess: for step in range(9): if sv.should_stop(): break elif step == 3: raise RuntimeError("fail at step 3") else: sess.run(shuff_rec)
def _read_keyed_batch_examples_helper(file_pattern, batch_size, reader, randomize_input=True, num_epochs=None, queue_capacity=10000, num_threads=1, read_batch_size=1, parse_fn=None, setup_shared_queue=False, name=None): """Adds operations to read, queue, batch `Example` protos. Args: file_pattern: List of files or pattern of file paths containing `Example` records. See `tf.gfile.Glob` for pattern rules. batch_size: An int or scalar `Tensor` specifying the batch size to use. reader: A function or class that returns an object with `read` method, (filename tensor) -> (example tensor). randomize_input: Whether the input should be randomized. num_epochs: Integer specifying the number of times to read through the dataset. If `None`, cycles through the dataset forever. NOTE - If specified, creates a variable that must be initialized, so call `tf.initialize_all_variables()` as shown in the tests. queue_capacity: Capacity for input queue. num_threads: The number of threads enqueuing examples. read_batch_size: An int or scalar `Tensor` specifying the number of records to read at once parse_fn: Parsing function, takes `Example` Tensor returns parsed representation. If `None`, no parsing is done. setup_shared_queue: Whether to set up a shared queue for file names. name: Name of resulting op. Returns: Returns tuple of: - `Tensor` of string keys. - String `Tensor` of batched `Example` proto. Raises: ValueError: for invalid inputs. """ # Retrieve files to read. file_names = _get_file_names(file_pattern, randomize_input) # Check input parameters are given and reasonable. if (not queue_capacity) or (queue_capacity <= 0): raise ValueError('Invalid queue_capacity %s.' % queue_capacity) if (batch_size is None) or ( (not isinstance(batch_size, ops.Tensor)) and (batch_size <= 0 or batch_size > queue_capacity)): raise ValueError('Invalid batch_size %s, with queue_capacity %s.' % (batch_size, queue_capacity)) if (read_batch_size is None) or ( (not isinstance(read_batch_size, ops.Tensor)) and (read_batch_size <= 0)): raise ValueError('Invalid read_batch_size %s.' % read_batch_size) if (not num_threads) or (num_threads <= 0): raise ValueError('Invalid num_threads %s.' % num_threads) if (num_epochs is not None) and (num_epochs <= 0): raise ValueError('Invalid num_epochs %s.' % num_epochs) with ops.name_scope(name, 'read_batch_examples', [file_pattern]) as scope: with ops.name_scope('file_name_queue') as file_name_queue_scope: if setup_shared_queue: shared_file_name_queue = _get_shared_file_name_queue( file_names, randomize_input, num_epochs, file_name_queue_scope) file_name_queue = data_flow_ops.FIFOQueue( capacity=1, dtypes=[dtypes.string], shapes=[[]]) enqueue_op = file_name_queue.enqueue( shared_file_name_queue.dequeue()) queue_runner.add_queue_runner( queue_runner.QueueRunner(file_name_queue, [enqueue_op])) else: file_name_queue = input_ops.string_input_producer( constant_op.constant(file_names, name='input'), shuffle=randomize_input, num_epochs=num_epochs, name=file_name_queue_scope) example_list = _get_examples(file_name_queue, reader, num_threads, read_batch_size, parse_fn) enqueue_many = read_batch_size > 1 if num_epochs is None: allow_smaller_final_batch = False else: allow_smaller_final_batch = True # Setup batching queue given list of read example tensors. if randomize_input: if isinstance(batch_size, ops.Tensor): min_after_dequeue = int(queue_capacity * 0.4) else: min_after_dequeue = max(queue_capacity - (3 * batch_size), batch_size) queued_examples_with_keys = input_ops.shuffle_batch_join( example_list, batch_size, capacity=queue_capacity, min_after_dequeue=min_after_dequeue, enqueue_many=enqueue_many, name=scope, allow_smaller_final_batch=allow_smaller_final_batch) else: queued_examples_with_keys = input_ops.batch_join( example_list, batch_size, capacity=queue_capacity, enqueue_many=enqueue_many, name=scope, allow_smaller_final_batch=allow_smaller_final_batch) if parse_fn and isinstance(queued_examples_with_keys, dict): queued_keys = queued_examples_with_keys.pop(KEY_FEATURE_NAME) return queued_keys, queued_examples_with_keys return queued_examples_with_keys
def read_batch_examples(file_pattern, batch_size, reader, randomize_input=True, queue_capacity=10000, num_threads=1, name='dequeue_examples'): """Adds operations to read, queue, batch `Example` protos. Given file pattern (or list of files), will setup a queue for file names, read `Example` proto using provided `reader`, use batch queue to create batches of examples of size `batch_size`. All queue runners are added to the queue runners collection, and may be started via `start_queue_runners`. All ops are added to the default graph. Args: file_pattern: List of files or pattern of file paths containing `Example` records. See `tf.gfile.Glob` for pattern rules. batch_size: An int or scalar `Tensor` specifying the batch size to use. reader: A function or class that returns an object with `read` method, (filename tensor) -> (example tensor). randomize_input: Whether the input should be randomized. queue_capacity: Capacity for input queue. num_threads: The number of threads enqueuing examples. name: Name of resulting op. Returns: String `Tensor` of batched `Example` proto. Raises: ValueError: for invalid inputs. """ # Retrive files to read. if isinstance(file_pattern, list): file_names = file_pattern if not file_names: raise ValueError('No files given to dequeue_examples.') else: file_names = list(gfile.Glob(file_pattern)) if not file_names: raise ValueError('No files match %s.' % file_pattern) # Sort files so it will be deterministic for unit tests. They'll be shuffled # in `string_input_producer` if `randomize_input` is enabled. if not randomize_input: file_names = sorted(file_names) # Check input parameters are given and reasonable. if (not queue_capacity) or (queue_capacity <= 0): raise ValueError('Invalid queue_capacity %s.' % queue_capacity) if (batch_size is None) or ( (not isinstance(batch_size, ops.Tensor)) and (batch_size <= 0 or batch_size > queue_capacity)): raise ValueError( 'Invalid batch_size %s, with queue_capacity %s.' % (batch_size, queue_capacity)) if (not num_threads) or (num_threads <= 0): raise ValueError('Invalid num_threads %s.' % num_threads) with ops.name_scope(name) as scope: # Setup filename queue with shuffling. with ops.name_scope('file_name_queue') as file_name_queue_scope: file_name_queue = input_ops.string_input_producer( constant_op.constant(file_names, name='input'), shuffle=randomize_input, name=file_name_queue_scope) # Create reader and set it to read from filename queue. with ops.name_scope('read'): _, example_proto = reader().read(file_name_queue) # Setup batching queue. if randomize_input: if isinstance(batch_size, ops.Tensor): min_after_dequeue = int(queue_capacity * 0.4) else: min_after_dequeue = max(queue_capacity - (3 * batch_size), batch_size) examples = input_ops.shuffle_batch( [example_proto], batch_size, capacity=queue_capacity, num_threads=num_threads, min_after_dequeue=min_after_dequeue, name=scope) else: examples = input_ops.batch( [example_proto], batch_size, capacity=queue_capacity, num_threads=num_threads, name=scope) return examples
def _read_keyed_batch_examples_helper(file_pattern, batch_size, reader, randomize_input=True, num_epochs=None, queue_capacity=10000, num_threads=1, read_batch_size=1, parse_fn=None, setup_shared_queue=False, name=None): # Retrieve files to read. file_names = _get_file_names(file_pattern, randomize_input) # Check input parameters are given and reasonable. if (not queue_capacity) or (queue_capacity <= 0): raise ValueError('Invalid queue_capacity %s.' % queue_capacity) if (batch_size is None) or ( (not isinstance(batch_size, ops.Tensor)) and (batch_size <= 0 or batch_size > queue_capacity)): raise ValueError( 'Invalid batch_size %s, with queue_capacity %s.' % (batch_size, queue_capacity)) if (read_batch_size is None) or ( (not isinstance(read_batch_size, ops.Tensor)) and (read_batch_size <= 0)): raise ValueError('Invalid read_batch_size %s.' % read_batch_size) if (not num_threads) or (num_threads <= 0): raise ValueError('Invalid num_threads %s.' % num_threads) if (num_epochs is not None) and (num_epochs <= 0): raise ValueError('Invalid num_epochs %s.' % num_epochs) with ops.name_scope(name, 'read_batch_examples', [file_pattern]) as scope: with ops.name_scope('file_name_queue') as file_name_queue_scope: if setup_shared_queue: shared_file_name_queue = _get_shared_file_name_queue( file_names, randomize_input, num_epochs, file_name_queue_scope) file_name_queue = data_flow_ops.FIFOQueue( capacity=1, dtypes=[dtypes.string], shapes=[[]]) enqueue_op = file_name_queue.enqueue(shared_file_name_queue.dequeue()) queue_runner.add_queue_runner( queue_runner.QueueRunner(file_name_queue, [enqueue_op])) else: file_name_queue = input_ops.string_input_producer( constant_op.constant( file_names, name='input'), shuffle=randomize_input, num_epochs=num_epochs, name=file_name_queue_scope) example_list = _get_examples(file_name_queue, reader, num_threads, read_batch_size, parse_fn) enqueue_many = read_batch_size > 1 if num_epochs is not None: allow_smaller_final_batch = True else: allow_smaller_final_batch = False # Setup batching queue given list of read example tensors. if randomize_input: if isinstance(batch_size, ops.Tensor): min_after_dequeue = int(queue_capacity * 0.4) else: min_after_dequeue = max(queue_capacity - (3 * batch_size), batch_size) queued_examples_with_keys = input_ops.shuffle_batch_join( example_list, batch_size, capacity=queue_capacity, min_after_dequeue=min_after_dequeue, enqueue_many=enqueue_many, name=scope, allow_smaller_final_batch=allow_smaller_final_batch) else: queued_examples_with_keys = input_ops.batch_join( example_list, batch_size, capacity=queue_capacity, enqueue_many=enqueue_many, name=scope, allow_smaller_final_batch=allow_smaller_final_batch) if parse_fn and isinstance(queued_examples_with_keys, dict): queued_keys = queued_examples_with_keys.pop(KEY_FEATURE_NAME) return queued_keys, queued_examples_with_keys return queued_examples_with_keys
def read_keyed_batch_examples( file_pattern, batch_size, reader, randomize_input=True, num_epochs=None, queue_capacity=10000, num_threads=1, read_batch_size=1, parse_fn=None, name=None): """Adds operations to read, queue, batch `Example` protos. Given file pattern (or list of files), will setup a queue for file names, read `Example` proto using provided `reader`, use batch queue to create batches of examples of size `batch_size`. All queue runners are added to the queue runners collection, and may be started via `start_queue_runners`. All ops are added to the default graph. Use `parse_fn` if you need to do parsing / processing on single examples. Args: file_pattern: List of files or pattern of file paths containing `Example` records. See `tf.gfile.Glob` for pattern rules. batch_size: An int or scalar `Tensor` specifying the batch size to use. reader: A function or class that returns an object with `read` method, (filename tensor) -> (example tensor). randomize_input: Whether the input should be randomized. num_epochs: Integer specifying the number of times to read through the dataset. If `None`, cycles through the dataset forever. NOTE - If specified, creates a variable that must be initialized, so call `tf.initialize_all_variables()` as shown in the tests. queue_capacity: Capacity for input queue. num_threads: The number of threads enqueuing examples. read_batch_size: An int or scalar `Tensor` specifying the number of records to read at once parse_fn: Parsing function, takes `Example` Tensor returns parsed representation. If `None`, no parsing is done. name: Name of resulting op. Returns: String `Tensor` of batched `Example` proto. If `keep_keys` is True, then returns tuple of string `Tensor`s, where first value is the key. Raises: ValueError: for invalid inputs. """ # Retrive files to read. if isinstance(file_pattern, list): file_names = file_pattern if not file_names: raise ValueError('No files given to dequeue_examples.') else: file_names = list(gfile.Glob(file_pattern)) if not file_names: raise ValueError('No files match %s.' % file_pattern) # Sort files so it will be deterministic for unit tests. They'll be shuffled # in `string_input_producer` if `randomize_input` is enabled. if not randomize_input: file_names = sorted(file_names) # Check input parameters are given and reasonable. if (not queue_capacity) or (queue_capacity <= 0): raise ValueError('Invalid queue_capacity %s.' % queue_capacity) if (batch_size is None) or ( (not isinstance(batch_size, ops.Tensor)) and (batch_size <= 0 or batch_size > queue_capacity)): raise ValueError( 'Invalid batch_size %s, with queue_capacity %s.' % (batch_size, queue_capacity)) if (read_batch_size is None) or ( (not isinstance(read_batch_size, ops.Tensor)) and (read_batch_size <= 0)): raise ValueError('Invalid read_batch_size %s.' % read_batch_size) if (not num_threads) or (num_threads <= 0): raise ValueError('Invalid num_threads %s.' % num_threads) if (num_epochs is not None) and (num_epochs <= 0): raise ValueError('Invalid num_epochs %s.' % num_epochs) with ops.op_scope([file_pattern], name, 'read_batch_examples') as scope: # Setup filename queue with shuffling. with ops.name_scope('file_name_queue') as file_name_queue_scope: file_name_queue = input_ops.string_input_producer( constant_op.constant(file_names, name='input'), shuffle=randomize_input, num_epochs=num_epochs, name=file_name_queue_scope) # Create readers, one per thread and set them to read from filename queue. with ops.name_scope('read'): example_list = [] for _ in range(num_threads): if read_batch_size > 1: keys, examples_proto = reader().read_up_to(file_name_queue, read_batch_size) else: keys, examples_proto = reader().read(file_name_queue) if parse_fn: parsed_examples = parse_fn(examples_proto) # Map keys into example map because batch_join doesn't support # tuple of Tensor + dict. if isinstance(parsed_examples, dict): parsed_examples[KEY_FEATURE_NAME] = keys example_list.append(parsed_examples) else: example_list.append((keys, parsed_examples)) else: example_list.append((keys, examples_proto)) enqueue_many = read_batch_size > 1 # Setup batching queue given list of read example tensors. if randomize_input: if isinstance(batch_size, ops.Tensor): min_after_dequeue = int(queue_capacity * 0.4) else: min_after_dequeue = max(queue_capacity - (3 * batch_size), batch_size) queued_examples_with_keys = input_ops.shuffle_batch_join( example_list, batch_size, capacity=queue_capacity, min_after_dequeue=min_after_dequeue, enqueue_many=enqueue_many, name=scope) else: queued_examples_with_keys = input_ops.batch_join( example_list, batch_size, capacity=queue_capacity, enqueue_many=enqueue_many, name=scope) if parse_fn and isinstance(queued_examples_with_keys, dict): queued_keys = queued_examples_with_keys.pop(KEY_FEATURE_NAME) return queued_keys, queued_examples_with_keys return queued_examples_with_keys
def _read_keyed_batch_examples_helper(file_pattern, batch_size, reader, randomize_input=True, num_epochs=None, queue_capacity=10000, num_threads=1, read_batch_size=1, parse_fn=None, setup_shared_queue=False, name=None): """Adds operations to read, queue, batch `Example` protos. Args: file_pattern: List of files or pattern of file paths containing `Example` records. See `tf.gfile.Glob` for pattern rules. batch_size: An int or scalar `Tensor` specifying the batch size to use. reader: A function or class that returns an object with `read` method, (filename tensor) -> (example tensor). randomize_input: Whether the input should be randomized. num_epochs: Integer specifying the number of times to read through the dataset. If `None`, cycles through the dataset forever. NOTE - If specified, creates a variable that must be initialized, so call `tf.initialize_all_variables()` as shown in the tests. queue_capacity: Capacity for input queue. num_threads: The number of threads enqueuing examples. read_batch_size: An int or scalar `Tensor` specifying the number of records to read at once parse_fn: Parsing function, takes `Example` Tensor returns parsed representation. If `None`, no parsing is done. setup_shared_queue: Whether to set up a shared queue for file names. name: Name of resulting op. Returns: Returns tuple of: - `Tensor` of string keys. - String `Tensor` of batched `Example` proto. Raises: ValueError: for invalid inputs. """ # Retrieve files to read. file_names = _get_file_names(file_pattern, randomize_input) # Check input parameters are given and reasonable. if (not queue_capacity) or (queue_capacity <= 0): raise ValueError('Invalid queue_capacity %s.' % queue_capacity) if (batch_size is None) or ( (not isinstance(batch_size, ops.Tensor)) and (batch_size <= 0 or batch_size > queue_capacity)): raise ValueError( 'Invalid batch_size %s, with queue_capacity %s.' % (batch_size, queue_capacity)) if (read_batch_size is None) or ( (not isinstance(read_batch_size, ops.Tensor)) and (read_batch_size <= 0)): raise ValueError('Invalid read_batch_size %s.' % read_batch_size) if (not num_threads) or (num_threads <= 0): raise ValueError('Invalid num_threads %s.' % num_threads) if (num_epochs is not None) and (num_epochs <= 0): raise ValueError('Invalid num_epochs %s.' % num_epochs) with ops.name_scope(name, 'read_batch_examples', [file_pattern]) as scope: with ops.name_scope('file_name_queue') as file_name_queue_scope: if setup_shared_queue: shared_file_name_queue = _get_shared_file_name_queue( file_names, randomize_input, num_epochs, file_name_queue_scope) file_name_queue = data_flow_ops.FIFOQueue( capacity=1, dtypes=[dtypes.string], shapes=[[]]) enqueue_op = file_name_queue.enqueue(shared_file_name_queue.dequeue()) queue_runner.add_queue_runner( queue_runner.QueueRunner(file_name_queue, [enqueue_op])) else: file_name_queue = input_ops.string_input_producer( constant_op.constant( file_names, name='input'), shuffle=randomize_input, num_epochs=num_epochs, name=file_name_queue_scope) example_list = _get_examples(file_name_queue, reader, num_threads, read_batch_size, parse_fn) enqueue_many = read_batch_size > 1 if num_epochs is None: allow_smaller_final_batch = False else: allow_smaller_final_batch = True # Setup batching queue given list of read example tensors. if randomize_input: if isinstance(batch_size, ops.Tensor): min_after_dequeue = int(queue_capacity * 0.4) else: min_after_dequeue = max(queue_capacity - (3 * batch_size), batch_size) queued_examples_with_keys = input_ops.shuffle_batch_join( example_list, batch_size, capacity=queue_capacity, min_after_dequeue=min_after_dequeue, enqueue_many=enqueue_many, name=scope, allow_smaller_final_batch=allow_smaller_final_batch) else: queued_examples_with_keys = input_ops.batch_join( example_list, batch_size, capacity=queue_capacity, enqueue_many=enqueue_many, name=scope, allow_smaller_final_batch=allow_smaller_final_batch) if parse_fn and isinstance(queued_examples_with_keys, dict): queued_keys = queued_examples_with_keys.pop(KEY_FEATURE_NAME) return queued_keys, queued_examples_with_keys return queued_examples_with_keys
def read_keyed_batch_examples(file_pattern, batch_size, reader, randomize_input=True, num_epochs=None, queue_capacity=10000, num_threads=1, read_batch_size=1, parse_fn=None, name=None): """Adds operations to read, queue, batch `Example` protos. Given file pattern (or list of files), will setup a queue for file names, read `Example` proto using provided `reader`, use batch queue to create batches of examples of size `batch_size`. All queue runners are added to the queue runners collection, and may be started via `start_queue_runners`. All ops are added to the default graph. Use `parse_fn` if you need to do parsing / processing on single examples. Args: file_pattern: List of files or pattern of file paths containing `Example` records. See `tf.gfile.Glob` for pattern rules. batch_size: An int or scalar `Tensor` specifying the batch size to use. reader: A function or class that returns an object with `read` method, (filename tensor) -> (example tensor). randomize_input: Whether the input should be randomized. num_epochs: Integer specifying the number of times to read through the dataset. If `None`, cycles through the dataset forever. NOTE - If specified, creates a variable that must be initialized, so call `tf.initialize_all_variables()` as shown in the tests. queue_capacity: Capacity for input queue. num_threads: The number of threads enqueuing examples. read_batch_size: An int or scalar `Tensor` specifying the number of records to read at once parse_fn: Parsing function, takes `Example` Tensor returns parsed representation. If `None`, no parsing is done. name: Name of resulting op. Returns: Returns tuple of: - `Tensor` of string keys. - String `Tensor` of batched `Example` proto. Raises: ValueError: for invalid inputs. """ # Retrieve files to read. if isinstance(file_pattern, list): file_names = file_pattern if not file_names: raise ValueError('No files given to dequeue_examples.') else: file_names = list(gfile.Glob(file_pattern)) if not file_names: raise ValueError('No files match %s.' % file_pattern) # Sort files so it will be deterministic for unit tests. They'll be shuffled # in `string_input_producer` if `randomize_input` is enabled. if not randomize_input: file_names = sorted(file_names) # Check input parameters are given and reasonable. if (not queue_capacity) or (queue_capacity <= 0): raise ValueError('Invalid queue_capacity %s.' % queue_capacity) if (batch_size is None) or ( (not isinstance(batch_size, ops.Tensor)) and (batch_size <= 0 or batch_size > queue_capacity)): raise ValueError('Invalid batch_size %s, with queue_capacity %s.' % (batch_size, queue_capacity)) if (read_batch_size is None) or ( (not isinstance(read_batch_size, ops.Tensor)) and (read_batch_size <= 0)): raise ValueError('Invalid read_batch_size %s.' % read_batch_size) if (not num_threads) or (num_threads <= 0): raise ValueError('Invalid num_threads %s.' % num_threads) if (num_epochs is not None) and (num_epochs <= 0): raise ValueError('Invalid num_epochs %s.' % num_epochs) with ops.name_scope(name, 'read_batch_examples', [file_pattern]) as scope: # Setup filename queue with shuffling. with ops.name_scope('file_name_queue') as file_name_queue_scope: file_name_queue = input_ops.string_input_producer( constant_op.constant(file_names, name='input'), shuffle=randomize_input, num_epochs=num_epochs, name=file_name_queue_scope) # Create readers, one per thread and set them to read from filename queue. with ops.name_scope('read'): example_list = [] for _ in range(num_threads): if read_batch_size > 1: keys, examples_proto = reader().read_up_to( file_name_queue, read_batch_size) else: keys, examples_proto = reader().read(file_name_queue) if parse_fn: parsed_examples = parse_fn(examples_proto) # Map keys into example map because batch_join doesn't support # tuple of Tensor + dict. if isinstance(parsed_examples, dict): parsed_examples[KEY_FEATURE_NAME] = keys example_list.append(parsed_examples) else: example_list.append((keys, parsed_examples)) else: example_list.append((keys, examples_proto)) enqueue_many = read_batch_size > 1 if num_epochs is not None: allow_smaller_final_batch = True else: allow_smaller_final_batch = False # Setup batching queue given list of read example tensors. if randomize_input: if isinstance(batch_size, ops.Tensor): min_after_dequeue = int(queue_capacity * 0.4) else: min_after_dequeue = max(queue_capacity - (3 * batch_size), batch_size) queued_examples_with_keys = input_ops.shuffle_batch_join( example_list, batch_size, capacity=queue_capacity, min_after_dequeue=min_after_dequeue, enqueue_many=enqueue_many, name=scope, allow_smaller_final_batch=allow_smaller_final_batch) else: queued_examples_with_keys = input_ops.batch_join( example_list, batch_size, capacity=queue_capacity, enqueue_many=enqueue_many, name=scope, allow_smaller_final_batch=allow_smaller_final_batch) if parse_fn and isinstance(queued_examples_with_keys, dict): queued_keys = queued_examples_with_keys.pop(KEY_FEATURE_NAME) return queued_keys, queued_examples_with_keys return queued_examples_with_keys
def read_batch_examples(file_pattern, batch_size, reader, randomize_input=True, queue_capacity=10000, num_threads=1, name='dequeue_examples'): """Adds operations to read, queue, batch `Example` protos. Given file pattern (or list of files), will setup a queue for file names, read `Example` proto using provided `reader`, use batch queue to create batches of examples of size `batch_size`. All queue runners are added to the queue runners collection, and may be started via `start_queue_runners`. All ops are added to the default graph. Args: file_pattern: List of files or pattern of file paths containing `Example` records. See `tf.gfile.Glob` for pattern rules. batch_size: An int or scalar `Tensor` specifying the batch size to use. reader: A function or class that returns an object with `read` method, (filename tensor) -> (example tensor). randomize_input: Whether the input should be randomized. queue_capacity: Capacity for input queue. num_threads: The number of threads enqueuing examples. name: Name of resulting op. Returns: String `Tensor` of batched `Example` proto. Raises: ValueError: for invalid inputs. """ # Retrive files to read. if isinstance(file_pattern, list): file_names = file_pattern if not file_names: raise ValueError('No files given to dequeue_examples.') else: file_names = list(gfile.Glob(file_pattern)) if not file_names: raise ValueError('No files match %s.' % file_pattern) # Sort files so it will be deterministic for unit tests. They'll be shuffled # in `string_input_producer` if `randomize_input` is enabled. if not randomize_input: file_names = sorted(file_names) # Check input parameters are given and reasonable. if (not queue_capacity) or (queue_capacity <= 0): raise ValueError('Invalid queue_capacity %s.' % queue_capacity) if (batch_size is None) or ( (not isinstance(batch_size, ops.Tensor)) and (batch_size <= 0 or batch_size > queue_capacity)): raise ValueError('Invalid batch_size %s, with queue_capacity %s.' % (batch_size, queue_capacity)) if (not num_threads) or (num_threads <= 0): raise ValueError('Invalid num_threads %s.' % num_threads) with ops.name_scope(name) as scope: # Setup filename queue with shuffling. with ops.name_scope('file_name_queue') as file_name_queue_scope: file_name_queue = input_ops.string_input_producer( constant_op.constant(file_names, name='input'), shuffle=randomize_input, name=file_name_queue_scope) # Create reader and set it to read from filename queue. with ops.name_scope('read'): _, example_proto = reader().read(file_name_queue) # Setup batching queue. if randomize_input: if isinstance(batch_size, ops.Tensor): min_after_dequeue = int(queue_capacity * 0.4) else: min_after_dequeue = max(queue_capacity - (3 * batch_size), batch_size) examples = input_ops.shuffle_batch( [example_proto], batch_size, capacity=queue_capacity, num_threads=num_threads, min_after_dequeue=min_after_dequeue, name=scope) else: examples = input_ops.batch([example_proto], batch_size, capacity=queue_capacity, num_threads=num_threads, name=scope) return examples