Exemplo n.º 1
0
def HowToRunOneEpoch():
    '''
    demonstrate how to run one epoch with shuffling
    '''
    my_dataset = _GetDataset()
    provider = slim.dataset_data_provider.DatasetDataProvider(
        my_dataset,
        num_readers=
        4,  # The number of parallel readers that read data from the dataset
        shuffle=
        True,  # Whether to shuffle the data sources and common queue when reading
        num_epochs=
        1,  # The number of times each data source is read. If left as None, the data will be cycled through indefinitely.
        common_queue_capacity=20 *
        cfg.BATCH_SIZE,  # The capacity of the common queue.
        common_queue_min=10 * cfg.
        BATCH_SIZE,  # The minimum number of elements in the common queue after a dequeue.
    )

    [image, label] = provider.get(['image', 'label'])
    '''
    preprocessing
    here is a demo with one hot transformation
    in real project, we need to perform data augmentation here
    '''
    label = slim.one_hot_encoding(label, cfg.NUM_CLASSES)

    # note, num_epochs must be initialized by local variables initializer
    op_init = tf.local_variables_initializer()
    with tf.Session() as sess:
        sess.run(op_init)
        with queues.QueueRunners(sess):
            _OneEpochTraining(sess, image, label, show=True)
Exemplo n.º 2
0
    def testOutOfRangeError(self):
        with self.test_session():
            [tfrecord_path
             ] = test_utils.create_tfrecord_files(self.get_temp_dir(),
                                                  num_files=1)

        key, value = parallel_reader.single_pass_read(
            tfrecord_path, reader_class=io_ops.TFRecordReader)
        init_op = variables.local_variables_initializer()

        with self.test_session() as sess:
            sess.run(init_op)
            with queues.QueueRunners(sess):
                num_reads = 11
                with self.assertRaises(errors_impl.OutOfRangeError):
                    for _ in range(num_reads):
                        sess.run([key, value])
Exemplo n.º 3
0
    def testTFRecordDataset(self):
        dataset_dir = tempfile.mkdtemp(
            prefix=os.path.join(self.get_temp_dir(), 'tfrecord_dataset'))

        height = 300
        width = 280

        with self.test_session():
            provider = dataset_data_provider.DatasetDataProvider(
                _create_tfrecord_dataset(dataset_dir))
            image, label = provider.get(['image', 'label'])
            image = _resize_image(image, height, width)

            with session.Session('') as sess:
                with queues.QueueRunners(sess):
                    image, label = sess.run([image, label])
            self.assertListEqual([height, width, 3], list(image.shape))
            self.assertListEqual([1], list(label.shape))
Exemplo n.º 4
0
def HowToRunOneEpochWithBatch():
    '''
    demonstrate how to run one epoch and pack data into batches
    '''
    my_dataset = _GetDataset()
    provider = slim.dataset_data_provider.DatasetDataProvider(
        my_dataset,
        num_readers=
        4,  # The number of parallel readers that read data from the dataset
        shuffle=
        True,  # Whether to shuffle the data sources and common queue when reading
        num_epochs=
        1,  # The number of times each data source is read. If left as None, the data will be cycled through indefinitely.
        common_queue_capacity=20 *
        cfg.BATCH_SIZE,  # The capacity of the common queue.
        common_queue_min=10 * cfg.
        BATCH_SIZE,  # The minimum number of elements in the common queue after a dequeue.
    )

    [image, label] = provider.get(['image', 'label'])

    images, labels = tf.train.batch(
        [image, label],
        batch_size=cfg.BATCH_SIZE,
        num_threads=4,  # The number of threads used to create the batches.
        capacity=5 *
        cfg.BATCH_SIZE,  # The maximum number of elements in the queue.
        allow_smaller_final_batch=
        True,  # If True, allow the final batch to be smaller if there are insufficient items left in the queue.
    )

    batch_queue = slim.prefetch_queue.prefetch_queue([images, labels],
                                                     capacity=5)

    images, labels = batch_queue.dequeue()

    # note, num_epochs must be initialized by local variables initializer
    op_init = tf.local_variables_initializer()
    with tf.Session() as sess:
        sess.run(op_init)
        with queues.QueueRunners(sess):
            _OneEpochTraining(sess, images, labels)
Exemplo n.º 5
0
    def testTFRecordReader(self):
        with self.test_session():
            [tfrecord_path
             ] = test_utils.create_tfrecord_files(self.get_temp_dir(),
                                                  num_files=1)

        key, value = parallel_reader.single_pass_read(
            tfrecord_path, reader_class=io_ops.TFRecordReader)
        init_op = variables.local_variables_initializer()

        with self.test_session() as sess:
            sess.run(init_op)
            with queues.QueueRunners(sess):
                flowers = 0
                num_reads = 9
                for _ in range(num_reads):
                    current_key, _ = sess.run([key, value])
                    if 'flowers' in str(current_key):
                        flowers += 1
                self.assertGreater(flowers, 0)
                self.assertEquals(flowers, num_reads)
  def testTFRecordDataset(self):
    dataset_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
                                                       'tfrecord_dataset'))

    height = 300
    width = 280

    with self.cached_session():
      test_dataset = _create_tfrecord_dataset(dataset_dir)
      provider = dataset_data_provider.DatasetDataProvider(test_dataset)
      key, image, label = provider.get(['record_key', 'image', 'label'])
      image = _resize_image(image, height, width)

      with session.Session('') as sess:
        with queues.QueueRunners(sess):
          key, image, label = sess.run([key, image, label])
      split_key = key.decode('utf-8').split(':')
      self.assertEqual(2, len(split_key))
      self.assertEqual(test_dataset.data_sources[0], split_key[0])
      self.assertTrue(split_key[1].isdigit())
      self.assertListEqual([height, width, 3], list(image.shape))
      self.assertListEqual([1], list(label.shape))