/
imagenet_pretrained.py
64 lines (52 loc) · 1.9 KB
/
imagenet_pretrained.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from keras.applications import ResNet50
from keras.applications import InceptionV3
from keras.applications import Xception
from keras.applications import VGG16
from keras.applications import VGG19
from keras.applications import imagenet_utils
from keras.applications.inception_v3 import preprocess_input
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
import numpy as np
import argparse
import cv2
ap = argparse.ArgumentParser()
ap.add_argument('-i', '--image', required = True,
help = "path to the image")
ap.add_argument('-m', '--model', required = True , default = "vgg16",
help = "name of pretrained model")
args = vars(ap.parse_args())
MODELS = {
"vgg16": VGG16,
"vgg19": VGG19,
"inception": InceptionV3,
"xception": Xception,
"resnet": ResNet50
}
if args['model'] not in MODELS.keys():
raise AssertionError("The --model command line argument should "
"be a key in the `MODELS` dictionary")
inputShape = (224 , 224)
preprocess = imagenet_utils.preprocess_input
if args['model'] in ('xception', 'inception'):
inputShape = (299, 299)
preprocess = preprocess_input
print('[INFO] loading {}...'.format(args['model']))
Network = MODELS[args['model']]
model = Network(weights = 'imagenet')
print('[INFO] loading image and preprocessing')
image = load_img(args['image'], target_size = inputShape)
image = img_to_array(image)
image = np.expand_dims(image , axis = 0)
image = preprocess(image)
print('[INFO] Classifying Image')
preds = model.predict(image)
P = imagenet_utils.decode_predictions(preds)
for (i , (imageID , label , prob)) in enumerate(P[0]):
print('[INFO] {}. {} --- {:.2f}%'.format(i + 1, label , prob*100))
image = cv2.imread(args['image'])
(id , label , prob) = P[0][0]
cv2.putText(image , "Label: {}".format(label), (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
0.8, (0, 255, 0), 2)
cv2.imshow('Preds', image)
cv2.waitKey(0)