コード例 #1
0
def yellow_brick_validation_curve(model, x, y, cpu_count, cv_count, param,
                                  scoring_metric):
    """

    """
    from yellowbrick.model_selection import LearningCurve
    from sklearn.model_selection import StratifiedKFold

    # Create the learning curve visualizer
    cv = StratifiedKFold(n_splits=cv_count)
    # Validation Curve

    mpl.rcParams['axes.prop_cycle'] = cycler('color', ['purple', 'darkblue'])

    fig = plt.gcf()
    fig.set_size_inches(10, 10)
    ax = plt.subplot(411)

    viz = ValidationCurve(model,
                          n_jobs=cpu_count,
                          ax=ax,
                          param_name=param,
                          param_range=np.arange(1, 11),
                          cv=cv,
                          scoring=scoring_metric)

    # Fit and poof the visualizer
    viz.fit(x, y)
    viz.show()
コード例 #2
0
def generate_validation_curve(model, clf_name, param_name, param_range,
                              scoring, cv, dataset_name, X_train, y_train):
    if 'svm' in clf_name or 'nn' == clf_name:
        train_scores, test_scores = validation_curve(model,
                                                     X_train,
                                                     y_train,
                                                     param_name=param_name,
                                                     param_range=param_range,
                                                     scoring="accuracy",
                                                     n_jobs=8)
        train_scores_mean = np.mean(train_scores, axis=1)
        train_scores_std = np.std(train_scores, axis=1)
        test_scores_mean = np.mean(test_scores, axis=1)
        test_scores_std = np.std(test_scores, axis=1)

        plt.title("Validation Curve with {}".format(clf_name))
        plt.xlabel(param_name)
        plt.ylabel("Score")
        plt.semilogx(param_range,
                     train_scores_mean,
                     label="Training score",
                     marker='o',
                     color="#0272a2")
        plt.semilogx(param_range,
                     test_scores_mean,
                     label="Cross-validation score",
                     marker='o',
                     color="#9fc377")
        plt.legend(loc="best")
        plt.savefig("results/{}_model_complexity_{}_{}.png".format(
            clf_name, dataset_name, param_name))
        plt.clf()

    else:
        viz = ValidationCurve(model,
                              param_name=param_name,
                              param_range=param_range,
                              scoring=scoring,
                              cv=cv)
        viz.fit(X_train, y_train)
        viz.show("results/{}_model_complexity_{}_{}.png".format(
            clf_name, dataset_name, param_name))
        plt.clf()
コード例 #3
0
def validation_curve(model, x, y, param, rang, cv):
    """
    
    :param model: Modelo a ser avaliado.
    :param x: Variáveis independentes de treino.
    :param y: Variavel dependente de treino.
    :param param: Parametro do modelo a ser avaliado.
    :param rang: Espaço de hipotese do parametro que esta sendo avaliado.
    :param cv: quantidade de splits para a cross validação.
    :return: Viz das curvas de validação.
    """
    viz = ValidationCurve(model,
                          param_name=param,
                          param_range=rang,
                          cv=cv,
                          scoring="roc_auc",
                          n_jobs=-1)

    viz.fit(x, y)
    viz.show()
コード例 #4
0
def validation_curve_classifier_knn(path="images/validation_curve_classifier_knn.png"):
    X, y = load_game()
    X = OneHotEncoder().fit_transform(X)

    _, ax = plt.subplots()
    cv = StratifiedKFold(4)
    param_range = np.arange(3, 20, 2)

    print("warning: generating the KNN validation curve can take a very long time!")

    oz = ValidationCurve(
        KNeighborsClassifier(),
        ax=ax,
        param_name="n_neighbors",
        param_range=param_range,
        cv=cv,
        scoring="f1_weighted",
        n_jobs=8,
    )
    oz.fit(X, y)
    oz.show(outpath=path)
コード例 #5
0
def validation_curve_sklearn_example(
    path="images/validation_curve_sklearn_example.png"
):
    digits = load_digits()
    X, y = digits.data, digits.target

    _, ax = plt.subplots()

    param_range = np.logspace(-6, -1, 5)
    oz = ValidationCurve(
        SVC(),
        ax=ax,
        param_name="gamma",
        param_range=param_range,
        logx=True,
        cv=10,
        scoring="accuracy",
        n_jobs=4,
    )
    oz.fit(X, y)
    oz.show(outpath=path)
コード例 #6
0
def validation_curve_classifier_svc(path="images/validation_curve_classifier_svc.png"):
    X, y = load_game()
    X = OneHotEncoder().fit_transform(X)

    _, ax = plt.subplots()
    cv = StratifiedKFold(12)
    param_range = np.logspace(-6, -1, 12)

    print("warning: generating the SVC validation curve can take a very long time!")

    oz = ValidationCurve(
        SVC(),
        ax=ax,
        param_name="gamma",
        param_range=param_range,
        logx=True,
        cv=cv,
        scoring="f1_weighted",
        n_jobs=8,
    )
    oz.fit(X, y)
    oz.show(outpath=path)
コード例 #7
0
for ind, model in enumerate(models):
    model.fit(x_train, y_train)
    preds = model.predict(x_test)
    for index, ax in enumerate(axes):
        residuals_plot(model, x_test, preds, hist=False, ax=ax[index])
        prediction_error(model, x_test, preds, ax=ax)

# Do some scoring on XGB estimators
# Validation curve
viz = ValidationCurve(XGBRegressor(objective="reg:squarederror"),
                      param_name="max_depth",
                      param_range=np.arange(1, 11),
                      cv=5,
                      scoring="r2")
viz.fit(x_train, y_train)
viz.show()

# Learning curve
model = XGBRegressor(objective="reg:squarederror")
viz_2 = LearningCurve(model, scoring="r2")
viz_2.fit(x_train, y_train)
viz_2.show()

model = RFECV(LassoCV(), cv=5, scoring='r2')
model.fit(x_train, y_train)
model.show()
"""
Section: 5
Time-Series Algorithms
"""
# Fitting ARIMA