forked from ml-lab/TensorBox
/
eval.py
174 lines (138 loc) · 5.68 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import tensorflow as tf
import matplotlib.pyplot as plt
import os
import json
import argparse
import subprocess
import scipy as scp
from scipy.misc import imread, imread, imsave, imresize
import time
import sys
import logging
from random import shuffle
#from train import build_lstm_forward, build_overfeat_forward
from train import build_overfeat_forward
from utils import googlenet_load, train_utils
from utils.annolist import AnnotationLib as al
from utils.stitch_wrapper import stitch_rects
from utils.train_utils import add_rectangles, rescale_boxes
flags = tf.app.flags
FLAGS = flags.FLAGS
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
level=logging.INFO,
stream=sys.stdout)
tf.app.flags.DEFINE_string('hypes', './hypes/default.json',
"""HYPES""")
tf.app.flags.DEFINE_string('run', None,
"""Run to Analyse.""")
def run_eval(H, checkpoint_dir , hypes_file, output_path):
"""Do Evaluation with full epoche of data.
Args:
H: Hypes
checkpoint_dir: directory with checkpoint files
output_path: path to save results
"""
#Load GT
true_idl = H['data']['test_idl']
true_annos = al.parse(true_idl)
# define output files
pred_file = 'val_%s.idl' % os.path.basename(hypes_file).replace('.json', '')
pred_idl = os.path.join(output_path, pred_file)
true_file = 'true_%s.idl' % os.path.basename(hypes_file).replace('.json', '')
true_idl_scaled = os.path.join(output_path, true_file)
data_folder = os.path.dirname(os.path.realpath(true_idl))
#Load Graph Model
tf.reset_default_graph()
googlenet = googlenet_load.init(H)
x_in = tf.placeholder(tf.float32, name='x_in')
if H['arch']['use_lstm']:
lstm_forward = build_lstm_forward(H, tf.expand_dims(x_in, 0),
googlenet, 'test', reuse=None)
pred_boxes, pred_logits, pred_confidences = lstm_forward
else:
overfeat_forward = build_overfeat_forward(H, tf.expand_dims(x_in, 0),
googlenet, 'test')
pred_boxes, pred_logits, pred_confidences = overfeat_forward
start_time = time.time()
saver = tf.train.Saver()
with tf.Session() as sess:
logging.info("Starting Evaluation")
sess.run(tf.initialize_all_variables())
# Restore Checkpoints
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
logging.info(ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
annolist = al.AnnoList()
trueanno = al.AnnoList()
#shuffle true_annos to randomize plottet Images
shuffle(true_annos)
for i in range(len(true_annos)):
true_anno = true_annos[i]
img = imread( os.path.join(data_folder, true_anno.imageName))
# Rescale Boxes
trueanno.append(rescale_boxes(img.shape, true_annos[i],
H["arch"]["image_height"],
H["arch"]["image_width"]))
# Rescale Images
img = imresize(img, (H["arch"]["image_height"],
H["arch"]["image_width"]), interp='cubic')
feed = {x_in: img}
(np_pred_boxes, np_pred_confidences) = sess.run([pred_boxes,
pred_confidences],
feed_dict=feed)
pred_anno = al.Annotation()
pred_anno.imageName = true_anno.imageName
new_img, rects = add_rectangles([img], np_pred_confidences,
np_pred_boxes, H["arch"],
use_stitching=True,
rnn_len=H['arch']['rnn_len'],
min_conf=0.3)
pred_anno.rects = rects
annolist.append(pred_anno)
if i % 20 == 0:
# Draw every 20th Image;
# plotted Image is randomized due to shuffling
duration = time.time() - start_time
duration = float(duration)*1000/20
out_img = os.path.join(output_path, 'test_%i.png'%i)
scp.misc.imsave(out_img, new_img)
logging.info('Step %d: Duration %.3f ms'
% (i, duration))
start_time = time.time()
annolist.save(pred_idl)
trueanno.save(true_idl_scaled)
# write results to disk
iou_threshold = 0.5
rpc_cmd = './utils/annolist/doRPC.py --minOverlap %f %s %s' % (iou_threshold, true_idl_scaled,
pred_idl)
rpc_output = subprocess.check_output(rpc_cmd, shell=True)
txt_file = [line for line in rpc_output.split('\n') if line.strip()][-1]
output_png = os.path.join(output_path, "roc.png")
plot_cmd = './utils/annolist/plotSimple.py %s --output %s' % (txt_file, output_png)
plot_output = subprocess.check_output(plot_cmd, shell=True)
def main(_):
'''
Parse command line arguments, load data and create output folder.
Output will be stored in the rundir/output. The last checkpoint
in rundir is loaded for evaluation.
'''
if FLAGS.run is None:
logging.error("No Checkpoint dir is provided!")
logging.error("Usage: eval.py --run=path/to/checkpointdir --hypes=HYPES")
exit(1)
# Get and create output_path
output_path = os.path.realpath(os.path.join(FLAGS.run,
"eval"))
if not os.path.exists(output_path):
os.makedirs(output_path)
#Load Hypes
with open(FLAGS.hypes, 'r') as f:
H = json.load(f)
#run evaluation
run_eval(H,FLAGS.run, FLAGS.hypes, output_path)
logging.info("Evaluation Complete. Results are saved in: %s",
output_path)
if __name__ == '__main__':
tf.app.run()