-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_feature_extraction.py
80 lines (65 loc) · 2.83 KB
/
train_feature_extraction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import pickle
import tensorflow as tf
from sklearn.model_selection import train_test_split
from alexnet import AlexNet
from sklearn.utils import shuffle
import time
nb_classes = 43
epochs = 10
batch_size = 128
# TODO: Load traffic signs data.
with open('./train.p', 'rb') as f:
data = pickle.load(f)
# TODO: Split data into training and validation sets.
X_train, X_val, y_train, y_val = train_test_split(data['features'], data['labels'], test_size=0.33, random_state=0)
# TODO: Define placeholders and resize operation.
labels = tf.placeholder(tf.int64, None)
features = tf.placeholder(tf.float32, (None, 32, 32, 3))
resized = tf.image.resize_images(features, (227, 227))
# TODO: pass placeholder as first argument to `AlexNet`.
fc7 = AlexNet(resized, feature_extract=True)
# NOTE: `tf.stop_gradient` prevents the gradient from flowing backwards
# past this point, keeping the weights before and up to `fc7` frozen.
# This also makes training faster, less work to do!
fc7 = tf.stop_gradient(fc7)
# TODO: Add the final layer for traffic sign classification.
shape = (fc7.get_shape().as_list()[-1], nb_classes)
fc8W = tf.Variable(tf.truncated_normal(shape, stddev=1e-2))
fc8b = tf.Variable(tf.zeros(nb_classes))
logits = tf.nn.xw_plus_b(fc7, fc8W, fc8b)
# TODO: Define loss, training, accuracy operations.
# HINT: Look back at your traffic signs project solution, you may
# be able to reuse some the code.
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels)
loss_op = tf.reduce_mean(cross_entropy)
opt = tf.train.AdamOptimizer()
# TODO: Train and evaluate the feature extraction model.
train_op = opt.minimize(loss_op, var_list=[fc8W, fc8b])
init_op = tf.global_variables_initializer()
preds = tf.arg_max(logits, 1)
accuracy_op = tf.reduce_mean(tf.cast(tf.equal(preds, labels), tf.float32))
def eval(X, y, sess):
total_acc = 0
total_loss = 0
for offset in range(0, X.shape[0], batch_size):
end = offset + batch_size
X_batch = X[offset:end]
y_batch = y[offset:end]
loss, acc = sess.run([loss_op, accuracy_op], feed_dict={features: X_batch, labels: y_batch})
total_loss += (loss * X_batch.shape[0])
total_acc += (acc * X_batch.shape[0])
return total_loss/X.shape[0], total_acc/X.shape[0]
with tf.Session() as sess:
sess.run(init_op)
for i in range(epochs):
X_train, y_train = shuffle(X_train, y_train)
t0 = time.time()
for offset in range(0, X_train.shape[0], batch_size):
end = offset + batch_size
sess.run(train_op, feed_dict={features: X_train[offset:end], labels: y_train[offset:end]})
val_loss, val_acc = eval(X_val, y_val, sess)
print("Epoch", i+1)
print("Time: %.3f seconds" % (time.time() - t0))
print("Validation loss:", val_loss)
print("Validation accuracy:", val_acc)
print("")