Ejemplo n.º 1
0
 def testManagedMainErrorTwoQueues(self):
     # Tests that the supervisor correctly raises a main loop
     # error even when using multiple queues for input.
     logdir = self._test_dir("managed_main_error_two_queues")
     os.makedirs(logdir)
     data_path = self._csv_data(logdir)
     with self.assertRaisesRegexp(RuntimeError, "fail at step 3"):
         with ops.Graph().as_default():
             # Create an input pipeline that reads the file 3 times.
             filename_queue = input_lib.string_input_producer([data_path],
                                                              num_epochs=3)
             reader = io_ops.TextLineReader()
             _, csv = reader.read(filename_queue)
             rec = parsing_ops.decode_csv(csv,
                                          record_defaults=[[1], [1], [1]])
             shuff_rec = input_lib.shuffle_batch(rec, 1, 6, 4)
             sv = supervisor.Supervisor(logdir=logdir)
             with sv.managed_session("") as sess:
                 for step in range(9):
                     if sv.should_stop():
                         break
                     elif step == 3:
                         raise RuntimeError("fail at step 3")
                     else:
                         sess.run(shuff_rec)
Ejemplo n.º 2
0
 def fn(tensors, scope):
     return input.shuffle_batch(
         tensors,
         batch_size=batch_size,
         num_threads=num_threads,
         capacity=capacity,
         enqueue_many=enqueue_many,
         min_after_dequeue=min_after_dequeue,
         seed=seed,
         allow_smaller_final_batch=allow_smaller_final_batch,
         name=scope)
Ejemplo n.º 3
0
 def fn(tensors, scope):
   return input.shuffle_batch(
       tensors,
       batch_size=batch_size,
       num_threads=num_threads,
       capacity=capacity,
       enqueue_many=enqueue_many,
       min_after_dequeue=min_after_dequeue,
       seed=seed,
       allow_smaller_final_batch=allow_smaller_final_batch,
       name=scope)
Ejemplo n.º 4
0
 def _apply_transform(self, transform_input, **kwargs):
   batched = input_ops.shuffle_batch(transform_input,
                                     batch_size=self.batch_size,
                                     capacity=self.queue_capacity,
                                     min_after_dequeue=self.min_after_dequeue,
                                     num_threads=self.num_threads,
                                     seed=self.seed,
                                     enqueue_many=True)
   # TODO(jamieas): batch will soon return a list regardless of the number of
   # enqueued tensors. Remove the following once that change is in place.
   if not isinstance(batched, (tuple, list)):
     batched = (batched,)
   # pylint: disable=not-callable
   return self.return_type(*batched)
Ejemplo n.º 5
0
 def _apply_transform(self, transform_input):
   batched = input_ops.shuffle_batch(transform_input,
                                     batch_size=self.batch_size,
                                     capacity=self.queue_capacity,
                                     min_after_dequeue=self.min_after_dequeue,
                                     num_threads=self.num_threads,
                                     seed=self.seed,
                                     enqueue_many=True)
   # TODO(jamieas): batch will soon return a list regardless of the number of
   # enqueued tensors. Remove the following once that change is in place.
   if not isinstance(batched, (tuple, list)):
     batched = (batched,)
   # pylint: disable=not-callable
   return self.return_type(*batched)
Ejemplo n.º 6
0
 def testManagedEndOfInputTwoQueues(self):
   # Tests that the supervisor finishes without an error when using
   # a fixed number of epochs, reading from two queues, the second
   # one producing a batch from the first one.
   logdir = self._test_dir("managed_end_of_input_two_queues")
   os.makedirs(logdir)
   data_path = self._csv_data(logdir)
   with ops.Graph().as_default():
     # Create an input pipeline that reads the file 3 times.
     filename_queue = input_lib.string_input_producer(
         [data_path], num_epochs=3)
     reader = io_ops.TextLineReader()
     _, csv = reader.read(filename_queue)
     rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]])
     shuff_rec = input_lib.shuffle_batch(rec, 1, 6, 4)
     sv = supervisor.Supervisor(logdir=logdir)
     with sv.managed_session("") as sess:
       while not sv.should_stop():
         sess.run(shuff_rec)
