forked from stevendaniluk/road_classifier
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
177 lines (144 loc) · 6.88 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# Trains the road classifier model
#
# Every step the training loss is outputted and summaries are recorded.
# Accuracy on the validation set is checked every 20 steps, and Metadata
# is recorded every 100 steps. Once training is complete, the model is saved.
#
# Summaries are saved for the following data:
# -Cross entropy
# -Prediction accuracy
# -Confusion matrix (TP, TN, FP, and FN)
# -IoU (Intersection over union)
import os, os.path
import scipy.misc
import numpy as np
import tensorflow as tf
import data_loader as data
import model_v6 as model
import parameters as params
# Check for log data
if tf.gfile.Exists(params.log_dir):
tf.gfile.DeleteRecursively(params.log_dir)
tf.gfile.MakeDirs(params.log_dir)
tf.gfile.MakeDirs(params.log_dir + "/images")
# Make model directory
if (params.save_model or params.early_stopping):
if not os.path.exists(params.model_ckpt_dir):
os.makedirs(params.model_ckpt_dir)
sess = tf.InteractiveSession()
# Use weighted cross entropy as the loss function
with tf.name_scope('cross_entropy'):
num_positives = tf.maximum(tf.reduce_sum(tf.cast(model.y_, tf.int32)), 1)
num_negatives = tf.sub(tf.size(model.y_), num_positives)
class_ratio = tf.cast(num_negatives, tf.float32)/tf.cast(num_positives, tf.float32)
diff = tf.nn.weighted_cross_entropy_with_logits(tf.clip_by_value(model.y, 1e-10, 1.0), tf.cast(model.y_, tf.float32), class_ratio)
with tf.name_scope('total'):
cross_entropy = tf.reduce_mean(diff)
tf.scalar_summary('cross entropy', cross_entropy)
# Add optimizer to the graph to minimize cross entropy
with tf.name_scope('train'):
train_step = tf.train.AdamOptimizer(params.learning_rate).minimize(cross_entropy)
with tf.name_scope('accuracy'):
with tf.name_scope('correct_prediction'):
correct_prediction = tf.equal(model.prediction, model.y_)
with tf.name_scope('accuracy'):
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.scalar_summary('accuracy', accuracy)
with tf.name_scope('prediction_stats') as scope:
true_pos = tf.reduce_mean(tf.cast(tf.logical_and(model.prediction, model.y_), tf.float32))
true_neg = tf.reduce_mean(tf.cast(tf.logical_and(tf.logical_not(model.prediction), tf.logical_not(model.y_)), tf.float32))
false_pos = tf.reduce_mean(tf.cast(tf.logical_and(model.prediction, tf.logical_not(model.y_)), tf.float32))
false_neg = tf.reduce_mean(tf.cast(tf.logical_and(tf.logical_not(model.prediction), model.y_), tf.float32))
IoU = true_pos/(true_pos + false_pos + false_neg)
tf.scalar_summary(scope + 'true_pos', true_pos)
tf.scalar_summary(scope + 'true_neg', true_neg)
tf.scalar_summary(scope + 'false_pos', false_pos)
tf.scalar_summary(scope + 'false_neg', false_neg)
tf.scalar_summary(scope + 'IoU', IoU)
# Merge all the summaries and write them out
merged = tf.merge_all_summaries()
train_writer = tf.train.SummaryWriter(params.log_dir + '/train', sess.graph)
val_writer = tf.train.SummaryWriter(params.log_dir + '/val', sess.graph)
print "Initializing variables."
tf.initialize_all_variables().run()
saver = tf.train.Saver()
# Initialize previous prediction
null_pred = np.full((params.batch_size, params.res["height"], params.res["width"]), 0.5)
null_pred_single = np.full((1, params.res["height"], params.res["width"]), 0.5)
prev_pred = null_pred
val_prev_pred = null_pred_single
# Train the model, write summaries, and check accuracy on the entire validations set
# Sample predictions and metadata will be periodically saved.
best_val_acc = 0
best_val_acc_step = 0
print "Beginning training (max {0} steps).".format(params.max_steps)
for i in range(params.max_steps):
# Load the data
x, y, prev_x = data.LoadTrainBatch(params.batch_size)
if params.feedback:
# Form prediciton
feed_dict = {model.x:prev_x, model.prev_y:null_pred, model.keep_prob:1.0}
prev_pred = tf.sigmoid(sess.run(model.y, feed_dict=feed_dict)).eval()
feed_dict = {model.x:x, model.y_:y, model.prev_y:prev_pred, model.keep_prob:params.dropout}
# Train operation
if i % 500 == 499:
# Record execution stats
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
summary, loss, _ = sess.run([merged, cross_entropy, train_step], feed_dict=feed_dict, options=run_options, run_metadata=run_metadata)
train_writer.add_run_metadata(run_metadata, 'step%03d' % i)
train_writer.add_summary(summary, i)
print "Saved metadata for step {0}.".format(i)
else:
# Record a summary
summary, loss, _ = sess.run([merged, cross_entropy, train_step], feed_dict=feed_dict)
train_writer.add_summary(summary, i)
if i % 10 == 0:
print "Training step {0} loss:{1:.3f}".format(i, loss)
# Save sample predictions
if i % 100 == 0:
feed_dict = {model.x:x, model.y_:y, model.prev_y:prev_pred, model.keep_prob:1.0}
prediction = sess.run(model.prediction, feed_dict=feed_dict)
scipy.misc.imsave((params.log_dir + "/images/step_" + str(i) + "_raw.png"), np.squeeze(x[0]))
scipy.misc.imsave((params.log_dir + "/images/step_" + str(i) + "._label.png"), y[0])
scipy.misc.imsave((params.log_dir + "/images/step_" + str(i) + "._pred.png"), prediction[0])
print "Saved sample images."
# Measure validation set accuracy (over entire set)
if (i % 100 == 0):
acc_count = 0
# Loop through each image in the validation set.
# Write a summary for the last image, and run one extra time and to rotate through
# the dataset so the summary isn't always on the same image
for j in range(data.num_val_imgs + 1):
x, y, prev_x = data.LoadValBatch(1)
if params.feedback:
# Form prediciton
feed_dict = {model.x:prev_x, model.prev_y:null_pred_single, model.keep_prob:1.0}
val_prev_pred = tf.sigmoid(sess.run(model.y, feed_dict=feed_dict)).eval()
feed_dict={model.x:x, model.y_:y, model.prev_y:val_prev_pred, model.keep_prob: 1.0}
if j == data.num_val_imgs:
summary = sess.run(merged, feed_dict=feed_dict)
else:
acc_count += sess.run(accuracy, feed_dict=feed_dict)
acc = acc_count/data.num_val_imgs
print "Validation set accuracy: {0:.3f}".format(acc)
val_writer.add_summary(summary, i)
# Save model when it has improved
if params.early_stopping and acc > best_val_acc:
best_val_acc = acc
best_val_acc_step = i
checkpoint_path = os.path.join(params.model_ckpt_dir, "model.ckpt")
filename = saver.save(sess, checkpoint_path)
print "Model saved in file: {0}.".format(filename)
# Early stopping
if params.early_stopping and (i - best_val_acc_step) > 1000:
print "Stopping at step {0}.".format(i)
break
train_writer.close()
val_writer.close()
print "Training complete."
# Save the final model (if desired)
if (params.save_model and not params.early_stopping):
checkpoint_path = os.path.join(params.model_ckpt_dir, "model.ckpt")
filename = saver.save(sess, checkpoint_path)
print "Model saved in file: {0}.".format(filename)