예제 #1
0
  def _apply_transform(self, transform_input):
    filename_queue = input_ops.string_input_producer(self._work_units,
                                                     shuffle=self.shuffle,
                                                     seed=self._seed)

    if self.shuffle:
      queue = data_flow_ops.RandomShuffleQueue(
          capacity=self.queue_capacity,
          min_after_dequeue=self.min_after_dequeue,
          dtypes=[dtypes.string, dtypes.string],
          shapes=[[], []],
          seed=self.seed)
    else:
      queue = data_flow_ops.FIFOQueue(capacity=self.queue_capacity,
                                      dtypes=[dtypes.string, dtypes.string],
                                      shapes=[[], []])

    enqueue_ops = []
    for _ in range(self.num_threads):
      reader = self._reader_cls(**self._reader_kwargs)
      enqueue_ops.append(queue.enqueue(reader.read(filename_queue)))

    runner = queue_runner.QueueRunner(queue, enqueue_ops)
    queue_runner.add_queue_runner(runner)
    dequeued = queue.dequeue_many(self.batch_size)

    # pylint: disable=not-callable
    return self.return_type(*dequeued)
예제 #2
0
  def _apply_transform(self, transform_input):
    filename_queue = input_ops.string_input_producer(self.work_units,
                                                     num_epochs=self.num_epochs,
                                                     shuffle=self.shuffle,
                                                     seed=self.seed)
    reader_ops = []
    for _ in range(self.num_threads):
      reader = self._reader_cls(**self._reader_kwargs)
      reader_ops.append(reader.read_up_to(filename_queue, self.enqueue_size))

    if self.shuffle:
      dequeued = input_ops.shuffle_batch_join(
          reader_ops,
          self.batch_size,
          capacity=self.queue_capacity,
          min_after_dequeue=self.min_after_dequeue,
          seed=self.seed,
          enqueue_many=True,
          shared_name=None,
          name=None)
    else:
      dequeued = input_ops.batch_join(reader_ops,
                                      self.batch_size,
                                      capacity=self.queue_capacity,
                                      enqueue_many=True,
                                      dynamic_pad=False,
                                      shared_name=None,
                                      name=None)

    # pylint: disable=not-callable
    return self.return_type(*dequeued)
예제 #3
0
def parallel_read(data_sources,
                  reader_class,
                  num_epochs=None,
                  num_readers=4,
                  reader_kwargs=None,
                  shuffle=True,
                  dtypes=None,
                  capacity=256,
                  min_after_dequeue=128):
    """Reads multiple records in parallel from data_sources using n readers.

  It uses a ParallelReader to read from multiple files in  parallel using
  multiple readers created using `reader_class` with `reader_kwargs'.

  If shuffle is True the common_queue would be a RandomShuffleQueue otherwise
  it would be a FIFOQueue.

  Usage:
      data_sources = ['path_to/train*']
      key, value = parallel_read(data_sources, tf.CSVReader, num_readers=4)

  Args:
    data_sources: a list/tuple of files or the location of the data, i.e.
      /cns/../train@128, /cns/.../train* or /tmp/.../train*
    reader_class: one of the io_ops.ReaderBase subclasses ex: TFRecordReader
    num_epochs: The number of times each data source is read. If left as None,
        the data will be cycled through indefinitely.
    num_readers: a integer, number of Readers to create.
    reader_kwargs: an optional dict, of kwargs for the reader.
    shuffle: boolean, wether should shuffle the files and the records by using
      RandomShuffleQueue as common_queue.
    dtypes:  A list of types.  The length of dtypes must equal the number
        of elements in each record. If it is None it will default to
        [tf.string, tf.string] for (key, value).
    capacity: integer, capacity of the common_queue.
    min_after_dequeue: integer, minimum number of records in the common_queue
      after dequeue. Needed for a good shuffle.

  Returns:
    key, value: a tuple of keys and values from the data_source.
  """
    data_files = get_data_files(data_sources)
    with ops.name_scope('parallel_read'):
        filename_queue = tf_input.string_input_producer(data_files,
                                                        num_epochs=num_epochs,
                                                        shuffle=shuffle)
        dtypes = dtypes or [tf_dtypes.string, tf_dtypes.string]
        if shuffle:
            common_queue = data_flow_ops.RandomShuffleQueue(
                capacity=capacity,
                min_after_dequeue=min_after_dequeue,
                dtypes=dtypes)
        else:
            common_queue = data_flow_ops.FIFOQueue(capacity=capacity,
                                                   dtypes=dtypes)

        return ParallelReader(reader_class,
                              common_queue,
                              num_readers=num_readers,
                              reader_kwargs=reader_kwargs).read(filename_queue)
예제 #4
0
    def _apply_transform(self, transform_input, **kwargs):
        filename_queue = input_ops.string_input_producer(
            self.work_units,
            num_epochs=kwargs.get("num_epochs"),
            shuffle=self.shuffle,
            seed=self.seed)
        reader_ops = []
        for _ in range(self.num_threads):
            reader = self._reader_cls(**self._reader_kwargs)
            reader_ops.append(
                reader.read_up_to(filename_queue, self.enqueue_size))

        if self.shuffle:
            dequeued = input_ops.shuffle_batch_join(
                reader_ops,
                self.batch_size,
                capacity=self.queue_capacity,
                min_after_dequeue=self.min_after_dequeue,
                seed=self.seed,
                enqueue_many=True,
                shared_name=None,
                name=None)
        else:
            dequeued = input_ops.batch_join(reader_ops,
                                            self.batch_size,
                                            capacity=self.queue_capacity,
                                            enqueue_many=True,
                                            dynamic_pad=False,
                                            shared_name=None,
                                            name=None)

        # pylint: disable=not-callable
        return self.return_type(*dequeued)
예제 #5
0
 def testManagedMainErrorTwoQueues(self):
     # Tests that the supervisor correctly raises a main loop
     # error even when using multiple queues for input.
     logdir = self._test_dir("managed_main_error_two_queues")
     os.makedirs(logdir)
     data_path = self._csv_data(logdir)
     with self.assertRaisesRegexp(RuntimeError, "fail at step 3"):
         with ops.Graph().as_default():
             # Create an input pipeline that reads the file 3 times.
             filename_queue = input_lib.string_input_producer([data_path],
                                                              num_epochs=3)
             reader = io_ops.TextLineReader()
             _, csv = reader.read(filename_queue)
             rec = parsing_ops.decode_csv(csv,
                                          record_defaults=[[1], [1], [1]])
             shuff_rec = input_lib.shuffle_batch(rec, 1, 6, 4)
             sv = supervisor.Supervisor(logdir=logdir)
             with sv.managed_session("") as sess:
                 for step in range(9):
                     if sv.should_stop():
                         break
                     elif step == 3:
                         raise RuntimeError("fail at step 3")
                     else:
                         sess.run(shuff_rec)
예제 #6
0
def single_pass_read(data_sources,
                     reader_class,
                     reader_kwargs=None,
                     scope=None):
    """Reads sequentially the data_sources using the reader, doing a single pass.

  Args:
    data_sources: a list/tuple of files or the location of the data, i.e.
      /path/to/train@128, /path/to/train* or /tmp/.../train*
    reader_class: one of the io_ops.ReaderBase subclasses ex: TFRecordReader.
    reader_kwargs: an optional dict, of kwargs for the reader.
    scope: Optional name scope for the ops.

  Returns:
    key, value: a tuple of keys and values from the data_source.
  """
    data_files = get_data_files(data_sources)
    with ops.name_scope(scope, 'single_pass_read'):
        filename_queue = tf_input.string_input_producer(data_files,
                                                        num_epochs=1,
                                                        shuffle=False,
                                                        capacity=1,
                                                        name='filenames')
        reader_kwargs = reader_kwargs or {}
        return reader_class(**reader_kwargs).read(filename_queue)
