def HowToRunOneEpoch(): ''' demonstrate how to run one epoch with shuffling ''' my_dataset = _GetDataset() provider = slim.dataset_data_provider.DatasetDataProvider( my_dataset, num_readers= 4, # The number of parallel readers that read data from the dataset shuffle= True, # Whether to shuffle the data sources and common queue when reading num_epochs= 1, # The number of times each data source is read. If left as None, the data will be cycled through indefinitely. common_queue_capacity=20 * cfg.BATCH_SIZE, # The capacity of the common queue. common_queue_min=10 * cfg. BATCH_SIZE, # The minimum number of elements in the common queue after a dequeue. ) [image, label] = provider.get(['image', 'label']) ''' preprocessing here is a demo with one hot transformation in real project, we need to perform data augmentation here ''' label = slim.one_hot_encoding(label, cfg.NUM_CLASSES) # note, num_epochs must be initialized by local variables initializer op_init = tf.local_variables_initializer() with tf.Session() as sess: sess.run(op_init) with queues.QueueRunners(sess): _OneEpochTraining(sess, image, label, show=True)
def testOutOfRangeError(self): with self.test_session(): [tfrecord_path ] = test_utils.create_tfrecord_files(self.get_temp_dir(), num_files=1) key, value = parallel_reader.single_pass_read( tfrecord_path, reader_class=io_ops.TFRecordReader) init_op = variables.local_variables_initializer() with self.test_session() as sess: sess.run(init_op) with queues.QueueRunners(sess): num_reads = 11 with self.assertRaises(errors_impl.OutOfRangeError): for _ in range(num_reads): sess.run([key, value])
def testTFRecordDataset(self): dataset_dir = tempfile.mkdtemp( prefix=os.path.join(self.get_temp_dir(), 'tfrecord_dataset')) height = 300 width = 280 with self.test_session(): provider = dataset_data_provider.DatasetDataProvider( _create_tfrecord_dataset(dataset_dir)) image, label = provider.get(['image', 'label']) image = _resize_image(image, height, width) with session.Session('') as sess: with queues.QueueRunners(sess): image, label = sess.run([image, label]) self.assertListEqual([height, width, 3], list(image.shape)) self.assertListEqual([1], list(label.shape))
def HowToRunOneEpochWithBatch(): ''' demonstrate how to run one epoch and pack data into batches ''' my_dataset = _GetDataset() provider = slim.dataset_data_provider.DatasetDataProvider( my_dataset, num_readers= 4, # The number of parallel readers that read data from the dataset shuffle= True, # Whether to shuffle the data sources and common queue when reading num_epochs= 1, # The number of times each data source is read. If left as None, the data will be cycled through indefinitely. common_queue_capacity=20 * cfg.BATCH_SIZE, # The capacity of the common queue. common_queue_min=10 * cfg. BATCH_SIZE, # The minimum number of elements in the common queue after a dequeue. ) [image, label] = provider.get(['image', 'label']) images, labels = tf.train.batch( [image, label], batch_size=cfg.BATCH_SIZE, num_threads=4, # The number of threads used to create the batches. capacity=5 * cfg.BATCH_SIZE, # The maximum number of elements in the queue. allow_smaller_final_batch= True, # If True, allow the final batch to be smaller if there are insufficient items left in the queue. ) batch_queue = slim.prefetch_queue.prefetch_queue([images, labels], capacity=5) images, labels = batch_queue.dequeue() # note, num_epochs must be initialized by local variables initializer op_init = tf.local_variables_initializer() with tf.Session() as sess: sess.run(op_init) with queues.QueueRunners(sess): _OneEpochTraining(sess, images, labels)
def testTFRecordReader(self): with self.test_session(): [tfrecord_path ] = test_utils.create_tfrecord_files(self.get_temp_dir(), num_files=1) key, value = parallel_reader.single_pass_read( tfrecord_path, reader_class=io_ops.TFRecordReader) init_op = variables.local_variables_initializer() with self.test_session() as sess: sess.run(init_op) with queues.QueueRunners(sess): flowers = 0 num_reads = 9 for _ in range(num_reads): current_key, _ = sess.run([key, value]) if 'flowers' in str(current_key): flowers += 1 self.assertGreater(flowers, 0) self.assertEquals(flowers, num_reads)
def testTFRecordDataset(self): dataset_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), 'tfrecord_dataset')) height = 300 width = 280 with self.cached_session(): test_dataset = _create_tfrecord_dataset(dataset_dir) provider = dataset_data_provider.DatasetDataProvider(test_dataset) key, image, label = provider.get(['record_key', 'image', 'label']) image = _resize_image(image, height, width) with session.Session('') as sess: with queues.QueueRunners(sess): key, image, label = sess.run([key, image, label]) split_key = key.decode('utf-8').split(':') self.assertEqual(2, len(split_key)) self.assertEqual(test_dataset.data_sources[0], split_key[0]) self.assertTrue(split_key[1].isdigit()) self.assertListEqual([height, width, 3], list(image.shape)) self.assertListEqual([1], list(label.shape))