コード例 #1
0
 def testSeekNextLimitEpochs(self):
   string_list = ["a", "b", "c"]
   with self.test_session() as session:
     elem = input_pipeline_ops.seek_next(string_list, num_epochs=1)
     session.run(
         [tf.local_variables_initializer(), tf.global_variables_initializer()])
     self._assert_output([b"a", b"b", b"c"], session, elem)
コード例 #2
0
 def testSeekNext(self):
   string_list = ["a", "b", "c"]
   with self.test_session() as session:
     elem = input_pipeline_ops.seek_next(string_list)
     session.run(tf.initialize_all_variables())
     self.assertEqual(b"a", session.run(elem))
     self.assertEqual(b"b", session.run(elem))
     self.assertEqual(b"c", session.run(elem))
     self.assertEqual(b"a", session.run(elem))
コード例 #3
0
 def testSeekNext(self):
     string_list = ["a", "b", "c"]
     with self.test_session() as session:
         elem = input_pipeline_ops.seek_next(string_list)
         session.run(tf.initialize_all_variables())
         self.assertEqual(b"a", session.run(elem))
         self.assertEqual(b"b", session.run(elem))
         self.assertEqual(b"c", session.run(elem))
         self.assertEqual(b"a", session.run(elem))
コード例 #4
0
 def testSeekNextLimitEpochsThree(self):
     string_list = ["a", "b", "c"]
     with self.test_session() as session:
         elem = input_pipeline_ops.seek_next(string_list, num_epochs=3)
         session.run([
             variables.local_variables_initializer(),
             variables.global_variables_initializer()
         ])
         # Expect to see [a, b, c] three times.
         self._assert_output([b"a", b"b", b"c"] * 3, session, elem)
コード例 #5
0
 def testSeekNext(self):
     string_list = ["a", "b", "c"]
     with self.test_session() as session:
         elem = input_pipeline_ops.seek_next(string_list)
         session.run([variables.global_variables_initializer()])
         self.assertEqual(b"a", session.run(elem))
         self.assertEqual(b"b", session.run(elem))
         self.assertEqual(b"c", session.run(elem))
         # Make sure we loop.
         self.assertEqual(b"a", session.run(elem))
コード例 #6
0
 def testSeekNextLimitEpochsThree(self):
   string_list = ["a", "b", "c"]
   with self.test_session() as session:
     elem = input_pipeline_ops.seek_next(string_list, num_epochs=3)
     session.run([
         variables.local_variables_initializer(),
         variables.global_variables_initializer()
     ])
     # Expect to see [a, b, c] three times.
     self._assert_output([b"a", b"b", b"c"] * 3, session, elem)
コード例 #7
0
 def testSeekNext(self):
   string_list = ["a", "b", "c"]
   with self.test_session() as session:
     elem = input_pipeline_ops.seek_next(string_list)
     session.run([variables.global_variables_initializer()])
     self.assertEqual(b"a", session.run(elem))
     self.assertEqual(b"b", session.run(elem))
     self.assertEqual(b"c", session.run(elem))
     # Make sure we loop.
     self.assertEqual(b"a", session.run(elem))
