示例#1
0
# For each round of bagging
for l in range(L):

    # Extract training set by random sampling with replacement from X and y
    X_train, y_train = bootstrap(X, y, N, weights)
    
    # Fit logistic regression model to training data and save result
    logit_classifier = LogisticRegression()
    logit_classifier.fit(X_train, y_train)
    logits[l] = logit_classifier
    y_est = logit_classifier.predict(X).T
    votes = votes + y_est

    ErrorRate = (y!=y_est).sum(dtype=float)/N
    print('Error rate: {:2.2f}%'.format(ErrorRate*100))    
    
# Estimated value of class labels (using 0.5 as threshold) by majority voting
y_est_ensemble = votes>(L/2)

# Compute error rate
ErrorRate = (y!=y_est_ensemble).sum(dtype=float)/N
print('Error rate: {:3.2f}%'.format(ErrorRate*100))

ce = BinClassifierEnsemble(logits)
figure(1); dbprobplot(ce, X, y, 'auto', resolution=200)
figure(2); dbplot(ce, X, y, 'auto', resolution=200)

show()

print('Ran Exercise 9.2.1')
示例#2
0
    print('Error rate: {:2.2f}%'.format(ErrorRate * 100))

# Estimated value of class labels (using 0.5 as threshold) by majority voting
alpha = alpha / sum(alpha)
y_est_ensemble = y_all @ alpha > 0.5

#y_est_ensemble = votes > (L/2)
#y_est_ensemble = mat(y_all) * mat(alpha) - (1-mat(y_all)) * mat(alpha) > 0
ErrorRateEnsemble = sum(y_est_ensemble != y) / N

# Compute error rate
#ErrorRate = (y!=y_est_ensemble).sum(dtype=float)/N
print('Error rate for ensemble classifier: {:.1f}%'.format(ErrorRateEnsemble *
                                                           100))

ce = BinClassifierEnsemble(logits, alpha)
#ce = BinClassifierEnsemble(logits) # What happens if alpha is not included?
plt.figure(1)
dbprobplot(ce, X, y, 'auto', resolution=200)
plt.figure(2)
dbplot(ce, X, y, 'auto', resolution=200)
#plt.figure(3); plt.plot(alpha);

#%%
plt.figure(4, figsize=(8, 8))
for i in range(2):
    plt.plot(X[(y_est_ensemble == i), 0], X[(y_est_ensemble == i), 1],
             'br'[i] + 'o')

## Incomment the below lines to investigate miss-classifications
#for i in range(2):