forked from renmengye/deep-tracker
/
run_matching.py
105 lines (77 loc) · 2.62 KB
/
run_matching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import cslab_environ
from saver import Saver
import logger
import numpy as np
import sys
import tensorflow as tf
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import plot_utils as pu
import matching_model as model
import matching_data as data
log = logger.get()
def get_model(opt, device='/cpu:0'):
return model.get_model(opt, device)
def get_dataset(opt):
dataset = {}
folder = '/ais/gobi3/u/mren/data/kitti/tracking/training'
dataset = data.get_dataset(folder, opt, split=None, seqs=[20])
return dataset
def get_batch_fn(dataset):
def get_batch(idx):
x1_bat = dataset['images_0'][idx]
x2_bat = dataset['images_1'][idx]
y_bat = dataset['labels'][idx]
x1_bat, x2_bat, y_bat = preprocess(x1_bat, x2_bat, y_bat)
return x1_bat, x2_bat, y_bat
return get_batch
def plot_output(fname, x1, x2, y_gt, y_out):
num_ex = y_out.shape[0]
num_items = 2
num_row, num_col, calc = pu.calc_row_col(
num_ex, num_items, max_items_per_row=9)
f1, axarr = plt.subplots(num_row, num_col, figsize=(10, num_row))
pu.set_axis_off(axarr, num_row, num_col)
for ii in xrange(num_ex):
for jj in xrange(num_items):
row, col = calc(ii, jj)
if jj == 0:
axarr[row, col].imshow(x1[ii])
else:
axarr[row, col].imshow(x2[ii])
axarr[row, col].text(0, 0, '{:.2f} {:.2f}'.format(
y_gt[ii], y_out[ii]),
color=(0, 0, 0), size=8)
plt.tight_layout(pad=2.0, w_pad=0.0, h_pad=0.0)
plt.savefig(fname, dpi=150)
plt.close('all')
def preprocess(x1, x2, y):
"""Preprocess training data."""
return (x1.astype('float32') / 255,
x2.astype('float32') / 255,
y.astype('float32'))
if __name__ == '__main__':
restore_folder = sys.argv[1]
saver = Saver(restore_folder)
ckpt_info = saver.get_ckpt_info()
model_opt = ckpt_info['model_opt']
data_opt = ckpt_info['data_opt']
ckpt_fname = ckpt_info['ckpt_fname']
step = ckpt_info['step']
model_id = ckpt_info['model_id']
log.info('Building model')
m = get_model(model_opt)
log.info('Loading dataset')
dataset = get_dataset(data_opt)
sess = tf.Session()
saver.restore(sess, ckpt_fname)
idx = np.arange(10)
get_batch = get_batch_fn(dataset)
x1, x2, y_gt = get_batch(idx)
y_out = sess.run(m['y_out'],
feed_dict={
m['x1']: x1, m['x2']: x2, m['phase_train']: False})
print y_out
print y_gt
plot_output('/u/mren/test_matching.png', x1, x2, y_gt, y_out)