-
Notifications
You must be signed in to change notification settings - Fork 3
/
W2V_GaussianNB.py
64 lines (52 loc) · 1.92 KB
/
W2V_GaussianNB.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
from sklearn.naive_bayes import GaussianNB
from sklearn import metrics
from sklearn.decomposition import NMF
import datetime
import matplotlib.pyplot as plt
if __name__ == "__main__":
startTime = datetime.datetime.now()
#Load training data
x = np.load('data/train_w2v_data_array.npy')
y = np.load('data/train_w2v_target_array.npy')
y = y.astype('int')
y = y.flatten()
#Load test data
z = np.load('data/test_w2v_data_array.npy')
t = np.load('data/test_w2v_target_array.npy')
t = t.astype('int')
t = t.flatten()
#Predict using Naive Bayes Model
clf = GaussianNB()
# nmf = NMF(n_components=500, init='random', random_state=0)
# x_500d = nmf.fit_transform(x)
# z_500d = nmf.transform(z)
clf.fit(x, y)
p = clf.predict(z)
# Compute training time
endTime = datetime.datetime.now() - startTime
print("Total time taken to train: ", endTime)
print("\n")
print("W2V Gaussian Naive Bayes")
# Compute accuracy
accuracy = metrics.accuracy_score(t, p, normalize=False)
print("Accuracy: ", (accuracy/len(t)) * 100)
# Confusion matrix
confusion_matrix = metrics.confusion_matrix(t, p)
print("Confusion Matrix:\n", confusion_matrix)
# Replace 4s with 1s
t[np.where(t == 4)] = 1
p[np.where(p == 4)] = 1
y_score = clf.predict_proba(z)
# Plot the Precision-Recall curve
precision, recall, _ = metrics.precision_recall_curve(t, y_score[:,1])
plt.step(recall, precision, color='b', alpha=0.2, where='post')
plt.fill_between(recall, precision, step='post', alpha=0.2, color='b')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.0, 1.05])
plt.xlim([0.0, 1.0])
average_precision = metrics.average_precision_score(t, p)
plt.title('W2V Gaussian NB Precision-Recall curve: AP={0:0.2f}'.format(average_precision))
plt.savefig('data/w2v_GaussianNB_precisionRecall.png')
plt.show()