def testSeekNextLimitEpochs(self): string_list = ["a", "b", "c"] with self.test_session() as session: elem = input_pipeline_ops.seek_next(string_list, num_epochs=1) session.run( [tf.local_variables_initializer(), tf.global_variables_initializer()]) self._assert_output([b"a", b"b", b"c"], session, elem)
def testSeekNext(self): string_list = ["a", "b", "c"] with self.test_session() as session: elem = input_pipeline_ops.seek_next(string_list) session.run(tf.initialize_all_variables()) self.assertEqual(b"a", session.run(elem)) self.assertEqual(b"b", session.run(elem)) self.assertEqual(b"c", session.run(elem)) self.assertEqual(b"a", session.run(elem))
def testSeekNextLimitEpochsThree(self): string_list = ["a", "b", "c"] with self.test_session() as session: elem = input_pipeline_ops.seek_next(string_list, num_epochs=3) session.run([ variables.local_variables_initializer(), variables.global_variables_initializer() ]) # Expect to see [a, b, c] three times. self._assert_output([b"a", b"b", b"c"] * 3, session, elem)
def testSeekNext(self): string_list = ["a", "b", "c"] with self.test_session() as session: elem = input_pipeline_ops.seek_next(string_list) session.run([variables.global_variables_initializer()]) self.assertEqual(b"a", session.run(elem)) self.assertEqual(b"b", session.run(elem)) self.assertEqual(b"c", session.run(elem)) # Make sure we loop. self.assertEqual(b"a", session.run(elem))
def _read_keyed_batch_examples_helper(file_pattern, batch_size, reader, randomize_input=True, num_epochs=None, queue_capacity=10000, num_threads=1, read_batch_size=1, filter_fn=None, parse_fn=None, setup_shared_queue=False, name=None, seed=None): """Adds operations to read, queue, batch `Example` protos. Args: file_pattern: List of files or patterns of file paths containing `Example` records. See `tf.io.gfile.glob` for pattern rules. batch_size: An int or scalar `Tensor` specifying the batch size to use. reader: A function or class that returns an object with `read` method, (filename tensor) -> (example tensor). randomize_input: Whether the input should be randomized. num_epochs: Integer specifying the number of times to read through the dataset. If `None`, cycles through the dataset forever. NOTE - If specified, creates a variable that must be initialized, so call `tf.compat.v1.local_variables_initializer()` and run the op in a session. queue_capacity: Capacity for input queue. num_threads: The number of threads enqueuing examples. read_batch_size: An int or scalar `Tensor` specifying the number of records to read at once. filter_fn: Filtering function, takes both keys as well `Example` Tensors and returns a boolean mask of the same shape as the input Tensors to be applied for filtering. If `None`, no filtering is done. parse_fn: Parsing function, takes `Example` Tensor returns parsed representation. If `None`, no parsing is done. setup_shared_queue: Whether to set up a shared queue for file names. name: Name of resulting op. seed: An integer (optional). Seed used if randomize_input == True. Returns: Returns tuple of: - `Tensor` of string keys. - String `Tensor` of batched `Example` proto. Raises: ValueError: for invalid inputs. """ # Retrieve files to read. file_names = _get_file_names(file_pattern, randomize_input) # Check input parameters are given and reasonable. if (not queue_capacity) or (queue_capacity <= 0): raise ValueError('Invalid queue_capacity %s.' % queue_capacity) if (batch_size is None) or ( (not isinstance(batch_size, ops.Tensor)) and (batch_size <= 0 or batch_size >= queue_capacity)): raise ValueError('Invalid batch_size %s, with queue_capacity %s.' % (batch_size, queue_capacity)) if (read_batch_size is None) or ( (not isinstance(read_batch_size, ops.Tensor)) and (read_batch_size <= 0)): raise ValueError('Invalid read_batch_size %s.' % read_batch_size) if (not num_threads) or (num_threads <= 0): raise ValueError('Invalid num_threads %s.' % num_threads) if (num_epochs is not None) and (num_epochs <= 0): raise ValueError('Invalid num_epochs %s.' % num_epochs) with ops.name_scope(name, 'read_batch_examples', [file_pattern]) as scope: with ops.name_scope('file_name_queue') as file_name_queue_scope: if setup_shared_queue: file_name_queue = data_flow_ops.FIFOQueue( capacity=1, dtypes=[dtypes.string], shapes=[[]]) enqueue_op = file_name_queue.enqueue( input_pipeline_ops.seek_next( file_names, shuffle=randomize_input, num_epochs=num_epochs, seed=seed)) queue_runner.add_queue_runner( queue_runner.QueueRunner(file_name_queue, [enqueue_op])) else: file_name_queue = input_ops.string_input_producer( constant_op.constant(file_names, name='input'), shuffle=randomize_input, num_epochs=num_epochs, name=file_name_queue_scope, seed=seed) example_list = _get_examples(file_name_queue, reader, num_threads, read_batch_size, filter_fn, parse_fn) enqueue_many = read_batch_size > 1 if num_epochs is None: allow_smaller_final_batch = False else: allow_smaller_final_batch = True # Setup batching queue given list of read example tensors. if randomize_input: if isinstance(batch_size, ops.Tensor): min_after_dequeue = int(queue_capacity * 0.4) else: min_after_dequeue = max(queue_capacity - (3 * batch_size), batch_size) queued_examples_with_keys = input_ops.shuffle_batch_join( example_list, batch_size, capacity=queue_capacity, min_after_dequeue=min_after_dequeue, enqueue_many=enqueue_many, name=scope, allow_smaller_final_batch=allow_smaller_final_batch, seed=seed) else: queued_examples_with_keys = input_ops.batch_join( example_list, batch_size, capacity=queue_capacity, enqueue_many=enqueue_many, name=scope, allow_smaller_final_batch=allow_smaller_final_batch) if parse_fn and isinstance(queued_examples_with_keys, dict): queued_keys = queued_examples_with_keys.pop(KEY_FEATURE_NAME) return queued_keys, queued_examples_with_keys return queued_examples_with_keys