Ejemplo n.º 7
0
 def testManagedEndOfInputTwoQueues(self):
     # Tests that the supervisor finishes without an error when using
     # a fixed number of epochs, reading from two queues, the second
     # one producing a batch from the first one.
     logdir = self._test_dir("managed_end_of_input_two_queues")
     os.makedirs(logdir)
     data_path = self._csv_data(logdir)
     with ops.Graph().as_default():
         # Create an input pipeline that reads the file 3 times.
         filename_queue = input_lib.string_input_producer([data_path],
                                                          num_epochs=3)
         reader = io_ops.TextLineReader()
         _, csv = reader.read(filename_queue)
         rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]])
         shuff_rec = input_lib.shuffle_batch(rec, 1, 6, 4)
         sv = supervisor.Supervisor(logdir=logdir)
         with sv.managed_session("") as sess:
             while not sv.should_stop():
                 sess.run(shuff_rec)
Ejemplo n.º 8
0
 def testManagedMainErrorTwoQueues(self):
   # Tests that the supervisor correctly raises a main loop
   # error even when using multiple queues for input.
   logdir = self._test_dir("managed_main_error_two_queues")
   os.makedirs(logdir)
   data_path = self._csv_data(logdir)
   with self.assertRaisesRegexp(RuntimeError, "fail at step 3"):
     with ops.Graph().as_default():
       # Create an input pipeline that reads the file 3 times.
       filename_queue = input_lib.string_input_producer(
           [data_path], num_epochs=3)
       reader = io_ops.TextLineReader()
       _, csv = reader.read(filename_queue)
       rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]])
       shuff_rec = input_lib.shuffle_batch(rec, 1, 6, 4)
       sv = supervisor.Supervisor(logdir=logdir)
       with sv.managed_session("") as sess:
         for step in range(9):
           if sv.should_stop():
             break
           elif step == 3:
             raise RuntimeError("fail at step 3")
           else:
             sess.run(shuff_rec)
Ejemplo n.º 9
0
def read_batch_examples(file_pattern,
                        batch_size,
                        reader,
                        randomize_input=True,
                        queue_capacity=10000,
                        num_threads=1,
                        name='dequeue_examples'):
    """Adds operations to read, queue, batch `Example` protos.

  Given file pattern (or list of files), will setup a queue for file names,
  read `Example` proto using provided `reader`, use batch queue to create
  batches of examples of size `batch_size`.

  All queue runners are added to the queue runners collection, and may be
  started via `start_queue_runners`.

  All ops are added to the default graph.

  Args:
    file_pattern: List of files or pattern of file paths containing
        `Example` records. See `tf.gfile.Glob` for pattern rules.
    batch_size: An int or scalar `Tensor` specifying the batch size to use.
    reader: A function or class that returns an object with
      `read` method, (filename tensor) -> (example tensor).
    randomize_input: Whether the input should be randomized.
    queue_capacity: Capacity for input queue.
    num_threads: The number of threads enqueuing examples.
    name: Name of resulting op.

  Returns:
    String `Tensor` of batched `Example` proto.

  Raises:
    ValueError: for invalid inputs.
  """
    # Retrive files to read.
    if isinstance(file_pattern, list):
        file_names = file_pattern
        if not file_names:
            raise ValueError('No files given to dequeue_examples.')
    else:
        file_names = list(gfile.Glob(file_pattern))
        if not file_names:
            raise ValueError('No files match %s.' % file_pattern)

    # Sort files so it will be deterministic for unit tests. They'll be shuffled
    # in `string_input_producer` if `randomize_input` is enabled.
    if not randomize_input:
        file_names = sorted(file_names)

    # Check input parameters are given and reasonable.
    if (not queue_capacity) or (queue_capacity <= 0):
        raise ValueError('Invalid queue_capacity %s.' % queue_capacity)
    if (batch_size is None) or (
        (not isinstance(batch_size, ops.Tensor)) and
        (batch_size <= 0 or batch_size > queue_capacity)):
        raise ValueError('Invalid batch_size %s, with queue_capacity %s.' %
                         (batch_size, queue_capacity))
    if (not num_threads) or (num_threads <= 0):
        raise ValueError('Invalid num_threads %s.' % num_threads)

    with ops.name_scope(name) as scope:
        # Setup filename queue with shuffling.
        with ops.name_scope('file_name_queue') as file_name_queue_scope:
            file_name_queue = input_ops.string_input_producer(
                constant_op.constant(file_names, name='input'),
                shuffle=randomize_input,
                name=file_name_queue_scope)

        # Create reader and set it to read from filename queue.
        with ops.name_scope('read'):
            _, example_proto = reader().read(file_name_queue)

        # Setup batching queue.
        if randomize_input:
            if isinstance(batch_size, ops.Tensor):
                min_after_dequeue = int(queue_capacity * 0.4)
            else:
                min_after_dequeue = max(queue_capacity - (3 * batch_size),
                                        batch_size)
            examples = input_ops.shuffle_batch(
                [example_proto],
                batch_size,
                capacity=queue_capacity,
                num_threads=num_threads,
                min_after_dequeue=min_after_dequeue,
                name=scope)
        else:
            examples = input_ops.batch([example_proto],
                                       batch_size,
                                       capacity=queue_capacity,
                                       num_threads=num_threads,
                                       name=scope)

        return examples
