/
image_input.py
137 lines (104 loc) · 4.7 KB
/
image_input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
IMAGE_SIZE = 64
NUM_CLASSES = 2
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 40000
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 1770
def read_cifar10(filename_queue):
""" Read a example from the filename queue. The TFRecordReader is used to
read examples from tfrecords' files. The decoder of decode_raw is used to
decode the tf.string of the example.
Args:
filename_queue: filename's queue where the file reader read from will be placed.
Returns:
A example used to train.
"""
class CIFAR10Record(object):
pass
result = CIFAR10Record()
label_dim = 300
result.height = 64
result.width = 64
result.depth = 3
reader = tf.TFRecordReader()
_, serializded_example = reader.read(filename_queue)
features = tf.parse_single_example(serializded_example,
features={
'label': tf.FixedLenFeature([], tf.string),
'image_raw': tf.FixedLenFeature([], tf.string)
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
depth_major = tf.reshape(image,
[result.height, result.width, result.depth])
result.uint8image = depth_major # tf.transpose(depth_major, [1, 2, 0])
label_raw = tf.decode_raw(features['label'], tf.float32)
result.label = tf.reshape(label_raw, [label_dim])
return result
def _generate_image_and_label_batch(image, label, min_queue_examples,
batch_size, shuffle):
""" generate a batch of images and labels.
Args:
image: the trained image.
label: label correspond to the image.
min_queue_examples: the least examples int the example's queue.
batch_size: the size of a batch.
shuffle: whether or not to shuffle the examples.
Returns:
A batch of examples including images and the corresponding label.
"""
num_preprocess_threads = 16
if shuffle:
images, label_batch = tf.train.shuffle_batch(
[image, label],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=min_queue_examples + 3 * batch_size,
min_after_dequeue=min_queue_examples)
else:
images, label_batch = tf.train.batch(
[image, label],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=min_queue_examples + 3 * batch_size)
# Display the training images in the visualizer.
tf.image_summary('images', images)
return images, label_batch
def distorted_inputs(data_dir, batch_size):
""" distort the images and get a batch of trained images.
Args:
data_dir: directory that place the images' data.
batch_size: the number of images that a step will be trained.
Returns:
A batch of examples including images and the corresponding label.
"""
filenames = [os.path.join(data_dir, f) for f in os.listdir(data_dir)]
# filenames = [os.path.join(data_dir, 'image_%d.tfrecords' % i)
# for i in xrange(0, 1)]
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
filename_queue = tf.train.string_input_producer(filenames)
read_input = read_cifar10(filename_queue)
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
height = 50
width = 50
distorted_image = tf.random_crop(reshaped_image, [height, width, 3])
distorted_image = tf.image.random_flip_left_right(distorted_image)
distorted_image = tf.image.random_brightness(distorted_image,
max_delta=63)
distorted_image = tf.image.random_contrast(distorted_image,
lower=0.2, upper=1.8)
float_image = tf.image.per_image_whitening(distorted_image)
min_fraction_of_examples_in_queue = 0.03
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
min_fraction_of_examples_in_queue)
print('Filling queue with %d CIFAR images before starting to train. '
'This will take a few minutes.' % min_queue_examples)
# Generate a batch of images and labels by building up a queue of examples.
return _generate_image_and_label_batch(float_image, read_input.label,
min_queue_examples, batch_size,
shuffle=True)