예제 #7
0
    def _apply_transform(self, transform_input):
        filename_queue = input_ops.string_input_producer(self._work_units,
                                                         shuffle=self.shuffle,
                                                         seed=self._seed)

        if self.shuffle:
            queue = data_flow_ops.RandomShuffleQueue(
                capacity=self.queue_capacity,
                min_after_dequeue=self.min_after_dequeue,
                dtypes=[dtypes.string, dtypes.string],
                shapes=[[], []],
                seed=self.seed)
        else:
            queue = data_flow_ops.FIFOQueue(
                capacity=self.queue_capacity,
                dtypes=[dtypes.string, dtypes.string],
                shapes=[[], []])

        enqueue_ops = []
        for _ in range(self.num_threads):
            reader = self._reader_cls(**self._reader_kwargs)
            enqueue_ops.append(queue.enqueue(reader.read(filename_queue)))

        runner = queue_runner.QueueRunner(queue, enqueue_ops)
        queue_runner.add_queue_runner(runner)
        dequeued = queue.dequeue_many(self.batch_size)

        # pylint: disable=not-callable
        return self.return_type(*dequeued)
예제 #8
0
def parallel_read(data_sources,
                  reader_class,
                  num_epochs=None,
                  num_readers=4,
                  reader_kwargs=None,
                  shuffle=True,
                  dtypes=None,
                  capacity=256,
                  min_after_dequeue=128):
  """Reads multiple records in parallel from data_sources using n readers.

  It uses a ParallelReader to read from multiple files in  parallel using
  multiple readers created using `reader_class` with `reader_kwargs'.

  If shuffle is True the common_queue would be a RandomShuffleQueue otherwise
  it would be a FIFOQueue.

  Usage:
      data_sources = ['path_to/train*']
      key, value = parallel_read(data_sources, tf.CSVReader, num_readers=4)

  Args:
    data_sources: a list/tuple of files or the location of the data, i.e.
      /cns/../train@128, /cns/.../train* or /tmp/.../train*
    reader_class: one of the io_ops.ReaderBase subclasses ex: TFRecordReader
    num_epochs: The number of times each data source is read. If left as None,
        the data will be cycled through indefinitely.
    num_readers: a integer, number of Readers to create.
    reader_kwargs: an optional dict, of kwargs for the reader.
    shuffle: boolean, wether should shuffle the files and the records by using
      RandomShuffleQueue as common_queue.
    dtypes:  A list of types.  The length of dtypes must equal the number
        of elements in each record. If it is None it will default to
        [tf.string, tf.string] for (key, value).
    capacity: integer, capacity of the common_queue.
    min_after_dequeue: integer, minimum number of records in the common_queue
      after dequeue. Needed for a good shuffle.

  Returns:
    key, value: a tuple of keys and values from the data_source.
  """
  data_files = get_data_files(data_sources)
  with ops.name_scope('parallel_read'):
    filename_queue = tf_input.string_input_producer(
        data_files, num_epochs=num_epochs, shuffle=shuffle)
    dtypes = dtypes or [tf_dtypes.string, tf_dtypes.string]
    if shuffle:
      common_queue = data_flow_ops.RandomShuffleQueue(
          capacity=capacity,
          min_after_dequeue=min_after_dequeue,
          dtypes=dtypes)
    else:
      common_queue = data_flow_ops.FIFOQueue(capacity=capacity, dtypes=dtypes)

    return ParallelReader(reader_class,
                          common_queue,
                          num_readers=num_readers,
                          reader_kwargs=reader_kwargs).read(filename_queue)
예제 #9
0
    def _verify_read_up_to_out(self, shared_queue):
        with self.cached_session():
            num_files = 3
            num_records_per_file = 7
            tfrecord_paths = test_utils.create_tfrecord_files(
                tempfile.mkdtemp(),
                num_files=num_files,
                num_records_per_file=num_records_per_file)

        p_reader = parallel_reader.ParallelReader(
            io_ops.TFRecordReader, shared_queue, num_readers=5)

        data_files = parallel_reader.get_data_files(tfrecord_paths)
        filename_queue = input_lib.string_input_producer(
            data_files, num_epochs=1)
        key, value = p_reader.read_up_to(filename_queue, 4)

        count0 = 0
        count1 = 0
        count2 = 0
        all_keys_count = 0
        all_values_count = 0

        sv = supervisor.Supervisor(logdir=tempfile.mkdtemp())
        with sv.prepare_or_wait_for_session() as sess:
            sv.start_queue_runners(sess)
            while True:
                try:
                    current_keys, current_values = sess.run([key, value])
                    self.assertEqual(len(current_keys), len(current_values))
                    all_keys_count += len(current_keys)
                    all_values_count += len(current_values)
                    for current_key in current_keys:
                        if '0-of-3' in str(current_key):
                            count0 += 1
                        if '1-of-3' in str(current_key):
                            count1 += 1
                        if '2-of-3' in str(current_key):
                            count2 += 1
                except errors_impl.OutOfRangeError:
                    break

        self.assertEqual(count0, num_records_per_file)
        self.assertEqual(count1, num_records_per_file)
        self.assertEqual(count2, num_records_per_file)
        self.assertEqual(
            all_keys_count,
            num_files * num_records_per_file)
        self.assertEqual(all_values_count, all_keys_count)
        self.assertEqual(
            count0 + count1 + count2,
            all_keys_count)
예제 #10
0
def _get_shared_file_name_queue(file_names, shuffle, num_epochs, name):
    # Creating a dummy variable so we can put the shared queue in ps if there is
    # a PS and in the worker otherwise. TODO(rohanj): Figure out how to place an
    # op on PS without this hack
    dummy_var = var_ops.Variable(initial_value=0, name='queue_placement_var')
    with ops.device(dummy_var.device):
        shared_file_name_queue = input_ops.string_input_producer(
            constant_op.constant(file_names, name='input'),
            shuffle=shuffle,
            num_epochs=num_epochs,
            name=name,
            shared_name=name)
        return shared_file_name_queue
예제 #11
0
  def _verify_read_up_to_out(self, shared_queue):
    with self.test_session():
      num_files = 3
      num_records_per_file = 7
      tfrecord_paths = test_utils.create_tfrecord_files(
          self.get_temp_dir(),
          num_files=num_files,
          num_records_per_file=num_records_per_file)

    p_reader = parallel_reader.ParallelReader(
        io_ops.TFRecordReader, shared_queue, num_readers=5)

    data_files = parallel_reader.get_data_files(tfrecord_paths)
    filename_queue = input_lib.string_input_producer(data_files, num_epochs=1)
    key, value = p_reader.read_up_to(filename_queue, 4)

    count0 = 0
    count1 = 0
    count2 = 0
    all_keys_count = 0
    all_values_count = 0

    sv = supervisor.Supervisor(logdir=self.get_temp_dir())
    with sv.prepare_or_wait_for_session() as sess:
      sv.start_queue_runners(sess)
      while True:
        try:
          current_keys, current_values = sess.run([key, value])
          self.assertEquals(len(current_keys), len(current_values))
          all_keys_count += len(current_keys)
          all_values_count += len(current_values)
          for current_key in current_keys:
            if '0-of-3' in str(current_key):
              count0 += 1
            if '1-of-3' in str(current_key):
              count1 += 1
            if '2-of-3' in str(current_key):
              count2 += 1
        except errors_impl.OutOfRangeError:
          break

    self.assertEquals(count0, num_records_per_file)
    self.assertEquals(count1, num_records_per_file)
    self.assertEquals(count2, num_records_per_file)
    self.assertEquals(
        all_keys_count,
        num_files * num_records_per_file)
    self.assertEquals(all_values_count, all_keys_count)
    self.assertEquals(
        count0 + count1 + count2,
        all_keys_count)
