-
Notifications
You must be signed in to change notification settings - Fork 1
/
predict.py
137 lines (116 loc) · 4.58 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import argparse
from glob import glob
import os
import json
import random
import chainer
import chainer.links as L
import chainer.functions as F
import numpy as np
from networks.mobilenetv2 import MobilenetV2
from networks.vgg16 import VGG16
from networks.resnet50 import ResNet50
from dataset import FoodDataset, preprocess
from chainer.links.model.vision import vgg
from chainer.links.model.vision import resnet
def find_latest(model_dir):
files = glob(os.path.join(model_dir, "model_epoch_*.npz"))
numbers = []
for f in files:
base = os.path.basename(f)
base = base[len("model_epoch_"):]
base = base[:-len(".npz")]
e = int(base)
numbers.append(e)
e = max(numbers)
return os.path.join(model_dir, "model_epoch_{}.npz".format(e))
def prepare_setting(args):
model_path = find_latest(args.model_path)
print(model_path)
jsonpath = os.path.join(os.path.dirname(model_path), "args.json")
print(jsonpath)
with open(jsonpath, 'r') as f:
train_args = json.load(f)
model_cand = {"mv2": MobilenetV2,
"vgg16": VGG16,
"resnet50": ResNet50}
if train_args["model_name"] == "mv2":
model = MobilenetV2(num_classes=101, depth_multiplier=1.0)
else:
model = model_cand[train_args["model_name"]](num_classes=101)
model = L.Classifier(model)
chainer.serializers.load_npz(model_path, model)
test_dataset = FoodDataset(args.dataset,
model_name=train_args["model_name"],
train=False)
if args.device >= 0:
# use GPU
chainer.backends.cuda.get_device_from_id(args.device).use()
model.predictor.to_gpu()
import cupy as xp
else:
# use CPU
xp = np
preprocess_ = lambda image: preprocess(image, train_args["model_name"])
return model, preprocess_, xp, test_dataset,
def predict(args):
classes = np.genfromtxt(os.path.join(args.dataset, "meta", "classes.txt"),
str,
delimiter="\n")
model, preprocess, xp, test_dataset = prepare_setting(args)
from chainer.exporters import caffe
import onnx_chainer
# x = [chainer.Variable(np.zeros((1, 3, 224, 224), np.float32))]
# caffe.export(model, x, None, True, 'test')
x = np.zeros((1, 3, 224, 224), dtype=np.float32)
onnx_chainer.export(model, x, filename='text.onnx')
top_1_counter = 0
top_5_counter = 0
top_10_counter = 0
indices = list(range(len(test_dataset)))
num_iteration = len(indices) if args.sample < 0 else args.sample
random.shuffle(indices)
with chainer.function.no_backprop_mode(), chainer.using_config('train', False):
for i in indices[:num_iteration]:
img, label = test_dataset.get_example(i)
h = model.predictor(xp.expand_dims(xp.array(img), axis=0))
print(xp.expand_dims(xp.array(img), axis=0).shape)
prediction = chainer.functions.softmax(h)
if args.device >= 0:
prediction = xp.asnumpy(prediction[0].data)
else:
prediction = prediction[0].data
top_ten = np.argsort(-prediction)[:10]
top_five = top_ten[:5]
if top_five[0] == label:
top_1_counter += 1
top_5_counter += 1
top_10_counter += 1
msg = "Bingo!"
elif label in top_five:
top_5_counter += 1
top_10_counter += 1
msg = "matched top 5"
elif label in top_ten:
top_10_counter += 1
msg = "matched top 10"
else:
msg = "Boo, actual {}".format(classes[label])
print(classes[top_five], prediction[top_five], msg)
print('top1 accuracy', top_1_counter / num_iteration)
print('top5 accuracy', top_5_counter / num_iteration)
print('top10 accuracy', top_10_counter / num_iteration)
def parse_argument():
parser = argparse.ArgumentParser()
parser.add_argument(
"model_path", type=str, help="path/to/snapshot e.g. pretrained/model_epoch_100.npz")
parser.add_argument("--sample", type=int, default=-1,
help="select num of --sample from test dataset to evaluate accuracy")
parser.add_argument("--device", type=int, default=0,
help="specify GPU_ID. If negative, use CPU")
parser.add_argument("--dataset", type=str, default=".")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_argument()
predict(args)