forked from peret/visualize-bovw
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ensemble_visualization.py
58 lines (47 loc) · 2.07 KB
/
ensemble_visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from visualization import Visualization
from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier
from vcd import VisualConceptDetection
from itertools import izip
from datamanagers.CaltechManager import CaltechManager
import os
import numpy as np
class EnsembleVisualization(Visualization):
def get_image_title(self, prediction, real):
"""Returns a string that describes whether the prediction
is a true positive, false positive, etc. and with what
confidence the prediction is made.
Args:
prediction: List of predicted probabilities of
the respective classes.
real: List of corresponding correct labels.
"""
p = np.argmax(prediction)
result = ""
result += "True " if p == real else "False "
result += "positive" if p == 1 else "negative"
result += " - confidence: %.5f" % prediction[p]
return result
if __name__ == "__main__":
# ada = AdaBoostClassifier()
# ada.n_estimators = 50
# ada.base_estimator.max_depth = 1
random_forest = RandomForestClassifier(n_estimators=100)
category = "trilobite"
dataset = "all"
datamanager = CaltechManager()
datamanager.PATHS["RESULTS"] = os.path.join(datamanager.PATHS["BASE"], "results_trilobite_rf_testing")
# vcd = VisualConceptDetection(ada, datamanager)
vcd = VisualConceptDetection(random_forest, datamanager)
clf = vcd.load_object("Classifier", category)
feature_importances = clf.feature_importances_
sample_matrix = vcd.datamanager.build_sample_matrix(dataset, category)
class_vector = vcd.datamanager.build_class_vector(dataset, category)
pred = clf.predict_proba(sample_matrix)
vis = EnsembleVisualization(datamanager)
del clf
image_titles = [vis.get_image_title(prediction, real) for prediction, real in
izip(pred, class_vector)]
del class_vector
del sample_matrix
img_names = [f for f in vcd.datamanager.get_image_names(dataset, category)]
vis.visualize_images(img_names, feature_importances, image_titles)