예제 #12
0
def _get_shared_file_name_queue(file_names, shuffle, num_epochs, name):
  # Creating a dummy variable so we can put the shared queue in ps if there is
  # a PS and in the worker otherwise. TODO(rohanj): Figure out how to place an
  # op on PS without this hack
  dummy_var = var_ops.Variable(initial_value=0, name='queue_placement_var')
  with ops.device(dummy_var.device):
    shared_file_name_queue = input_ops.string_input_producer(
        constant_op.constant(
            file_names, name='input'),
        shuffle=shuffle,
        num_epochs=num_epochs,
        name=name,
        shared_name=name)
    return shared_file_name_queue
예제 #13
0
def single_pass_read(data_sources, reader_class, reader_kwargs=None):
    """Reads sequentially the data_sources using the reader, doing a single pass.

  Args:
    data_sources: a list/tuple of files or the location of the data, i.e.
      /path/to/train@128, /path/to/train* or /tmp/.../train*
    reader_class: one of the io_ops.ReaderBase subclasses ex: TFRecordReader.
    reader_kwargs: an optional dict, of kwargs for the reader.

  Returns:
    key, value: a tuple of keys and values from the data_source.
  """
    data_files = get_data_files(data_sources)
    with ops.name_scope("single_pass_read"):
        filename_queue = tf_input.string_input_producer(data_files, num_epochs=1, shuffle=False, capacity=1)
        reader_kwargs = reader_kwargs or {}
        return reader_class(**reader_kwargs).read(filename_queue)
예제 #14
0
 def testManagedEndOfInputOneQueue(self):
   # Tests that the supervisor finishes without an error when using
   # a fixed number of epochs, reading from a single queue.
   logdir = self._test_dir("managed_end_of_input_one_queue")
   os.makedirs(logdir)
   data_path = self._csv_data(logdir)
   with ops.Graph().as_default():
     # Create an input pipeline that reads the file 3 times.
     filename_queue = input_lib.string_input_producer(
         [data_path], num_epochs=3)
     reader = io_ops.TextLineReader()
     _, csv = reader.read(filename_queue)
     rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]])
     sv = supervisor.Supervisor(logdir=logdir)
     with sv.managed_session("") as sess:
       while not sv.should_stop():
         sess.run(rec)
예제 #15
0
 def testManagedEndOfInputOneQueue(self):
     # Tests that the supervisor finishes without an error when using
     # a fixed number of epochs, reading from a single queue.
     logdir = self._test_dir("managed_end_of_input_one_queue")
     os.makedirs(logdir)
     data_path = self._csv_data(logdir)
     with ops.Graph().as_default():
         # Create an input pipeline that reads the file 3 times.
         filename_queue = input_lib.string_input_producer([data_path],
                                                          num_epochs=3)
         reader = io_ops.TextLineReader()
         _, csv = reader.read(filename_queue)
         rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]])
         sv = supervisor.Supervisor(logdir=logdir)
         with sv.managed_session("") as sess:
             while not sv.should_stop():
                 sess.run(rec)
예제 #16
0
  def testReadFromSameFile(self):
    with self.cached_session() as sess:
      reader1 = io_ops.LMDBReader(name="test_read_from_same_file1")
      reader2 = io_ops.LMDBReader(name="test_read_from_same_file2")
      filename_queue = input_lib.string_input_producer(
          [self.db_path], num_epochs=None)
      key1, value1 = reader1.read(filename_queue)
      key2, value2 = reader2.read(filename_queue)

      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
      for _ in range(3):
        for _ in range(10):
          k1, v1, k2, v2 = self.evaluate([key1, value1, key2, value2])
          self.assertAllEqual(compat.as_bytes(k1), compat.as_bytes(k2))
          self.assertAllEqual(compat.as_bytes(v1), compat.as_bytes(v2))
      coord.request_stop()
      coord.join(threads)
예제 #17
0
  def testReadFromSameFile(self):
    with self.test_session() as sess:
      reader1 = io_ops.LMDBReader(name="test_read_from_same_file1")
      reader2 = io_ops.LMDBReader(name="test_read_from_same_file2")
      filename_queue = input_lib.string_input_producer(
          [self.db_path], num_epochs=None)
      key1, value1 = reader1.read(filename_queue)
      key2, value2 = reader2.read(filename_queue)

      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
      for _ in range(3):
        for _ in range(10):
          k1, v1, k2, v2 = sess.run([key1, value1, key2, value2])
          self.assertAllEqual(compat.as_bytes(k1), compat.as_bytes(k2))
          self.assertAllEqual(compat.as_bytes(v1), compat.as_bytes(v2))
      coord.request_stop()
      coord.join(threads)
예제 #18
0
  def testReadFromFileRepeatedly(self):
    with self.cached_session() as sess:
      reader = io_ops.LMDBReader(name="test_read_from_file_repeated")
      filename_queue = input_lib.string_input_producer(
          [self.db_path], num_epochs=None)
      key, value = reader.read(filename_queue)

      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
      # Iterate over the lmdb 3 times.
      for _ in range(3):
        # Go over all 10 records each time.
        for j in range(10):
          k, v = self.evaluate([key, value])
          self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(j)))
          self.assertAllEqual(
              compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + j))))
      coord.request_stop()
      coord.join(threads)
예제 #19
0
    def _get_filename_queue(self, epoch_limit):
        """Constructs a filename queue with an epoch limit.

    `epoch_limit` is intended as an error checking fallback to prevent a reader
    from infinitely looping in its requests for more work items if none are
    available in any file. It should be set high enough that it is never reached
    assuming at least one record exists in some file.

    Args:
      epoch_limit: The maximum number of times to read through the complete list
        of files before throwing an OutOfRangeError.

    Returns:
      A tuple of (filename_queue, epoch_limiter):
        filename_queue: A FIFOQueue with filename work items.
        epoch_limiter: The local variable used for epoch limitation. This should
          be set to zero before a reader is passed `filename_queue` in order to
          reset the epoch limiter's state.
    """
        epoch_limiter = variable_scope.variable(
            initial_value=constant_op.constant(0, dtype=dtypes.int64),
            name="epoch_limiter",
            trainable=False,
            collections=[ops.GraphKeys.LOCAL_VARIABLES])
        filenames_tensor = array_ops.reshape(
            ops.convert_to_tensor(self._filenames), [-1])
        # We can't rely on epoch_limiter being initialized, since queue runners are
        # started before local variables are initialized. Instead, we ignore epoch
        # limits before variable initialization. This means that prior to variable
        # initialization, a QueueRunner may cause a reader to enter an un-checked
        # infinite loop. However, as soon as local variables are initialized, we
        # will start incrementing and checking epoch_limiter, which will interrupt
        # any in-progress loops.
        conditional_count_up_to = control_flow_ops.cond(
            state_ops.is_variable_initialized(epoch_limiter),
            lambda: epoch_limiter.count_up_to(epoch_limit),
            lambda: constant_op.constant(0, dtype=dtypes.int64))
        with ops.control_dependencies([conditional_count_up_to]):
            filenames_tensor = array_ops.identity(filenames_tensor)
        filename_queue = input_lib.string_input_producer(filenames_tensor,
                                                         shuffle=False,
                                                         capacity=1)
        return filename_queue, epoch_limiter
