forked from zygmuntz/classifier-calibration
/
platts_scaling.py
77 lines (49 loc) · 1.86 KB
/
platts_scaling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#!/usr/bin/env python
"calibrate a classifier's predictions using Platt's scaling (logistic regression)"
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression as LR
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score as AUC
from log_loss import log_loss
from get_diagram_data import get_diagram_data
from load_data_adult import y, p
###
# train/test split (in half)
train_end = y.shape[0] / 2
test_start = train_end + 1
y_train = y[0:train_end]
y_test =y[test_start:]
p_train = p[0:train_end]
p_test =p[test_start:]
###
lr = LR() # default param values
lr.fit( p_train.reshape( -1, 1 ), y_train ) # LR needs X to be 2-dimensional
p_calibrated = lr.predict_proba( p_test.reshape( -1, 1 ))[:,1]
###
acc = accuracy_score( y_test, np.round( p_test ))
acc_calibrated = accuracy_score( y_test, np.round( p_calibrated ))
auc = AUC( y_test, p_test )
auc_calibrated = AUC( y_test, p_calibrated )
ll = log_loss( y_test, p_test )
ll_calibrated = log_loss( y_test, p_calibrated )
print "accuracy - before/after:", acc, "/", acc_calibrated
print "AUC - before/after: ", auc, "/", auc_calibrated
print "log loss - before/after:", ll, "/", ll_calibrated
"""
accuracy - before/after: 0.847788697789 / 0.846805896806
AUC - before/after: 0.878139845077 / 0.878139845077
log loss - before/after: 0.630525772871 / 0.364873617584
"""
###
print "creating diagrams..."
n_bins = 10
# uncalibrated
mean_predicted_values, true_fractions = get_diagram_data( y_test, p_test, n_bins )
plt.plot( mean_predicted_values, true_fractions )
# calibrated
mean_predicted_values, true_fractions = get_diagram_data( y_test, p_calibrated, n_bins )
plt.plot( mean_predicted_values, true_fractions, 'green' )
# perfect calibration line
plt.plot( np.linspace( 0, 1 ), np.linspace( 0, 1 ), 'gray' )
plt.show()