/
test_ensemble.py
56 lines (47 loc) · 1.85 KB
/
test_ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# import the necessary packages
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
from keras.models import load_model
from keras.datasets import cifar10
import numpy as np
import argparse
import glob
import os
# construct the argument parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--models", required=True,
help="path to models directory")
args = vars(ap.parse_args())
# load the testing data, then scale it into the range [0, 1]
(testX, testY) = cifar10.load_data()[1]
testX = testX.astype("float") / 255.0
# initialize the label names for the CIFAR-10 dataset
labelNames = ["airplane", "automobile", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck"]
# convert the labels from integers to vectors
lb = LabelBinarizer()
testY = lb.fit_transform(testY)
# construct the path used to collect the models then initialize the
# model list
modelPaths = os.path.sep.join([args["models"], "*.model"])
modelPaths = list(glob.glob(modelPaths))
models = []
# loop over the model paths, loading the model, and adding it to
# the list of models
for (i, modelPath) in enumerate(modelPaths):
print("[INFO] loading model {}/{}".format(i + 1,
len(modelPaths)))
models.append(load_model(modelPath))
# initialize the list of predictions
print("[INFO] evaluating ensemble...")
predictions = []
# loop over the models
for model in models:
# use the current model to make predictions on the testing data,
# then store these predictions in the aggregate predictions list
predictions.append(model.predict(testX, batch_size=64))
# average the probabilities across all model predictions, then show
# a classification report
predictions = np.average(predictions, axis=0)
print(classification_report(testY.argmax(axis=1),
predictions.argmax(axis=1), target_names=labelNames))