예제 #20
0
  def testReadFromFileRepeatedly(self):
    with self.test_session() as sess:
      reader = io_ops.LMDBReader(name="test_read_from_file_repeated")
      filename_queue = input_lib.string_input_producer(
          [self.db_path], num_epochs=None)
      key, value = reader.read(filename_queue)

      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
      # Iterate over the lmdb 3 times.
      for _ in range(3):
        # Go over all 10 records each time.
        for j in range(10):
          k, v = sess.run([key, value])
          self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(j)))
          self.assertAllEqual(
              compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + j))))
      coord.request_stop()
      coord.join(threads)
예제 #21
0
  def _get_filename_queue(self, epoch_limit):
    """Constructs a filename queue with an epoch limit.

    `epoch_limit` is intended as an error checking fallback to prevent a reader
    from infinitely looping in its requests for more work items if none are
    available in any file. It should be set high enough that it is never reached
    assuming at least one record exists in some file.

    Args:
      epoch_limit: The maximum number of times to read through the complete list
        of files before throwing an OutOfRangeError.

    Returns:
      A tuple of (filename_queue, epoch_limiter):
        filename_queue: A FIFOQueue with filename work items.
        epoch_limiter: The local variable used for epoch limitation. This should
          be set to zero before a reader is passed `filename_queue` in order to
          reset the epoch limiter's state.
    """
    epoch_limiter = variable_scope.variable(
        initial_value=constant_op.constant(0, dtype=dtypes.int64),
        name="epoch_limiter",
        trainable=False,
        collections=[ops.GraphKeys.LOCAL_VARIABLES])
    filenames_tensor = array_ops.reshape(
        ops.convert_to_tensor(self._filenames), [-1])
    # We can't rely on epoch_limiter being initialized, since queue runners are
    # started before local variables are initialized. Instead, we ignore epoch
    # limits before variable initialization. This means that prior to variable
    # initialization, a QueueRunner may cause a reader to enter an un-checked
    # infinite loop. However, as soon as local variables are initialized, we
    # will start incrementing and checking epoch_limiter, which will interrupt
    # any in-progress loops.
    conditional_count_up_to = control_flow_ops.cond(
        state_ops.is_variable_initialized(
            epoch_limiter), lambda: epoch_limiter.count_up_to(epoch_limit),
        lambda: constant_op.constant(0, dtype=dtypes.int64))
    with ops.control_dependencies([conditional_count_up_to]):
      filenames_tensor = array_ops.identity(filenames_tensor)
    filename_queue = input_lib.string_input_producer(
        filenames_tensor, shuffle=False, capacity=1)
    return filename_queue, epoch_limiter
예제 #22
0
    def _verify_all_data_sources_read(self, shared_queue):
        with self.cached_session():
            tfrecord_paths = test_utils.create_tfrecord_files(
                self.get_temp_dir(), num_files=3)

        num_readers = len(tfrecord_paths)
        p_reader = parallel_reader.ParallelReader(io_ops.TFRecordReader,
                                                  shared_queue,
                                                  num_readers=num_readers)

        data_files = parallel_reader.get_data_files(tfrecord_paths)
        filename_queue = input_lib.string_input_producer(data_files)
        key, value = p_reader.read(filename_queue)

        count0 = 0
        count1 = 0
        count2 = 0

        num_reads = 50

        sv = supervisor.Supervisor(logdir=self.get_temp_dir())
        with sv.prepare_or_wait_for_session() as sess:
            sv.start_queue_runners(sess)

            for _ in range(num_reads):
                current_key, _ = sess.run([key, value])
                if '0-of-3' in str(current_key):
                    count0 += 1
                if '1-of-3' in str(current_key):
                    count1 += 1
                if '2-of-3' in str(current_key):
                    count2 += 1

        self.assertGreater(count0, 0)
        self.assertGreater(count1, 0)
        self.assertGreater(count2, 0)
        self.assertEqual(count0 + count1 + count2, num_reads)
예제 #23
0
  def _verify_all_data_sources_read(self, shared_queue):
    with self.test_session():
      tfrecord_paths = test_utils.create_tfrecord_files(
          self.get_temp_dir(), num_files=3)

    num_readers = len(tfrecord_paths)
    p_reader = parallel_reader.ParallelReader(
        io_ops.TFRecordReader, shared_queue, num_readers=num_readers)

    data_files = parallel_reader.get_data_files(tfrecord_paths)
    filename_queue = input_lib.string_input_producer(data_files)
    key, value = p_reader.read(filename_queue)

    count0 = 0
    count1 = 0
    count2 = 0

    num_reads = 50

    sv = supervisor.Supervisor(logdir=self.get_temp_dir())
    with sv.prepare_or_wait_for_session() as sess:
      sv.start_queue_runners(sess)

      for _ in range(num_reads):
        current_key, _ = sess.run([key, value])
        if '0-of-3' in str(current_key):
          count0 += 1
        if '1-of-3' in str(current_key):
          count1 += 1
        if '2-of-3' in str(current_key):
          count2 += 1

    self.assertGreater(count0, 0)
    self.assertGreater(count1, 0)
    self.assertGreater(count2, 0)
    self.assertEquals(count0 + count1 + count2, num_reads)
예제 #24
0
 def testManagedMainErrorTwoQueues(self):
   # Tests that the supervisor correctly raises a main loop
   # error even when using multiple queues for input.
   logdir = self._test_dir("managed_main_error_two_queues")
   os.makedirs(logdir)
   data_path = self._csv_data(logdir)
   with self.assertRaisesRegexp(RuntimeError, "fail at step 3"):
     with ops.Graph().as_default():
       # Create an input pipeline that reads the file 3 times.
       filename_queue = input_lib.string_input_producer(
           [data_path], num_epochs=3)
       reader = io_ops.TextLineReader()
       _, csv = reader.read(filename_queue)
       rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]])
       shuff_rec = input_lib.shuffle_batch(rec, 1, 6, 4)
       sv = supervisor.Supervisor(logdir=logdir)
       with sv.managed_session("") as sess:
         for step in range(9):
           if sv.should_stop():
             break
           elif step == 3:
             raise RuntimeError("fail at step 3")
           else:
             sess.run(shuff_rec)