Ejemplo n.º 10
0
def read_batch_examples(file_pattern, batch_size, reader,
                        randomize_input=True, queue_capacity=10000,
                        num_threads=1, name='dequeue_examples'):
  """Adds operations to read, queue, batch `Example` protos.

  Given file pattern (or list of files), will setup a queue for file names,
  read `Example` proto using provided `reader`, use batch queue to create
  batches of examples of size `batch_size`.

  All queue runners are added to the queue runners collection, and may be
  started via `start_queue_runners`.

  All ops are added to the default graph.

  Args:
    file_pattern: List of files or pattern of file paths containing
        `Example` records. See `tf.gfile.Glob` for pattern rules.
    batch_size: An int or scalar `Tensor` specifying the batch size to use.
    reader: A function or class that returns an object with
      `read` method, (filename tensor) -> (example tensor).
    randomize_input: Whether the input should be randomized.
    queue_capacity: Capacity for input queue.
    num_threads: The number of threads enqueuing examples.
    name: Name of resulting op.

  Returns:
    String `Tensor` of batched `Example` proto.

  Raises:
    ValueError: for invalid inputs.
  """
  # Retrive files to read.
  if isinstance(file_pattern, list):
    file_names = file_pattern
    if not file_names:
      raise ValueError('No files given to dequeue_examples.')
  else:
    file_names = list(gfile.Glob(file_pattern))
    if not file_names:
      raise ValueError('No files match %s.' % file_pattern)

  # Sort files so it will be deterministic for unit tests. They'll be shuffled
  # in `string_input_producer` if `randomize_input` is enabled.
  if not randomize_input:
    file_names = sorted(file_names)

  # Check input parameters are given and reasonable.
  if (not queue_capacity) or (queue_capacity <= 0):
    raise ValueError('Invalid queue_capacity %s.' % queue_capacity)
  if (batch_size is None) or (
      (not isinstance(batch_size, ops.Tensor)) and
      (batch_size <= 0 or batch_size > queue_capacity)):
    raise ValueError(
        'Invalid batch_size %s, with queue_capacity %s.' %
        (batch_size, queue_capacity))
  if (not num_threads) or (num_threads <= 0):
    raise ValueError('Invalid num_threads %s.' % num_threads)

  with ops.name_scope(name) as scope:
    # Setup filename queue with shuffling.
    with ops.name_scope('file_name_queue') as file_name_queue_scope:
      file_name_queue = input_ops.string_input_producer(
          constant_op.constant(file_names, name='input'),
          shuffle=randomize_input, name=file_name_queue_scope)

    # Create reader and set it to read from filename queue.
    with ops.name_scope('read'):
      _, example_proto = reader().read(file_name_queue)

    # Setup batching queue.
    if randomize_input:
      if isinstance(batch_size, ops.Tensor):
        min_after_dequeue = int(queue_capacity * 0.4)
      else:
        min_after_dequeue = max(queue_capacity - (3 * batch_size), batch_size)
      examples = input_ops.shuffle_batch(
          [example_proto], batch_size, capacity=queue_capacity,
          num_threads=num_threads, min_after_dequeue=min_after_dequeue,
          name=scope)
    else:
      examples = input_ops.batch(
          [example_proto], batch_size, capacity=queue_capacity,
          num_threads=num_threads, name=scope)

    return examples