/
oneVsAll.py
30 lines (22 loc) · 926 Bytes
/
oneVsAll.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np
from ann import ANN
class OneVsAll:
def __init__(self, dp_train):
self.logistic_classifiers = []
self.dp_train = dp_train
def train(self, hidden_layers, epochs, learning_rate):
labels = self.dp_train.labels()
for label in labels:
data = self.dp_train.binarizeU(label, upsampled=True)
X = data[:, 0:-1]
y_train_logistic = data[:, -1]
logistic_classifier = ANN(hidden_layers, epochs, learning_rate, verbose=False)
logistic_classifier.train(X, y_train_logistic)
self.logistic_classifiers.append(logistic_classifier)
def predict(self, x):
results = []
labels = self.dp_train.labels()
for classifier in self.logistic_classifiers:
results.append(classifier.predict(x))
index = np.argmax(results)
return labels[index]