예제 #25
0
def _read_keyed_batch_examples_helper(file_pattern,
                                      batch_size,
                                      reader,
                                      randomize_input=True,
                                      num_epochs=None,
                                      queue_capacity=10000,
                                      num_threads=1,
                                      read_batch_size=1,
                                      parse_fn=None,
                                      setup_shared_queue=False,
                                      name=None):
    """Adds operations to read, queue, batch `Example` protos.

  Args:
    file_pattern: List of files or pattern of file paths containing
        `Example` records. See `tf.gfile.Glob` for pattern rules.
    batch_size: An int or scalar `Tensor` specifying the batch size to use.
    reader: A function or class that returns an object with
      `read` method, (filename tensor) -> (example tensor).
    randomize_input: Whether the input should be randomized.
    num_epochs: Integer specifying the number of times to read through the
      dataset. If `None`, cycles through the dataset forever.
      NOTE - If specified, creates a variable that must be initialized, so call
      `tf.initialize_all_variables()` as shown in the tests.
    queue_capacity: Capacity for input queue.
    num_threads: The number of threads enqueuing examples.
    read_batch_size: An int or scalar `Tensor` specifying the number of
      records to read at once
    parse_fn: Parsing function, takes `Example` Tensor returns parsed
      representation. If `None`, no parsing is done.
    setup_shared_queue: Whether to set up a shared queue for file names.
    name: Name of resulting op.

  Returns:
    Returns tuple of:
    - `Tensor` of string keys.
    - String `Tensor` of batched `Example` proto.

  Raises:
    ValueError: for invalid inputs.
  """
    # Retrieve files to read.
    file_names = _get_file_names(file_pattern, randomize_input)

    # Check input parameters are given and reasonable.
    if (not queue_capacity) or (queue_capacity <= 0):
        raise ValueError('Invalid queue_capacity %s.' % queue_capacity)
    if (batch_size is None) or (
        (not isinstance(batch_size, ops.Tensor)) and
        (batch_size <= 0 or batch_size > queue_capacity)):
        raise ValueError('Invalid batch_size %s, with queue_capacity %s.' %
                         (batch_size, queue_capacity))
    if (read_batch_size is None) or (
        (not isinstance(read_batch_size, ops.Tensor)) and
        (read_batch_size <= 0)):
        raise ValueError('Invalid read_batch_size %s.' % read_batch_size)
    if (not num_threads) or (num_threads <= 0):
        raise ValueError('Invalid num_threads %s.' % num_threads)
    if (num_epochs is not None) and (num_epochs <= 0):
        raise ValueError('Invalid num_epochs %s.' % num_epochs)

    with ops.name_scope(name, 'read_batch_examples', [file_pattern]) as scope:
        with ops.name_scope('file_name_queue') as file_name_queue_scope:
            if setup_shared_queue:
                shared_file_name_queue = _get_shared_file_name_queue(
                    file_names, randomize_input, num_epochs,
                    file_name_queue_scope)
                file_name_queue = data_flow_ops.FIFOQueue(
                    capacity=1, dtypes=[dtypes.string], shapes=[[]])
                enqueue_op = file_name_queue.enqueue(
                    shared_file_name_queue.dequeue())
                queue_runner.add_queue_runner(
                    queue_runner.QueueRunner(file_name_queue, [enqueue_op]))
            else:
                file_name_queue = input_ops.string_input_producer(
                    constant_op.constant(file_names, name='input'),
                    shuffle=randomize_input,
                    num_epochs=num_epochs,
                    name=file_name_queue_scope)

        example_list = _get_examples(file_name_queue, reader, num_threads,
                                     read_batch_size, parse_fn)

        enqueue_many = read_batch_size > 1

        if num_epochs is None:
            allow_smaller_final_batch = False
        else:
            allow_smaller_final_batch = True

        # Setup batching queue given list of read example tensors.
        if randomize_input:
            if isinstance(batch_size, ops.Tensor):
                min_after_dequeue = int(queue_capacity * 0.4)
            else:
                min_after_dequeue = max(queue_capacity - (3 * batch_size),
                                        batch_size)
            queued_examples_with_keys = input_ops.shuffle_batch_join(
                example_list,
                batch_size,
                capacity=queue_capacity,
                min_after_dequeue=min_after_dequeue,
                enqueue_many=enqueue_many,
                name=scope,
                allow_smaller_final_batch=allow_smaller_final_batch)
        else:
            queued_examples_with_keys = input_ops.batch_join(
                example_list,
                batch_size,
                capacity=queue_capacity,
                enqueue_many=enqueue_many,
                name=scope,
                allow_smaller_final_batch=allow_smaller_final_batch)
        if parse_fn and isinstance(queued_examples_with_keys, dict):
            queued_keys = queued_examples_with_keys.pop(KEY_FEATURE_NAME)
            return queued_keys, queued_examples_with_keys
        return queued_examples_with_keys
예제 #26
0
파일: graph_io.py 프로젝트: 01-/tensorflow
def read_batch_examples(file_pattern, batch_size, reader,
                        randomize_input=True, queue_capacity=10000,
                        num_threads=1, name='dequeue_examples'):
  """Adds operations to read, queue, batch `Example` protos.

  Given file pattern (or list of files), will setup a queue for file names,
  read `Example` proto using provided `reader`, use batch queue to create
  batches of examples of size `batch_size`.

  All queue runners are added to the queue runners collection, and may be
  started via `start_queue_runners`.

  All ops are added to the default graph.

  Args:
    file_pattern: List of files or pattern of file paths containing
        `Example` records. See `tf.gfile.Glob` for pattern rules.
    batch_size: An int or scalar `Tensor` specifying the batch size to use.
    reader: A function or class that returns an object with
      `read` method, (filename tensor) -> (example tensor).
    randomize_input: Whether the input should be randomized.
    queue_capacity: Capacity for input queue.
    num_threads: The number of threads enqueuing examples.
    name: Name of resulting op.

  Returns:
    String `Tensor` of batched `Example` proto.

  Raises:
    ValueError: for invalid inputs.
  """
  # Retrive files to read.
  if isinstance(file_pattern, list):
    file_names = file_pattern
    if not file_names:
      raise ValueError('No files given to dequeue_examples.')
  else:
    file_names = list(gfile.Glob(file_pattern))
    if not file_names:
      raise ValueError('No files match %s.' % file_pattern)

  # Sort files so it will be deterministic for unit tests. They'll be shuffled
  # in `string_input_producer` if `randomize_input` is enabled.
  if not randomize_input:
    file_names = sorted(file_names)

  # Check input parameters are given and reasonable.
  if (not queue_capacity) or (queue_capacity <= 0):
    raise ValueError('Invalid queue_capacity %s.' % queue_capacity)
  if (batch_size is None) or (
      (not isinstance(batch_size, ops.Tensor)) and
      (batch_size <= 0 or batch_size > queue_capacity)):
    raise ValueError(
        'Invalid batch_size %s, with queue_capacity %s.' %
        (batch_size, queue_capacity))
  if (not num_threads) or (num_threads <= 0):
    raise ValueError('Invalid num_threads %s.' % num_threads)

  with ops.name_scope(name) as scope:
    # Setup filename queue with shuffling.
    with ops.name_scope('file_name_queue') as file_name_queue_scope:
      file_name_queue = input_ops.string_input_producer(
          constant_op.constant(file_names, name='input'),
          shuffle=randomize_input, name=file_name_queue_scope)

    # Create reader and set it to read from filename queue.
    with ops.name_scope('read'):
      _, example_proto = reader().read(file_name_queue)

    # Setup batching queue.
    if randomize_input:
      if isinstance(batch_size, ops.Tensor):
        min_after_dequeue = int(queue_capacity * 0.4)
      else:
        min_after_dequeue = max(queue_capacity - (3 * batch_size), batch_size)
      examples = input_ops.shuffle_batch(
          [example_proto], batch_size, capacity=queue_capacity,
          num_threads=num_threads, min_after_dequeue=min_after_dequeue,
          name=scope)
    else:
      examples = input_ops.batch(
          [example_proto], batch_size, capacity=queue_capacity,
          num_threads=num_threads, name=scope)

    return examples
