-
Notifications
You must be signed in to change notification settings - Fork 2
/
__init__.py
104 lines (85 loc) · 3.28 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import pandas as pd
__author__ = 'JasonLiu'
class Taxonomy:
"""
Taxonomy
~~~~~~~~
This class is used to compose a modular Classification Taxonomy
:usage:
root = TaxonomyNode(
"alcohol",
{
0:"non_drinking",
1:"drinking"
},
alcohol_classifier
)
first_person = TaxonomyNode(
"first_person",
{
0:"alcohol_related",
1:"first_person"
},
first_person_classifier
)
first_person_level = TaxonomyNode(
"first_person_level",
{
0:"first_person_casual",
1:"first_person_looking",
2:"first_person_reflecting",
3:"first_person_heavy"
},
first_person_level_classifier
)
first_person.add_children({1: first_person_level})
root.add_children({1: first_person})
y_predict = root.predict(X, deep=True)
"""
def __init__(self, name, label2name, sklearn_classifier):
"""
:param name: (str) name of the node/classifier
:param label2name: (dict[int, str]) maps the label to the name
:param sklearn_classifier: sklearn classifier hat implements predict
"""
assert(hasattr(sklearn_classifier, "predict"))
self.name = name
self.label2name = label2name
self.clf = sklearn_classifier
self.children = None
def add_children(self, label2node):
"""
:param label2node: (dict[int, Taxonomy])
"""
self.children = label2node
def _name(self):
def get(label):
return (label, self.label2name[label]) if label in self.label2name else label
return get
def predict(self, X, deep=False):
# Shallow prediction means that we only want to predict for a single node
if not deep:
return pd.Series(self.clf.predict(X), index=X.index)
# We want to predict for all of the children below
if deep:
current_labels = self.predict(X)
print(current_labels.value_counts())
# If there are children, depth first traverse the taxonomy and replace labels
# with new classifications in the form (label:int, label:str)
if self.children:
new_labels = []
# Satisfy all the childen nodes
for (label, child) in self.children.items():
# Select the relevant data
# Predict new labels and add to new_labels
relevant_slice = X[current_labels == label]
predicted_slice = child.predict(relevant_slice, deep=True)# .apply(self._name)
new_labels.append(predicted_slice.apply(child._name()))
# Satisfy all the leaf nodes
for label, _ in self.label2name.items():
if not label in self.children:
relevant_labels = current_labels[current_labels == label]
new_labels.append(relevant_labels.apply(self._name()))
return pd.concat(new_labels).apply(self._name())
else:
return current_labels