コード例 #8
0
def _read_keyed_batch_examples_helper(file_pattern,
                                      batch_size,
                                      reader,
                                      randomize_input=True,
                                      num_epochs=None,
                                      queue_capacity=10000,
                                      num_threads=1,
                                      read_batch_size=1,
                                      filter_fn=None,
                                      parse_fn=None,
                                      setup_shared_queue=False,
                                      name=None,
                                      seed=None):
  """Adds operations to read, queue, batch `Example` protos.

  Args:
    file_pattern: List of files or patterns of file paths containing
        `Example` records. See `tf.io.gfile.glob` for pattern rules.
    batch_size: An int or scalar `Tensor` specifying the batch size to use.
    reader: A function or class that returns an object with
      `read` method, (filename tensor) -> (example tensor).
    randomize_input: Whether the input should be randomized.
    num_epochs: Integer specifying the number of times to read through the
      dataset. If `None`, cycles through the dataset forever.
      NOTE - If specified, creates a variable that must be initialized, so call
      `tf.compat.v1.local_variables_initializer()` and run the op in a session.
    queue_capacity: Capacity for input queue.
    num_threads: The number of threads enqueuing examples.
    read_batch_size: An int or scalar `Tensor` specifying the number of
      records to read at once.
    filter_fn: Filtering function, takes both keys as well `Example` Tensors
      and returns a boolean mask of the same shape as the input Tensors to
      be applied for filtering. If `None`, no filtering is done.
    parse_fn: Parsing function, takes `Example` Tensor returns parsed
      representation. If `None`, no parsing is done.
    setup_shared_queue: Whether to set up a shared queue for file names.
    name: Name of resulting op.
    seed: An integer (optional). Seed used if randomize_input == True.

  Returns:
    Returns tuple of:
    - `Tensor` of string keys.
    - String `Tensor` of batched `Example` proto.

  Raises:
    ValueError: for invalid inputs.
  """
  # Retrieve files to read.
  file_names = _get_file_names(file_pattern, randomize_input)

  # Check input parameters are given and reasonable.
  if (not queue_capacity) or (queue_capacity <= 0):
    raise ValueError('Invalid queue_capacity %s.' % queue_capacity)
  if (batch_size is None) or (
      (not isinstance(batch_size, ops.Tensor)) and
      (batch_size <= 0 or batch_size >= queue_capacity)):
    raise ValueError('Invalid batch_size %s, with queue_capacity %s.' %
                     (batch_size, queue_capacity))
  if (read_batch_size is None) or (
      (not isinstance(read_batch_size, ops.Tensor)) and (read_batch_size <= 0)):
    raise ValueError('Invalid read_batch_size %s.' % read_batch_size)
  if (not num_threads) or (num_threads <= 0):
    raise ValueError('Invalid num_threads %s.' % num_threads)
  if (num_epochs is not None) and (num_epochs <= 0):
    raise ValueError('Invalid num_epochs %s.' % num_epochs)

  with ops.name_scope(name, 'read_batch_examples', [file_pattern]) as scope:
    with ops.name_scope('file_name_queue') as file_name_queue_scope:
      if setup_shared_queue:
        file_name_queue = data_flow_ops.FIFOQueue(
            capacity=1, dtypes=[dtypes.string], shapes=[[]])
        enqueue_op = file_name_queue.enqueue(
            input_pipeline_ops.seek_next(
                file_names,
                shuffle=randomize_input,
                num_epochs=num_epochs,
                seed=seed))
        queue_runner.add_queue_runner(
            queue_runner.QueueRunner(file_name_queue, [enqueue_op]))
      else:
        file_name_queue = input_ops.string_input_producer(
            constant_op.constant(file_names, name='input'),
            shuffle=randomize_input,
            num_epochs=num_epochs,
            name=file_name_queue_scope,
            seed=seed)

    example_list = _get_examples(file_name_queue, reader, num_threads,
                                 read_batch_size, filter_fn, parse_fn)

    enqueue_many = read_batch_size > 1

    if num_epochs is None:
      allow_smaller_final_batch = False
    else:
      allow_smaller_final_batch = True

    # Setup batching queue given list of read example tensors.
    if randomize_input:
      if isinstance(batch_size, ops.Tensor):
        min_after_dequeue = int(queue_capacity * 0.4)
      else:
        min_after_dequeue = max(queue_capacity - (3 * batch_size), batch_size)
      queued_examples_with_keys = input_ops.shuffle_batch_join(
          example_list,
          batch_size,
          capacity=queue_capacity,
          min_after_dequeue=min_after_dequeue,
          enqueue_many=enqueue_many,
          name=scope,
          allow_smaller_final_batch=allow_smaller_final_batch,
          seed=seed)
    else:
      queued_examples_with_keys = input_ops.batch_join(
          example_list,
          batch_size,
          capacity=queue_capacity,
          enqueue_many=enqueue_many,
          name=scope,
          allow_smaller_final_batch=allow_smaller_final_batch)
    if parse_fn and isinstance(queued_examples_with_keys, dict):
      queued_keys = queued_examples_with_keys.pop(KEY_FEATURE_NAME)
      return queued_keys, queued_examples_with_keys
    return queued_examples_with_keys