예제 #27
0
def _read_keyed_batch_examples_helper(file_pattern,
                                      batch_size,
                                      reader,
                                      randomize_input=True,
                                      num_epochs=None,
                                      queue_capacity=10000,
                                      num_threads=1,
                                      read_batch_size=1,
                                      parse_fn=None,
                                      setup_shared_queue=False,
                                      name=None):
  # Retrieve files to read.
  file_names = _get_file_names(file_pattern, randomize_input)

  # Check input parameters are given and reasonable.
  if (not queue_capacity) or (queue_capacity <= 0):
    raise ValueError('Invalid queue_capacity %s.' % queue_capacity)
  if (batch_size is None) or (
      (not isinstance(batch_size, ops.Tensor)) and
      (batch_size <= 0 or batch_size > queue_capacity)):
    raise ValueError(
        'Invalid batch_size %s, with queue_capacity %s.' %
        (batch_size, queue_capacity))
  if (read_batch_size is None) or (
      (not isinstance(read_batch_size, ops.Tensor)) and
      (read_batch_size <= 0)):
    raise ValueError('Invalid read_batch_size %s.' % read_batch_size)
  if (not num_threads) or (num_threads <= 0):
    raise ValueError('Invalid num_threads %s.' % num_threads)
  if (num_epochs is not None) and (num_epochs <= 0):
    raise ValueError('Invalid num_epochs %s.' % num_epochs)

  with ops.name_scope(name, 'read_batch_examples', [file_pattern]) as scope:
    with ops.name_scope('file_name_queue') as file_name_queue_scope:
      if setup_shared_queue:
        shared_file_name_queue = _get_shared_file_name_queue(
            file_names, randomize_input, num_epochs, file_name_queue_scope)
        file_name_queue = data_flow_ops.FIFOQueue(
            capacity=1, dtypes=[dtypes.string], shapes=[[]])
        enqueue_op = file_name_queue.enqueue(shared_file_name_queue.dequeue())
        queue_runner.add_queue_runner(
            queue_runner.QueueRunner(file_name_queue, [enqueue_op]))
      else:
        file_name_queue = input_ops.string_input_producer(
            constant_op.constant(
                file_names, name='input'),
            shuffle=randomize_input,
            num_epochs=num_epochs,
            name=file_name_queue_scope)

    example_list = _get_examples(file_name_queue, reader, num_threads,
                                 read_batch_size, parse_fn)

    enqueue_many = read_batch_size > 1

    if num_epochs is not None:
      allow_smaller_final_batch = True
    else:
      allow_smaller_final_batch = False

    # Setup batching queue given list of read example tensors.
    if randomize_input:
      if isinstance(batch_size, ops.Tensor):
        min_after_dequeue = int(queue_capacity * 0.4)
      else:
        min_after_dequeue = max(queue_capacity - (3 * batch_size), batch_size)
      queued_examples_with_keys = input_ops.shuffle_batch_join(
          example_list, batch_size, capacity=queue_capacity,
          min_after_dequeue=min_after_dequeue,
          enqueue_many=enqueue_many, name=scope,
          allow_smaller_final_batch=allow_smaller_final_batch)
    else:
      queued_examples_with_keys = input_ops.batch_join(
          example_list, batch_size, capacity=queue_capacity,
          enqueue_many=enqueue_many, name=scope,
          allow_smaller_final_batch=allow_smaller_final_batch)
    if parse_fn and isinstance(queued_examples_with_keys, dict):
      queued_keys = queued_examples_with_keys.pop(KEY_FEATURE_NAME)
      return queued_keys, queued_examples_with_keys
    return queued_examples_with_keys
예제 #28
0
def read_keyed_batch_examples(
    file_pattern, batch_size, reader,
    randomize_input=True, num_epochs=None,
    queue_capacity=10000, num_threads=1,
    read_batch_size=1, parse_fn=None,
    name=None):
  """Adds operations to read, queue, batch `Example` protos.

  Given file pattern (or list of files), will setup a queue for file names,
  read `Example` proto using provided `reader`, use batch queue to create
  batches of examples of size `batch_size`.

  All queue runners are added to the queue runners collection, and may be
  started via `start_queue_runners`.

  All ops are added to the default graph.

  Use `parse_fn` if you need to do parsing / processing on single examples.

  Args:
    file_pattern: List of files or pattern of file paths containing
        `Example` records. See `tf.gfile.Glob` for pattern rules.
    batch_size: An int or scalar `Tensor` specifying the batch size to use.
    reader: A function or class that returns an object with
      `read` method, (filename tensor) -> (example tensor).
    randomize_input: Whether the input should be randomized.
    num_epochs: Integer specifying the number of times to read through the
      dataset. If `None`, cycles through the dataset forever.
      NOTE - If specified, creates a variable that must be initialized, so call
      `tf.initialize_all_variables()` as shown in the tests.
    queue_capacity: Capacity for input queue.
    num_threads: The number of threads enqueuing examples.
    read_batch_size: An int or scalar `Tensor` specifying the number of
      records to read at once
    parse_fn: Parsing function, takes `Example` Tensor returns parsed
      representation. If `None`, no parsing is done.
    name: Name of resulting op.

  Returns:
    String `Tensor` of batched `Example` proto. If `keep_keys` is True, then
    returns tuple of string `Tensor`s, where first value is the key.

  Raises:
    ValueError: for invalid inputs.
  """
  # Retrive files to read.
  if isinstance(file_pattern, list):
    file_names = file_pattern
    if not file_names:
      raise ValueError('No files given to dequeue_examples.')
  else:
    file_names = list(gfile.Glob(file_pattern))
    if not file_names:
      raise ValueError('No files match %s.' % file_pattern)

  # Sort files so it will be deterministic for unit tests. They'll be shuffled
  # in `string_input_producer` if `randomize_input` is enabled.
  if not randomize_input:
    file_names = sorted(file_names)

  # Check input parameters are given and reasonable.
  if (not queue_capacity) or (queue_capacity <= 0):
    raise ValueError('Invalid queue_capacity %s.' % queue_capacity)
  if (batch_size is None) or (
      (not isinstance(batch_size, ops.Tensor)) and
      (batch_size <= 0 or batch_size > queue_capacity)):
    raise ValueError(
        'Invalid batch_size %s, with queue_capacity %s.' %
        (batch_size, queue_capacity))
  if (read_batch_size is None) or (
      (not isinstance(read_batch_size, ops.Tensor)) and
      (read_batch_size <= 0)):
    raise ValueError('Invalid read_batch_size %s.' % read_batch_size)
  if (not num_threads) or (num_threads <= 0):
    raise ValueError('Invalid num_threads %s.' % num_threads)
  if (num_epochs is not None) and (num_epochs <= 0):
    raise ValueError('Invalid num_epochs %s.' % num_epochs)

  with ops.op_scope([file_pattern], name, 'read_batch_examples') as scope:
    # Setup filename queue with shuffling.
    with ops.name_scope('file_name_queue') as file_name_queue_scope:
      file_name_queue = input_ops.string_input_producer(
          constant_op.constant(file_names, name='input'),
          shuffle=randomize_input, num_epochs=num_epochs,
          name=file_name_queue_scope)

    # Create readers, one per thread and set them to read from filename queue.
    with ops.name_scope('read'):
      example_list = []
      for _ in range(num_threads):
        if read_batch_size > 1:
          keys, examples_proto = reader().read_up_to(file_name_queue,
                                                     read_batch_size)
        else:
          keys, examples_proto = reader().read(file_name_queue)
        if parse_fn:
          parsed_examples = parse_fn(examples_proto)
          # Map keys into example map because batch_join doesn't support
          # tuple of Tensor + dict.
          if isinstance(parsed_examples, dict):
            parsed_examples[KEY_FEATURE_NAME] = keys
            example_list.append(parsed_examples)
          else:
            example_list.append((keys, parsed_examples))
        else:
          example_list.append((keys, examples_proto))

    enqueue_many = read_batch_size > 1

    # Setup batching queue given list of read example tensors.
    if randomize_input:
      if isinstance(batch_size, ops.Tensor):
        min_after_dequeue = int(queue_capacity * 0.4)
      else:
        min_after_dequeue = max(queue_capacity - (3 * batch_size), batch_size)
      queued_examples_with_keys = input_ops.shuffle_batch_join(
          example_list, batch_size, capacity=queue_capacity,
          min_after_dequeue=min_after_dequeue,
          enqueue_many=enqueue_many, name=scope)
    else:
      queued_examples_with_keys = input_ops.batch_join(
          example_list, batch_size, capacity=queue_capacity,
          enqueue_many=enqueue_many, name=scope)
    if parse_fn and isinstance(queued_examples_with_keys, dict):
      queued_keys = queued_examples_with_keys.pop(KEY_FEATURE_NAME)
      return queued_keys, queued_examples_with_keys
    return queued_examples_with_keys
예제 #29
0
def _read_keyed_batch_examples_helper(file_pattern,
                                      batch_size,
                                      reader,
                                      randomize_input=True,
                                      num_epochs=None,
                                      queue_capacity=10000,
                                      num_threads=1,
                                      read_batch_size=1,
                                      parse_fn=None,
                                      setup_shared_queue=False,
                                      name=None):
  """Adds operations to read, queue, batch `Example` protos.

  Args:
    file_pattern: List of files or pattern of file paths containing
        `Example` records. See `tf.gfile.Glob` for pattern rules.
    batch_size: An int or scalar `Tensor` specifying the batch size to use.
    reader: A function or class that returns an object with
      `read` method, (filename tensor) -> (example tensor).
    randomize_input: Whether the input should be randomized.
    num_epochs: Integer specifying the number of times to read through the
      dataset. If `None`, cycles through the dataset forever.
      NOTE - If specified, creates a variable that must be initialized, so call
      `tf.initialize_all_variables()` as shown in the tests.
    queue_capacity: Capacity for input queue.
    num_threads: The number of threads enqueuing examples.
    read_batch_size: An int or scalar `Tensor` specifying the number of
      records to read at once
    parse_fn: Parsing function, takes `Example` Tensor returns parsed
      representation. If `None`, no parsing is done.
    setup_shared_queue: Whether to set up a shared queue for file names.
    name: Name of resulting op.

  Returns:
    Returns tuple of:
    - `Tensor` of string keys.
    - String `Tensor` of batched `Example` proto.

  Raises:
    ValueError: for invalid inputs.
  """
  # Retrieve files to read.
  file_names = _get_file_names(file_pattern, randomize_input)

  # Check input parameters are given and reasonable.
  if (not queue_capacity) or (queue_capacity <= 0):
    raise ValueError('Invalid queue_capacity %s.' % queue_capacity)
  if (batch_size is None) or (
      (not isinstance(batch_size, ops.Tensor)) and
      (batch_size <= 0 or batch_size > queue_capacity)):
    raise ValueError(
        'Invalid batch_size %s, with queue_capacity %s.' %
        (batch_size, queue_capacity))
  if (read_batch_size is None) or (
      (not isinstance(read_batch_size, ops.Tensor)) and
      (read_batch_size <= 0)):
    raise ValueError('Invalid read_batch_size %s.' % read_batch_size)
  if (not num_threads) or (num_threads <= 0):
    raise ValueError('Invalid num_threads %s.' % num_threads)
  if (num_epochs is not None) and (num_epochs <= 0):
    raise ValueError('Invalid num_epochs %s.' % num_epochs)

  with ops.name_scope(name, 'read_batch_examples', [file_pattern]) as scope:
    with ops.name_scope('file_name_queue') as file_name_queue_scope:
      if setup_shared_queue:
        shared_file_name_queue = _get_shared_file_name_queue(
            file_names, randomize_input, num_epochs, file_name_queue_scope)
        file_name_queue = data_flow_ops.FIFOQueue(
            capacity=1, dtypes=[dtypes.string], shapes=[[]])
        enqueue_op = file_name_queue.enqueue(shared_file_name_queue.dequeue())
        queue_runner.add_queue_runner(
            queue_runner.QueueRunner(file_name_queue, [enqueue_op]))
      else:
        file_name_queue = input_ops.string_input_producer(
            constant_op.constant(
                file_names, name='input'),
            shuffle=randomize_input,
            num_epochs=num_epochs,
            name=file_name_queue_scope)

    example_list = _get_examples(file_name_queue, reader, num_threads,
                                 read_batch_size, parse_fn)

    enqueue_many = read_batch_size > 1

    if num_epochs is None:
      allow_smaller_final_batch = False
    else:
      allow_smaller_final_batch = True

    # Setup batching queue given list of read example tensors.
    if randomize_input:
      if isinstance(batch_size, ops.Tensor):
        min_after_dequeue = int(queue_capacity * 0.4)
      else:
        min_after_dequeue = max(queue_capacity - (3 * batch_size), batch_size)
      queued_examples_with_keys = input_ops.shuffle_batch_join(
          example_list, batch_size, capacity=queue_capacity,
          min_after_dequeue=min_after_dequeue,
          enqueue_many=enqueue_many, name=scope,
          allow_smaller_final_batch=allow_smaller_final_batch)
    else:
      queued_examples_with_keys = input_ops.batch_join(
          example_list, batch_size, capacity=queue_capacity,
          enqueue_many=enqueue_many, name=scope,
          allow_smaller_final_batch=allow_smaller_final_batch)
    if parse_fn and isinstance(queued_examples_with_keys, dict):
      queued_keys = queued_examples_with_keys.pop(KEY_FEATURE_NAME)
      return queued_keys, queued_examples_with_keys
    return queued_examples_with_keys
예제 #30
0
def read_keyed_batch_examples(file_pattern,
                              batch_size,
                              reader,
                              randomize_input=True,
                              num_epochs=None,
                              queue_capacity=10000,
                              num_threads=1,
                              read_batch_size=1,
                              parse_fn=None,
                              name=None):
    """Adds operations to read, queue, batch `Example` protos.

  Given file pattern (or list of files), will setup a queue for file names,
  read `Example` proto using provided `reader`, use batch queue to create
  batches of examples of size `batch_size`.

  All queue runners are added to the queue runners collection, and may be
  started via `start_queue_runners`.

  All ops are added to the default graph.

  Use `parse_fn` if you need to do parsing / processing on single examples.

  Args:
    file_pattern: List of files or pattern of file paths containing
        `Example` records. See `tf.gfile.Glob` for pattern rules.
    batch_size: An int or scalar `Tensor` specifying the batch size to use.
    reader: A function or class that returns an object with
      `read` method, (filename tensor) -> (example tensor).
    randomize_input: Whether the input should be randomized.
    num_epochs: Integer specifying the number of times to read through the
      dataset. If `None`, cycles through the dataset forever.
      NOTE - If specified, creates a variable that must be initialized, so call
      `tf.initialize_all_variables()` as shown in the tests.
    queue_capacity: Capacity for input queue.
    num_threads: The number of threads enqueuing examples.
    read_batch_size: An int or scalar `Tensor` specifying the number of
      records to read at once
    parse_fn: Parsing function, takes `Example` Tensor returns parsed
      representation. If `None`, no parsing is done.
    name: Name of resulting op.

  Returns:
    Returns tuple of:
    - `Tensor` of string keys.
    - String `Tensor` of batched `Example` proto.

  Raises:
    ValueError: for invalid inputs.
  """
    # Retrieve files to read.
    if isinstance(file_pattern, list):
        file_names = file_pattern
        if not file_names:
            raise ValueError('No files given to dequeue_examples.')
    else:
        file_names = list(gfile.Glob(file_pattern))
        if not file_names:
            raise ValueError('No files match %s.' % file_pattern)

    # Sort files so it will be deterministic for unit tests. They'll be shuffled
    # in `string_input_producer` if `randomize_input` is enabled.
    if not randomize_input:
        file_names = sorted(file_names)

    # Check input parameters are given and reasonable.
    if (not queue_capacity) or (queue_capacity <= 0):
        raise ValueError('Invalid queue_capacity %s.' % queue_capacity)
    if (batch_size is None) or (
        (not isinstance(batch_size, ops.Tensor)) and
        (batch_size <= 0 or batch_size > queue_capacity)):
        raise ValueError('Invalid batch_size %s, with queue_capacity %s.' %
                         (batch_size, queue_capacity))
    if (read_batch_size is None) or (
        (not isinstance(read_batch_size, ops.Tensor)) and
        (read_batch_size <= 0)):
        raise ValueError('Invalid read_batch_size %s.' % read_batch_size)
    if (not num_threads) or (num_threads <= 0):
        raise ValueError('Invalid num_threads %s.' % num_threads)
    if (num_epochs is not None) and (num_epochs <= 0):
        raise ValueError('Invalid num_epochs %s.' % num_epochs)

    with ops.name_scope(name, 'read_batch_examples', [file_pattern]) as scope:
        # Setup filename queue with shuffling.
        with ops.name_scope('file_name_queue') as file_name_queue_scope:
            file_name_queue = input_ops.string_input_producer(
                constant_op.constant(file_names, name='input'),
                shuffle=randomize_input,
                num_epochs=num_epochs,
                name=file_name_queue_scope)

        # Create readers, one per thread and set them to read from filename queue.
        with ops.name_scope('read'):
            example_list = []
            for _ in range(num_threads):
                if read_batch_size > 1:
                    keys, examples_proto = reader().read_up_to(
                        file_name_queue, read_batch_size)
                else:
                    keys, examples_proto = reader().read(file_name_queue)
                if parse_fn:
                    parsed_examples = parse_fn(examples_proto)
                    # Map keys into example map because batch_join doesn't support
                    # tuple of Tensor + dict.
                    if isinstance(parsed_examples, dict):
                        parsed_examples[KEY_FEATURE_NAME] = keys
                        example_list.append(parsed_examples)
                    else:
                        example_list.append((keys, parsed_examples))
                else:
                    example_list.append((keys, examples_proto))

        enqueue_many = read_batch_size > 1

        if num_epochs is not None:
            allow_smaller_final_batch = True
        else:
            allow_smaller_final_batch = False

        # Setup batching queue given list of read example tensors.
        if randomize_input:
            if isinstance(batch_size, ops.Tensor):
                min_after_dequeue = int(queue_capacity * 0.4)
            else:
                min_after_dequeue = max(queue_capacity - (3 * batch_size),
                                        batch_size)
            queued_examples_with_keys = input_ops.shuffle_batch_join(
                example_list,
                batch_size,
                capacity=queue_capacity,
                min_after_dequeue=min_after_dequeue,
                enqueue_many=enqueue_many,
                name=scope,
                allow_smaller_final_batch=allow_smaller_final_batch)
        else:
            queued_examples_with_keys = input_ops.batch_join(
                example_list,
                batch_size,
                capacity=queue_capacity,
                enqueue_many=enqueue_many,
                name=scope,
                allow_smaller_final_batch=allow_smaller_final_batch)
        if parse_fn and isinstance(queued_examples_with_keys, dict):
            queued_keys = queued_examples_with_keys.pop(KEY_FEATURE_NAME)
            return queued_keys, queued_examples_with_keys
        return queued_examples_with_keys
예제 #31
0
def read_batch_examples(file_pattern,
                        batch_size,
                        reader,
                        randomize_input=True,
                        queue_capacity=10000,
                        num_threads=1,
                        name='dequeue_examples'):
    """Adds operations to read, queue, batch `Example` protos.

  Given file pattern (or list of files), will setup a queue for file names,
  read `Example` proto using provided `reader`, use batch queue to create
  batches of examples of size `batch_size`.

  All queue runners are added to the queue runners collection, and may be
  started via `start_queue_runners`.

  All ops are added to the default graph.

  Args:
    file_pattern: List of files or pattern of file paths containing
        `Example` records. See `tf.gfile.Glob` for pattern rules.
    batch_size: An int or scalar `Tensor` specifying the batch size to use.
    reader: A function or class that returns an object with
      `read` method, (filename tensor) -> (example tensor).
    randomize_input: Whether the input should be randomized.
    queue_capacity: Capacity for input queue.
    num_threads: The number of threads enqueuing examples.
    name: Name of resulting op.

  Returns:
    String `Tensor` of batched `Example` proto.

  Raises:
    ValueError: for invalid inputs.
  """
    # Retrive files to read.
    if isinstance(file_pattern, list):
        file_names = file_pattern
        if not file_names:
            raise ValueError('No files given to dequeue_examples.')
    else:
        file_names = list(gfile.Glob(file_pattern))
        if not file_names:
            raise ValueError('No files match %s.' % file_pattern)

    # Sort files so it will be deterministic for unit tests. They'll be shuffled
    # in `string_input_producer` if `randomize_input` is enabled.
    if not randomize_input:
        file_names = sorted(file_names)

    # Check input parameters are given and reasonable.
    if (not queue_capacity) or (queue_capacity <= 0):
        raise ValueError('Invalid queue_capacity %s.' % queue_capacity)
    if (batch_size is None) or (
        (not isinstance(batch_size, ops.Tensor)) and
        (batch_size <= 0 or batch_size > queue_capacity)):
        raise ValueError('Invalid batch_size %s, with queue_capacity %s.' %
                         (batch_size, queue_capacity))
    if (not num_threads) or (num_threads <= 0):
        raise ValueError('Invalid num_threads %s.' % num_threads)

    with ops.name_scope(name) as scope:
        # Setup filename queue with shuffling.
        with ops.name_scope('file_name_queue') as file_name_queue_scope:
            file_name_queue = input_ops.string_input_producer(
                constant_op.constant(file_names, name='input'),
                shuffle=randomize_input,
                name=file_name_queue_scope)

        # Create reader and set it to read from filename queue.
        with ops.name_scope('read'):
            _, example_proto = reader().read(file_name_queue)

        # Setup batching queue.
        if randomize_input:
            if isinstance(batch_size, ops.Tensor):
                min_after_dequeue = int(queue_capacity * 0.4)
            else:
                min_after_dequeue = max(queue_capacity - (3 * batch_size),
                                        batch_size)
            examples = input_ops.shuffle_batch(
                [example_proto],
                batch_size,
                capacity=queue_capacity,
                num_threads=num_threads,
                min_after_dequeue=min_after_dequeue,
                name=scope)
        else:
            examples = input_ops.batch([example_proto],
                                       batch_size,
                                       capacity=queue_capacity,
                                       num_threads=num_threads,
                                       name=scope)

        return examples