Beispiel #1
0
def learning_curves(model):
    """model class dependent

    WIP

    get training history plots for xgboost, lightgbm

    returns list of PlotArtifacts, can be empty if no history
    is found
    """
    plots = []

    # do this here and not in the call to learning_curve plots,
    # this is default approach for xgboost and lightgbm
    if hasattr(model, "evals_result"):
        results = model.evals_result()
        train_set = list(results.items())[0]
        valid_set = list(results.items())[1]

        learning_curves = pd.DataFrame({
            "train_error": train_set[1]["error"],
            "train_auc": train_set[1]["auc"],
            "valid_error": valid_set[1]["error"],
            "valid_auc": valid_set[1]["auc"],
        })

        plt.clf()  # gcf_clear(plt)
        fig, ax = plt.subplots()
        plt.xlabel("# training examples")
        plt.ylabel("auc")
        plt.title("learning curve - auc")
        ax.plot(learning_curves.train_auc, label="train")
        ax.plot(learning_curves.valid_auc, label="valid")
        ax.legend(loc="lower left")
        plots.append(PlotArtifact("learning curve - auc", body=plt.gcf()))

        plt.clf()  # gcf_clear(plt)
        fig, ax = plt.subplots()
        plt.xlabel("# training examples")
        plt.ylabel("error rate")
        plt.title("learning curve - error")
        ax.plot(learning_curves.train_error, label="train")
        ax.plot(learning_curves.valid_error, label="valid")
        ax.legend(loc="lower left")
        plots.append(PlotArtifact("learning curve - taoot", body=plt.gcf()))

    # elif some other model history api...

    return plots
Beispiel #2
0
def feature_importances(model, header):
    """Display estimated feature importances
    Only works for models with attribute 'feature_importances_`
    :param model:       fitted model
    :param header:      feature labels
    """
    if not hasattr(model, "feature_importances_"):
        raise Exception(
            "feature importaces are only available for some models")

    # create a feature importance table with desired labels
    zipped = zip(model.feature_importances_, header)
    feature_imp = pd.DataFrame(sorted(zipped),
                               columns=["freq", "feature"
                                        ]).sort_values(by="freq",
                                                       ascending=False)

    plt.clf()  #gcf_clear(plt)
    plt.figure(figsize=(20, 10))
    sns.barplot(x="freq", y="feature", data=feature_imp)
    plt.title("features")
    plt.tight_layout()

    return (PlotArtifact("feature-importances", body=plt.gcf()),
            TableArtifact("feature-importances-tbl", df=feature_imp))
Beispiel #3
0
def _kaplan_meier_log_model(
    context,
    model,
    time_column: str = "tenure",
    dataset_key: str = "km-timelines",
    plot_key: str = "km-survival",
    plots_dest: str = "plots",
    models_dest: str = "models",
    file_ext: str = "csv",
):
    import matplotlib.pyplot as plt

    o = []
    for obj in model.__dict__.keys():
        if isinstance(model.__dict__[obj], pd.DataFrame):
            o.append(model.__dict__[obj])
    df = pd.concat(o, axis=1)
    df.index.name = time_column
    context.log_dataset(dataset_key, df=df, index=True, format=file_ext)
    model.plot()
    context.log_artifact(
        PlotArtifact(plot_key, body=plt.gcf()),
        local_path=f"{plots_dest}/{plot_key}.html",
    )
    context.log_model(
        "km-model",
        body=dumps(model),
        model_dir=f"{models_dest}/km",
        model_file="model.pkl",
    )
Beispiel #4
0
def feature_importances(model, header):
    """Display estimated feature importances
    Only works for models with attribute 'feature_importances_`
    :param model:       fitted model
    :param header:      feature labels
    """
    if not hasattr(model, "feature_importances_"):
        raise Exception(
            "feature importances are only available for some models, if you got "
            "here then please make sure to check your estimated model for a "
            "`feature_importances_` attribute before calling this method")

    # create a feature importance table with desired labels
    zipped = zip(model.feature_importances_, header)
    feature_imp = pd.DataFrame(sorted(zipped),
                               columns=["freq", "feature"
                                        ]).sort_values(by="freq",
                                                       ascending=False)

    plt.clf()  # gcf_clear(plt)
    plt.figure()
    sns.barplot(x="freq", y="feature", data=feature_imp)
    plt.title("features")
    plt.tight_layout()

    return (
        PlotArtifact("feature-importances",
                     body=plt.gcf(),
                     title="Feature Importances"),
        feature_imp,
    )
Beispiel #5
0
def precision_recall_multi(ytest_b, yprob, labels, scoring="micro"):
    """"""
    n_classes = len(labels)

    precision = dict()
    recall = dict()
    avg_prec = dict()
    for i in range(n_classes):
        precision[i], recall[i], _ = metrics.precision_recall_curve(
            ytest_b[:, i], yprob[:, i])
        avg_prec[i] = metrics.average_precision_score(ytest_b[:, i], yprob[:,
                                                                           i])
    precision["micro"], recall["micro"], _ = metrics.precision_recall_curve(
        ytest_b.ravel(), yprob.ravel())
    avg_prec["micro"] = metrics.average_precision_score(ytest_b,
                                                        yprob,
                                                        average="micro")
    ap_micro = avg_prec["micro"]
    # model_metrics.update({'precision-micro-avg-classes': ap_micro})

    # gcf_clear(plt)
    colors = cycle(
        ["navy", "turquoise", "darkorange", "cornflowerblue", "teal"])
    plt.figure()
    f_scores = np.linspace(0.2, 0.8, num=4)
    lines = []
    labels = []
    for f_score in f_scores:
        x = np.linspace(0.01, 1)
        y = f_score * x / (2 * x - f_score)
        (l, ) = plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=0.2)
        plt.annotate(f"f1={f_score:0.1f}", xy=(0.9, y[45] + 0.02))

    lines.append(l)
    labels.append("iso-f1 curves")
    (l, ) = plt.plot(recall["micro"], precision["micro"], color="gold", lw=10)
    lines.append(l)
    labels.append(f"micro-average precision-recall (area = {ap_micro:0.2f})")

    for i, color in zip(range(n_classes), colors):
        (l, ) = plt.plot(recall[i], precision[i], color=color, lw=2)
        lines.append(l)
        labels.append(
            f"precision-recall for class {i} (area = {avg_prec[i]:0.2f})")

    # fig = plt.gcf()
    # fig.subplots_adjust(bottom=0.25)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("recall")
    plt.ylabel("precision")
    plt.title("precision recall - multiclass")
    plt.legend(lines, labels, loc=(0, -0.41), prop=dict(size=10))

    return PlotArtifact(
        "precision-recall-multiclass",
        body=plt.gcf(),
        title="Multiclass Precision Recall",
    )
Beispiel #6
0
def plot_iter(context, iterations, col='accuracy', num_bins=10):
    df = pd.read_csv(BytesIO(iterations.get()))
    x = df['output.{}'.format(col)]
    fig, ax = plt.subplots(figsize=(6, 6))
    n, bins, patches = ax.hist(x, num_bins, density=1)
    ax.set_xlabel('Accuraccy')
    ax.set_ylabel('Count')
    context.log_artifact(PlotArtifact('myfig', body=fig))
Beispiel #7
0
def plot_roc(
    context,
    y_labels,
    y_probs,
    key="roc",
    plots_dir: str = "plots",
    fmt="png",
    fpr_label: str = "false positive rate",
    tpr_label: str = "true positive rate",
    title: str = "roc curve",
    legend_loc: str = "best",
):
    """plot roc curves

    TODO:  add averaging method (as string) that was used to create probs, 
    display in legend

    :param context:      the function context
    :param y_labels:     ground truth labels, hot encoded for multiclass  
    :param y_probs:      model prediction probabilities
    :param key:          ("roc") key of plot in artifact store
    :param plots_dir:    ("plots") destination folder relative path to artifact path
    :param fmt:          ("png") plot format
    :param fpr_label:    ("false positive rate") x-axis labels
    :param tpr_label:    ("true positive rate") y-axis labels
    :param title:        ("roc curve") title of plot
    :param legend_loc:   ("best") location of plot legend
    """
    # clear matplotlib current figure
    gcf_clear(plt)

    # draw 45 degree line
    plt.plot([0, 1], [0, 1], "k--")

    # labelling
    plt.xlabel(fpr_label)
    plt.ylabel(tpr_label)
    plt.title(title)
    plt.legend(loc=legend_loc)

    # single ROC or mutliple
    if y_labels.shape[1] > 1:
        # data accummulators by class
        fpr = dict()
        tpr = dict()
        roc_auc = dict()
        for i in range(y_labels[:, :-1].shape[1]):
            fpr[i], tpr[i], _ = metrics.roc_curve(y_labels[:, i],
                                                  y_probs[:, i],
                                                  pos_label=1)
            roc_auc[i] = metrics.auc(fpr[i], tpr[i])
            plt.plot(fpr[i], tpr[i], label=f"class {i}")
    else:
        fpr, tpr, _ = metrics.roc_curve(y_labels, y_probs[:, 1], pos_label=1)
        plt.plot(fpr, tpr, label=f"positive class")

    fname = f"{plots_dir}/{key}.html"
    context.log_artifact(PlotArtifact(key, body=plt.gcf()), local_path=fname)
Beispiel #8
0
def confusion_matrix(model, xtest, ytest):
    cmd = metrics.plot_confusion_matrix(model,
                                        xtest,
                                        ytest,
                                        normalize='all',
                                        values_format='.2g',
                                        cmap=plt.cm.Blues)
    # for now only 1, add different views to this array for display in UI
    return PlotArtifact("confusion-matrix-normalized", body=cmd.figure_)
Beispiel #9
0
def precision_recall_multi(ytest_b, yprob, labels, scoring="micro"):
    """
    """
    n_classes = len(labels)

    precision = dict()
    recall = dict()
    avg_prec = dict()
    for i in range(n_classes):
        precision[i], recall[i], _ = metrics.precision_recall_curve(
            ytest_b[:, i], yprob[:, i])
        avg_prec[i] = metrics.average_precision_score(ytest_b[:, i], yprob[:,
                                                                           i])
    precision["micro"], recall["micro"], _ = metrics.precision_recall_curve(
        ytest_b.ravel(), yprob.ravel())
    avg_prec["micro"] = metrics.average_precision_score(ytest_b,
                                                        yprob,
                                                        average="micro")
    ap_micro = avg_prec["micro"]
    model_metrics.update({'precision-micro-avg-classes': ap_micro})

    gcf_clear(plt)
    colors = cycle(
        ['navy', 'turquoise', 'darkorange', 'cornflowerblue', 'teal'])
    plt.figure(figsize=(7, 8))
    f_scores = np.linspace(0.2, 0.8, num=4)
    lines = []
    labels = []
    for f_score in f_scores:
        x = np.linspace(0.01, 1)
        y = f_score * x / (2 * x - f_score)
        l, = plt.plot(x[y >= 0], y[y >= 0], color='gray', alpha=0.2)
        plt.annotate('f1={0:0.1f}'.format(f_score), xy=(0.9, y[45] + 0.02))

    lines.append(l)
    labels.append('iso-f1 curves')
    l, = plt.plot(recall["micro"], precision["micro"], color='gold', lw=10)
    lines.append(l)
    labels.append(f'micro-average precision-recall (area = {ap_micro:0.2f})')

    for i, color in zip(range(n_classes), colors):
        l, = plt.plot(recall[i], precision[i], color=color, lw=2)
        lines.append(l)
        labels.append(
            f'precision-recall for class {i} (area = {avg_prec[i]:0.2f})')

    fig = plt.gcf()
    fig.subplots_adjust(bottom=0.25)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('recall')
    plt.ylabel('precision')
    plt.title('precision recall - multiclass')
    plt.legend(lines, labels, loc=(0, -.41), prop=dict(size=10))

    return PlotArtifact("precision-recall-multiclass", body=plt.gcf())
Beispiel #10
0
        def handler(context, p1=1, p2="xx"):
            """this is a simple function

            :param context: handler context
            :param p1:  first param
            :param p2:  another param
            """
            # access input metadata, values, and inputs
            print(f"Run: {context.name} (uid={context.uid})")
            print(f"Params: p1={p1}, p2={p2}")

            time.sleep(1)

            # log the run results (scalar values)
            context.log_result("accuracy", p1 * 2)
            context.log_result("loss", p1 * 3)

            # add a lable/tag to this run
            context.set_label("category", "tests")

            # create a matplot figure and store as artifact
            fig, ax = plt.subplots()
            np.random.seed(0)
            x, y = np.random.normal(size=(2, 200))
            color, size = np.random.random((2, 200))
            ax.scatter(x, y, c=color, s=500 * size, alpha=0.3)
            ax.grid(color="lightgray", alpha=0.7)

            context.log_artifact(
                PlotArtifact("myfig", body=fig, title="my plot"))

            # create a dataframe artifact
            df = pd.DataFrame([{
                "A": 10,
                "B": 100
            }, {
                "A": 11,
                "B": 110
            }, {
                "A": 12,
                "B": 120
            }])
            context.log_dataset("mydf", df=df)

            # Log an ML Model artifact
            context.log_model(
                "mymodel",
                body=b"abc is 123",
                model_file="model.txt",
                model_dir="data",
                metrics={"accuracy": 0.85},
                parameters={"xx": "abc"},
                labels={"framework": "xgboost"},
            )

            return "my resp"
Beispiel #11
0
def precision_recall_bin(model, xtest, ytest, yprob):
    """
    """
    # precision-recall
    #gcf_clear(plt)
    disp = metrics.plot_precision_recall_curve(model, xtest, ytest)
    disp.ax_.set_title(
        f'precision recall: AP={metrics.average_precision_score(ytest, yprob):0.2f}'
    )

    return PlotArtifact("precision-recall-binary", body=disp.figure_)
Beispiel #12
0
def precision_recall_bin(model, xtest, ytest, yprob, clear=False):
    """"""
    if clear:
        gcf_clear(plt)
    disp = metrics.plot_precision_recall_curve(model, xtest, ytest)
    disp.ax_.set_title(
        f"precision recall: AP={metrics.average_precision_score(ytest, yprob):0.2f}"
    )

    return PlotArtifact("precision-recall-binary",
                        body=disp.figure_,
                        title="Binary Precision Recall")
Beispiel #13
0
def roc_bin(ytest, yprob):
    """
    """
    # ROC plot
    #gcf_clear(plt)
    fpr, tpr, _ = metrics.roc_curve(ytest, yprob)
    plt.figure(1)
    plt.plot([0, 1], [0, 1], 'k--')
    plt.plot(fpr, tpr, label='a label')
    plt.xlabel('false positive rate')
    plt.ylabel('true positive rate')
    plt.title('roc curve')
    plt.legend(loc='best')

    return PlotArtifact("roc-binary", body=plt.gcf())
Beispiel #14
0
def roc_bin(ytest, yprob, clear: bool = False):
    """"""
    # ROC plot
    if clear:
        gcf_clear(plt)
    fpr, tpr, _ = metrics.roc_curve(ytest, yprob)
    plt.figure()
    plt.plot([0, 1], [0, 1], "k--")
    plt.plot(fpr, tpr, label="a label")
    plt.xlabel("false positive rate")
    plt.ylabel("true positive rate")
    plt.title("roc curve")
    plt.legend(loc="best")

    return PlotArtifact("roc-binary", body=plt.gcf(), title="Binary ROC Curve")
Beispiel #15
0
def confusion_matrix(model, xtest, ytest, cmap="Blues"):
    cmd = metrics.plot_confusion_matrix(
        model,
        xtest,
        ytest,
        normalize="all",
        values_format=".2g",
        cmap=plt.get_cmap(cmap),
    )
    # for now only 1, add different views to this array for display in UI
    cmd.plot()
    return PlotArtifact(
        "confusion-matrix-normalized",
        body=cmd.figure_,
        title="Confusion Matrix - Normalized Plot",
    )
Beispiel #16
0
def learning_curves(context: MLClientCtx,
                    results: dict,
                    figsz: Tuple[int, int] = (10, 10),
                    plots_dest: str = "plots") -> None:
    """plot xgb learning curves

    this will also log a model's learning curves
    """
    plt.clf()
    plt.figure(figsize=figsz)
    plt.plot(results["train"]["my_rmsle"], label="train-my-rmsle")
    plt.plot(results["valid"]["my_rmsle"], label="valid-my-rmsle")
    plt.title(f"learning curves")
    plt.legend()

    context.log_artifact(PlotArtifact(f"learning-curves", body=plt.gcf()),
                         local_path=f"{plots_dest}/learning-curves.html")
Beispiel #17
0
def _coxph_log_model(
    context,
    model,
    dataset_key: str = "coxhazard-summary",
    models_dest: str = "models",
    plot_cov_groups: bool = False,
    p_value: float = 0.005,
    plot_key: str = "km-cx",
    plots_dest: str = "plots",
    file_ext="csv",
    extra_data: dict = {},
):
    """log a coxph model (and submodel locations)

    :param model:        estimated coxph model
    :param extra_data:   if this model wants to store the locations of submodels
                         use this
    """
    import matplotlib.pyplot as plt

    sumtbl = model.summary

    context.log_dataset(dataset_key, df=sumtbl, index=True, format=file_ext)

    model_bin = dumps(model)
    context.log_model(
        "cx-model",
        body=model_bin,
        artifact_path=os.path.join(context.artifact_path, models_dest),
        model_file="model.pkl",
    )
    if plot_cov_groups:
        select_covars = summary[summary.p <= p_value].index.values
        for group in select_covars:
            axs = model.plot_covariate_groups(group, values=[0, 1])
            for ix, ax in enumerate(axs):
                f = ax.get_figure()
                context.log_artifact(
                    PlotArtifact(f"cx-{group}-{ix}", body=plt.gcf()),
                    local_path=f"{plots_dest}/cx-{group}-{ix}.html",
                )
                gcf_clear(plt)
Beispiel #18
0
def plot_stat(context,
              stat_name,
              stat_df):
    gcf_clear(plt)

    # Add chart
    ax = plt.axes()
    stat_chart = sns.barplot(x=stat_name,
                             y='index',
                             data=stat_df.sort_values(stat_name, ascending=False).reset_index(),
                             ax=ax)
    plt.tight_layout()

    for p in stat_chart.patches:
        width = p.get_width()
        plt.text(5 + p.get_width(), p.get_y() + 0.55 * p.get_height(),
                 '{:1.2f}'.format(width),
                 ha='center', va='center')

    context.log_artifact(PlotArtifact(f'{stat_name}', body=plt.gcf()),
                         local_path=os.path.join('plots', 'feature_selection', f'{stat_name}.html'))
    gcf_clear(plt)
Beispiel #19
0
def plot_importance(context,
                    model,
                    key: str = "feature-importances",
                    plots_dest: str = "plots"):
    """Display estimated feature importances
    Only works for models with attribute 'feature_importances_`

    **legacy version please deprecate in functions and demos**

    :param context:     function context
    :param model:       fitted model
    :param key:         key of feature importances plot and table in artifact
                        store
    :param plots_dest:  subfolder  in artifact store
    """
    if not hasattr(model, "feature_importances_"):
        raise Exception(
            "feature importaces are only available for some models")

    # create a feature importance table with desired labels
    zipped = zip(model.feature_importances_, context.header)
    feature_imp = pd.DataFrame(sorted(zipped),
                               columns=["freq", "feature"
                                        ]).sort_values(by="freq",
                                                       ascending=False)

    gcf_clear(plt)
    plt.figure(figsize=(20, 10))
    sns.barplot(x="freq", y="feature", data=feature_imp)
    plt.title("features")
    plt.tight_layout()

    fname = f"{plots_dest}/{key}.html"
    context.log_artifact(PlotArtifact(key, body=plt.gcf()), local_path=fname)

    # feature importances are also saved as a csv table (generally small):
    fname = key + "-tbl.csv"
    return context.log_dataset(key + "-tbl", df=feature_imp, local_path=fname)
def plot_confusion_matrix(context: MLClientCtx,
                          labels,
                          predictions,
                          key: str = "confusion_matrix",
                          plots_dir: str = "plots",
                          colormap: str = "Blues",
                          fmt: str = "png",
                          sample_weight=None):
    """Create a confusion matrix.
    Plot and save a confusion matrix using test data from a
    modelline step.
    
    See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html
    
    TODO: fix label alignment
    TODO: consider using another packaged version
    TODO: refactor to take params dict for plot options

    :param context:         function context
    :param labels:          validation data ground-truth labels
    :param predictions:     validation data predictions
    :param key:             str
    :param plots_dir:       relative path of plots in artifact store
    :param colormap:        colourmap for confusion matrix
    :param fmt:             plot format
    :param sample_weight:   sample weights
    """
    _gcf_clear(plt)

    cm = metrics.confusion_matrix(labels, predictions, sample_weight=None)
    sns.heatmap(cm, annot=True, cmap=colormap, square=True)

    fig = plt.gcf()
    fname = f"{plots_dir}/{key}.{fmt}"
    fig.savefig(os.path.join(context.artifact_path, fname))
    context.log_artifact(PlotArtifact(key, body=fig), local_path=fname)
Beispiel #21
0
def summarize(
    context,
    dask_key: str = "dask_key",
    dataset: mlrun.DataItem = None,
    label_column: str = "label",
    plots_dest: str = "plots",
    dask_function: str = None,
    dask_client=None,
) -> None:
    """Summarize a table
    
    Connects to dask client through the function context, or through an optional
    user-supplied scheduler.

    :param context:         the function context
    :param dask_key:        key of dataframe in dask client "datasets" attribute
    :param label_column:    ground truth column label
    :param plots_dest:      destination folder of summary plots (relative to artifact_path)
    :param dask_function:   dask function url (db://..)
    :param dask_client:     dask client object
    """
    if dask_function:
        client = mlrun.import_function(dask_function).client
    elif dask_client:
        client = dask_client
    else:
        raise ValueError('dask client was not provided')
        
    if dask_key in client.datasets:
        table = client.get_dataset(dask_key)
    elif dataset:
        #table = dataset.as_df(df_module=dd)
        table = dataset.as_df()
    else:
        context.logger.info(f"only these datasets are available {client.datasets} in client {client}")
        raise Exception("dataset not found on dask cluster")
    df = table
    header = df.columns.values
    extra_data = {}

    try:
        gcf_clear(plt)
        snsplt = sns.pairplot(df, hue=label_column)  # , diag_kws={"bw": 1.5})
        extra_data["histograms"] = context.log_artifact(
            PlotArtifact("histograms", body=plt.gcf()),
            local_path=f"{plots_dest}/hist.html",
            db_key=False,
        )
    except Exception as e:
        context.logger.error(f"Failed to create pairplot histograms due to: {e}")

    try:
        gcf_clear(plt)
        plot_cols = 3
        plot_rows = int((len(header) - 1) / plot_cols) + 1
        fig, ax = plt.subplots(plot_rows, plot_cols, figsize=(15, 4))
        fig.tight_layout(pad=2.0)
        for i in range(plot_rows * plot_cols):
            if i < len(header):
                sns.violinplot(
                    x=df[header[i]],
                    ax=ax[int(i / plot_cols)][i % plot_cols],
                    orient="h",
                    width=0.7,
                    inner="quartile",
                )
            else:
                fig.delaxes(ax[int(i / plot_cols)][i % plot_cols])
            i += 1
        extra_data["violin"] = context.log_artifact(
            PlotArtifact("violin", body=plt.gcf(), title="Violin Plot"),
            local_path=f"{plots_dest}/violin.html",
            db_key=False,
        )
    except Exception as e:
        context.logger.warn(f"Failed to create violin distribution plots due to: {e}")

    if label_column:
        labels = df.pop(label_column)
        imbtable = labels.value_counts(normalize=True).sort_index()
        try:
            gcf_clear(plt)
            balancebar = imbtable.plot(kind="bar", title="class imbalance - labels")
            balancebar.set_xlabel("class")
            balancebar.set_ylabel("proportion of total")
            extra_data["imbalance"] = context.log_artifact(
                PlotArtifact("imbalance", body=plt.gcf()),
                local_path=f"{plots_dest}/imbalance.html",
            )
        except Exception as e:
            context.logger.warn(f"Failed to create class imbalance plot due to: {e}")
        context.log_artifact(
            TableArtifact(
                "imbalance-weights-vec", df=pd.DataFrame({"weights": imbtable})
            ),
            local_path=f"{plots_dest}/imbalance-weights-vec.csv",
            db_key=False,
        )

    tblcorr = df.corr()
    mask = np.zeros_like(tblcorr, dtype=np.bool)
    mask[np.triu_indices_from(mask)] = True

    dfcorr = pd.DataFrame(data=tblcorr, columns=header, index=header)
    dfcorr = dfcorr[np.arange(dfcorr.shape[0])[:, None] > np.arange(dfcorr.shape[1])]
    context.log_artifact(
        TableArtifact("correlation-matrix", df=tblcorr, visible=True),
        local_path=f"{plots_dest}/correlation-matrix.csv",
        db_key=False,
    )

    try:
        gcf_clear(plt)
        ax = plt.axes()
        sns.heatmap(tblcorr, ax=ax, mask=mask, annot=False, cmap=plt.cm.Reds)
        ax.set_title("features correlation")
        extra_data["correlation"] = context.log_artifact(
            PlotArtifact("correlation", body=plt.gcf(), title="Correlation Matrix"),
            local_path=f"{plots_dest}/corr.html",
            db_key=False,
        )
    except Exception as e:
        context.logger.warn(f"Failed to create features correlation plot due to: {e}")

    gcf_clear(plt)
Beispiel #22
0
def summarize(
    context: MLClientCtx,
    table: DataItem,
    label_column: str = None,
    class_labels: List[str] = [],
    plot_hist: bool = True,
    plots_dest: str = "plots",
    update_dataset=False,
) -> None:
    """Summarize a table

    :param context:         the function context
    :param table:           MLRun input pointing to pandas dataframe (csv/parquet file path)
    :param label_column:    ground truth column label
    :param class_labels:    label for each class in tables and plots
    :param plot_hist:       (True) set this to False for large tables
    :param plots_dest:      destination folder of summary plots (relative to artifact_path)
    :param update_dataset:  when the table is a registered dataset update the charts in-place
    """
    df = table.as_df()
    header = df.columns.values
    extra_data = {}

    try:
        gcf_clear(plt)
        snsplt = sns.pairplot(df, hue=label_column)  # , diag_kws={"bw": 1.5})
        extra_data["histograms"] = context.log_artifact(
            PlotArtifact("histograms", body=plt.gcf()),
            local_path=f"{plots_dest}/hist.html",
            db_key=False,
        )
    except Exception as e:
        context.logger.error(
            f"Failed to create pairplot histograms due to: {e}")

    try:
        gcf_clear(plt)
        plot_cols = 3
        plot_rows = int((len(header) - 1) / plot_cols) + 1
        fig, ax = plt.subplots(plot_rows, plot_cols, figsize=(15, 4))
        fig.tight_layout(pad=2.0)
        for i in range(plot_rows * plot_cols):
            if i < len(header):
                sns.violinplot(
                    x=df[header[i]],
                    ax=ax[int(i / plot_cols)][i % plot_cols],
                    orient="h",
                    width=0.7,
                    inner="quartile",
                )
            else:
                fig.delaxes(ax[int(i / plot_cols)][i % plot_cols])
            i += 1
        extra_data["violin"] = context.log_artifact(
            PlotArtifact("violin", body=plt.gcf(), title="Violin Plot"),
            local_path=f"{plots_dest}/violin.html",
            db_key=False,
        )
    except Exception as e:
        context.logger.warn(
            f"Failed to create violin distribution plots due to: {e}")

    if label_column:
        labels = df.pop(label_column)
        imbtable = labels.value_counts(normalize=True).sort_index()
        try:
            gcf_clear(plt)
            balancebar = imbtable.plot(kind="bar",
                                       title="class imbalance - labels")
            balancebar.set_xlabel("class")
            balancebar.set_ylabel("proportion of total")
            extra_data["imbalance"] = context.log_artifact(
                PlotArtifact("imbalance", body=plt.gcf()),
                local_path=f"{plots_dest}/imbalance.html",
            )
        except Exception as e:
            context.logger.warn(
                f"Failed to create class imbalance plot due to: {e}")
        context.log_artifact(
            TableArtifact("imbalance-weights-vec",
                          df=pd.DataFrame({"weights": imbtable})),
            local_path=f"{plots_dest}/imbalance-weights-vec.csv",
            db_key=False,
        )

    tblcorr = df.corr()
    mask = np.zeros_like(tblcorr, dtype=np.bool)
    mask[np.triu_indices_from(mask)] = True

    dfcorr = pd.DataFrame(data=tblcorr, columns=header, index=header)
    dfcorr = dfcorr[
        np.arange(dfcorr.shape[0])[:, None] > np.arange(dfcorr.shape[1])]
    context.log_artifact(
        TableArtifact("correlation-matrix", df=tblcorr, visible=True),
        local_path=f"{plots_dest}/correlation-matrix.csv",
        db_key=False,
    )

    try:
        gcf_clear(plt)
        ax = plt.axes()
        sns.heatmap(tblcorr, ax=ax, mask=mask, annot=False, cmap=plt.cm.Reds)
        ax.set_title("features correlation")
        extra_data["correlation"] = context.log_artifact(
            PlotArtifact("correlation",
                         body=plt.gcf(),
                         title="Correlation Matrix"),
            local_path=f"{plots_dest}/corr.html",
            db_key=False,
        )
    except Exception as e:
        context.logger.warn(
            f"Failed to create features correlation plot due to: {e}")

    gcf_clear(plt)
    if update_dataset and table.meta and table.meta.kind == "dataset":
        from mlrun.artifacts import update_dataset_meta

        update_dataset_meta(table.meta, extra_data=extra_data)
Beispiel #23
0
def train_model(context: MLClientCtx,
                dataset: DataItem,
                model_pkg_class: str,
                label_column: str = "label",
                train_validation_size: float = 0.75,
                sample: float = 1.0,
                models_dest: str = "models",
                test_set_key: str = "test_set",
                plots_dest: str = "plots",
                dask_key: str = "dask_key",
                dask_persist: bool = False,
                scheduler_key: str = '',
                file_ext: str = "parquet",
                random_state: int = 42) -> None:
    """
    Train a sklearn classifier with Dask
    
    :param context:                 Function context.
    :param dataset:                 Raw data file.
    :param model_pkg_class:         Model to train, e.g, "sklearn.ensemble.RandomForestClassifier", 
                                    or json model config.
    :param label_column:            (label) Ground-truth y labels.
    :param train_validation_size:   (0.75) Train validation set proportion out of the full dataset.
    :param sample:                  (1.0) Select sample from dataset (n-rows/% of total), randomzie rows as default.
    :param models_dest:             (models) Models subfolder on artifact path.
    :param test_set_key:            (test_set) Mlrun db key of held out data in artifact store.
    :param plots_dest:              (plots) Plot subfolder on artifact path.
    :param dask_key:                (dask key) Key of dataframe in dask client "datasets" attribute.
    :param dask_persist:            (False) Should the data be persisted (through the `client.persist`)
    :param scheduler_key:           (scheduler) Dask scheduler configuration, json also logged as an artifact.
    :param file_ext:                (parquet) format for test_set_key hold out data
    :param random_state:            (42) sklearn seed
    """

    if scheduler_key:
        client = Client(scheduler_key)

    else:
        client = Client()

    context.logger.info("Read Data")
    df = dataset.as_df(df_module=dd)

    context.logger.info("Prep Data")
    numerics = ['int16', 'int32', 'int64', 'float16', 'float32', 'float64']
    df = df.select_dtypes(include=numerics)

    if df.isna().any().any().compute() == True:
        raise Exception('NAs valus found')

    df_header = df.columns

    df = df.sample(frac=sample).reset_index(drop=True)
    encoder = LabelEncoder()
    encoder = encoder.fit(df[label_column])
    X = df.drop(label_column, axis=1).to_dask_array(lengths=True)
    y = encoder.transform(df[label_column])

    classes = df[label_column].drop_duplicates()  # no unique values in dask
    classes = [str(i) for i in classes]

    context.logger.info("Split and Train")
    X_train, X_test, y_train, y_test = model_selection.train_test_split(
        X, y, train_size=train_validation_size, random_state=random_state)

    scaler = StandardScaler()
    scaler = scaler.fit(X_train)
    X_train_transformed = scaler.transform(X_train)
    X_test_transformed = scaler.transform(X_test)

    model_config = gen_sklearn_model(model_pkg_class,
                                     context.parameters.items())

    model_config["FIT"].update({"X": X_train_transformed, "y": y_train})

    ClassifierClass = create_class(model_config["META"]["class"])

    model = ClassifierClass(**model_config["CLASS"])

    with joblib.parallel_backend("dask"):

        model = model.fit(**model_config["FIT"])

    artifact_path = context.artifact_subpath(models_dest)

    plots_path = context.artifact_subpath(models_dest, plots_dest)

    context.logger.info("Evaluate")
    extra_data_dict = {}
    for report in (ROCAUC, ClassificationReport, ConfusionMatrix):

        report_name = str(report.__name__)
        plt.cla()
        plt.clf()
        plt.close()

        viz = report(model, classes=classes, per_class=True, is_fitted=True)
        viz.fit(X_train_transformed,
                y_train)  # Fit the training data to the visualizer
        viz.score(X_test_transformed,
                  y_test.compute())  # Evaluate the model on the test data

        plot = context.log_artifact(PlotArtifact(report_name,
                                                 body=viz.fig,
                                                 title=report_name),
                                    db_key=False)
        extra_data_dict[str(report)] = plot

        if report_name == 'ROCAUC':
            context.log_results({
                "micro": viz.roc_auc.get("micro"),
                "macro": viz.roc_auc.get("macro")
            })

        elif report_name == 'ClassificationReport':
            for score_name in viz.scores_:
                for score_class in viz.scores_[score_name]:

                    context.log_results({
                        score_name + "-" + score_class:
                        viz.scores_[score_name].get(score_class)
                    })

    viz = FeatureImportances(model,
                             classes=classes,
                             per_class=True,
                             is_fitted=True,
                             labels=df_header.delete(
                                 df_header.get_loc(label_column)))
    viz.fit(X_train_transformed, y_train)
    viz.score(X_test_transformed, y_test)

    plot = context.log_artifact(PlotArtifact("FeatureImportances",
                                             body=viz.fig,
                                             title="FeatureImportances"),
                                db_key=False)
    extra_data_dict[str("FeatureImportances")] = plot

    plt.cla()
    plt.clf()
    plt.close()

    context.logger.info("Log artifacts")
    artifact_path = context.artifact_subpath(models_dest)

    plots_path = context.artifact_subpath(models_dest, plots_dest)

    context.set_label('class', model_pkg_class)

    context.log_model("model",
                      body=dumps(model),
                      artifact_path=artifact_path,
                      model_file="model.pkl",
                      extra_data=extra_data_dict,
                      metrics=context.results,
                      labels={"class": model_pkg_class})

    context.log_artifact("standard_scaler",
                         body=dumps(scaler),
                         artifact_path=artifact_path,
                         model_file="scaler.gz",
                         label="standard_scaler")

    context.log_artifact("label_encoder",
                         body=dumps(encoder),
                         artifact_path=artifact_path,
                         model_file="encoder.gz",
                         label="label_encoder")

    df_to_save = delayed(np.column_stack)((X_test, y_test)).compute()
    context.log_dataset(
        test_set_key,
        df=pd.DataFrame(df_to_save,
                        columns=df_header),  # improve log dataset ability
        format=file_ext,
        index=False,
        labels={"data-type": "held-out"},
        artifact_path=context.artifact_subpath('data'))

    context.logger.info("Done!")
Beispiel #24
0
def permutation_importance(
    context: MLClientCtx,
    model: DataItem,
    dataset: DataItem,
    labels: str,
    figsz=(10, 5),
    plots_dest: str = "plots",
    fitype: str = "permute",
) -> pd.DataFrame:
    """calculate change in metric

    type 'permute' uses a pre-estimated model
    type 'dropcol' uses a re-estimates model

    :param context:     the function's execution context
    :param model:       a trained model
    :param dataset:     features and ground truths, regression targets
    :param labels       name of the ground truths column
    :param figsz:       matplotlib figure size
    :param plots_dest:  path within artifact store
    :
    """
    model_file, model_data, _ = get_model(model.url, suffix=".pkl")
    model = load(open(str(model_file), "rb"))

    X = dataset.as_df()
    y = X.pop(labels)
    header = X.columns

    metric = _oob_classifier_accuracy

    baseline = metric(model, X, y)

    imp = []
    for col in X.columns:
        if fitype is "permute":
            save = X[col].copy()
            X[col] = np.random.permutation(X[col])
            m = metric(model, X, y)
            X[col] = save
            imp.append(baseline - m)
        elif fitype is "dropcol":
            X_ = X.drop(col, axis=1)
            model_ = clone(model)
            #model_.random_state = random_state
            model_.fit(X_, y)
            o = model_.oob_score_
            imp.append(baseline - o)
        else:
            raise ValueError(
                "unknown fitype, only 'permute' or 'dropcol' permitted")

    zipped = zip(imp, header)
    feature_imp = pd.DataFrame(sorted(zipped),
                               columns=["importance", "feature"])
    feature_imp.sort_values(by="importance", ascending=False, inplace=True)

    plt.clf()
    plt.figure(figsize=figsz)
    sns.barplot(x="importance", y="feature", data=feature_imp)
    plt.title(f"feature importances-{fitype}")
    plt.tight_layout()

    context.log_artifact(
        PlotArtifact(f"feature importances-{fitype}", body=plt.gcf()),
        local_path=f"{plots_dest}/feature-permutations.html",
    )
    context.log_dataset(f"feature-importances-{fitype}-tbl",
                        df=feature_imp,
                        index=False)
Beispiel #25
0
def describe(
    context: MLClientCtx,
    table: Union[DataItem, str],
    label_column: str,
    class_labels: List[str],
    key: str = "table-summary",
) -> None:
    """Summarize a table

    TODO: merge with dask version

    :param context:         the function context
    :param table:           pandas dataframe
    :param key:             key of table summary in artifact store
    """
    _gcf_clear(plt)

    base_path = context.artifact_path
    os.makedirs(base_path, exist_ok=True)
    os.makedirs(base_path + "/plots", exist_ok=True)

    print(f'TABLE {table}')
    table = pd.read_parquet(str(table))
    header = table.columns.values

    # describe table
    sumtbl = table.describe()
    sumtbl = sumtbl.append(len(table.index) - table.count(), ignore_index=True)
    sumtbl.insert(
        0, "metric",
        ["count", "mean", "std", "min", "25%", "50%", "75%", "max", "nans"])

    sumtbl.to_csv(os.path.join(base_path, key + ".csv"), index=False)
    context.log_artifact(key, local_path=key + ".csv")

    # plot class balance, record relative class weight
    _gcf_clear(plt)

    labels = table.pop(label_column)
    class_balance_model = ClassBalance(labels=class_labels)
    class_balance_model.fit(labels)

    scale_pos_weight = class_balance_model.support_[
        0] / class_balance_model.support_[1]
    #context.log_artifact("scale_pos_weight", f"{scale_pos_weight:0.2f}")
    context.log_artifact("scale_pos_weight", str(scale_pos_weight))

    class_balance_model.show(
        outpath=os.path.join(base_path, "plots/imbalance.png"))
    context.log_artifact(PlotArtifact("imbalance", body=plt.gcf()),
                         local_path="plots/imbalance.html")

    # plot feature correlation
    _gcf_clear(plt)
    tblcorr = table.corr()
    ax = plt.axes()
    sns.heatmap(tblcorr, ax=ax, annot=False, cmap=plt.cm.Reds)
    ax.set_title("features correlation")
    plt.savefig(os.path.join(base_path, "plots/corr.png"))
    context.log_artifact(PlotArtifact("correlation", body=plt.gcf()),
                         local_path="plots/corr.html")

    # plot histogram
    _gcf_clear(plt)
    """
Beispiel #26
0
def roc_multi(ytest_b, yprob, labels):
    """
    """
    n_classes = len(labels)

    # Compute ROC curve and ROC area for each class
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(n_classes):
        fpr[i], tpr[i], _ = metrics.roc_curve(ytest_b[:, i], yprob[:, i])
        roc_auc[i] = metrics.auc(fpr[i], tpr[i])

    # Compute micro-average ROC curve and ROC area
    fpr["micro"], tpr["micro"], _ = metrics.roc_curve(ytest_b.ravel(),
                                                      yprob.ravel())
    roc_auc["micro"] = metrics.auc(fpr["micro"], tpr["micro"])

    # First aggregate all false positive rates
    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))

    # Then interpolate all ROC curves at this points
    mean_tpr = np.zeros_like(all_fpr)
    for i in range(n_classes):
        mean_tpr += interp(all_fpr, fpr[i], tpr[i])

    # Finally average it and compute AUC
    mean_tpr /= n_classes

    fpr["macro"] = all_fpr
    tpr["macro"] = mean_tpr
    roc_auc["macro"] = metrics.auc(fpr["macro"], tpr["macro"])

    # Plot all ROC curves
    #gcf_clear(plt)
    plt.figure()
    plt.plot(fpr["micro"],
             tpr["micro"],
             label='micro-average ROC curve (area = {0:0.2f})'
             ''.format(roc_auc["micro"]),
             color='deeppink',
             linestyle=':',
             linewidth=4)

    plt.plot(fpr["macro"],
             tpr["macro"],
             label='macro-average ROC curve (area = {0:0.2f})'
             ''.format(roc_auc["macro"]),
             color='navy',
             linestyle=':',
             linewidth=4)

    colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
    for i, color in zip(range(n_classes), colors):
        plt.plot(fpr[i],
                 tpr[i],
                 color=color,
                 lw=2,
                 label='ROC curve of class {0} (area = {1:0.2f})'
                 ''.format(i, roc_auc[i]))

    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('receiver operating characteristic - multiclass')
    plt.legend(loc=(0, -.68), prop=dict(size=10))

    return PlotArtifact("roc-multiclass", body=plt.gcf())
Beispiel #27
0
def eval_model_v2(
    context,
    xtest,
    ytest,
    model,
    pcurve_bins: int = 10,
    pcurve_names: List[str] = ["my classifier"],
    plots_artifact_path: str = "",
    pred_params: dict = {},
    cmap="Blues",
    is_xgb=False,
):
    """generate predictions and validation stats

    pred_params are non-default, scikit-learn api prediction-function
    parameters. For example, a tree-type of model may have a tree depth
    limit for its prediction function.

    :param xtest:               features array type Union(DataItem, DataFrame,
                                numpy array)
    :param ytest:               ground-truth labels Union(DataItem, DataFrame,
                                Series, numpy array, List)
    :param model:               estimated model
    :param pcurve_bins:         (10) subdivide [0,1] interval into n bins, x-axis
    :param pcurve_names:        label for each calibration curve
    :param pred_params:         (None) dict of predict function parameters
    :param cmap:                ('Blues') matplotlib color map
    :param is_xgb
    """

    if hasattr(model, "get_xgb_params"):
        is_xgb = True

    def df_blob(df):
        return bytes(df.to_csv(index=False), encoding="utf-8")

    if isinstance(ytest, np.ndarray):
        unique_labels = np.unique(ytest)
    elif isinstance(ytest, list):
        unique_labels = set(ytest)
    else:
        try:
            ytest = ytest.values
            unique_labels = np.unique(ytest)
        except Exception as exc:
            raise Exception(f"unrecognized data type for ytest {exc}")

    n_classes = len(unique_labels)
    is_multiclass = True if n_classes > 2 else False

    # INIT DICT...OR SOME OTHER COLLECTOR THAT CAN BE ACCESSED
    plots_path = plots_artifact_path or context.artifact_subpath("plots")
    extra_data = {}

    ypred = model.predict(xtest, **pred_params)

    if isinstance(ypred.flat[0], np.floating):
        accuracy = mean_absolute_error(ytest, ypred)

    else:
        accuracy = float(metrics.accuracy_score(ytest, ypred))

    context.log_results({
        "accuracy": accuracy,
        "test-error": np.sum(ytest != ypred) / ytest.shape[0]
    })

    # PROBABILITIES
    if hasattr(model, "predict_proba"):
        yprob = model.predict_proba(xtest, **pred_params)
        if not is_multiclass:
            fraction_of_positives, mean_predicted_value = calibration_curve(
                ytest, yprob[:, -1], n_bins=pcurve_bins, strategy="uniform")
            cmd = plot_calibration_curve(ytest, [yprob], pcurve_names)
            calibration = context.log_artifact(
                PlotArtifact(
                    "probability-calibration",
                    body=cmd.get_figure(),
                    title="probability calibration plot",
                ),
                artifact_path=plots_path,
                db_key=False,
            )
            extra_data["probability calibration"] = calibration

    # CONFUSION MATRIX
    if is_classifier(model):
        cm = sklearn_confusion_matrix(ytest, ypred, normalize="all")
        df = pd.DataFrame(data=cm)
        extra_data["confusion matrix table.csv"] = df_blob(df)

        cmd = metrics.plot_confusion_matrix(
            model,
            xtest,
            ytest,
            normalize="all",
            values_format=".2g",
            cmap=plt.get_cmap(cmap),
        )
        confusion = context.log_artifact(
            PlotArtifact(
                "confusion-matrix",
                body=cmd.figure_,
                title="Confusion Matrix - Normalized Plot",
            ),
            artifact_path=plots_path,
            db_key=False,
        )
        extra_data["confusion matrix"] = confusion

    # LEARNING CURVES
    if hasattr(model, "evals_result") and is_xgb is False:
        results = model.evals_result()
        train_set = list(results.items())[0]
        valid_set = list(results.items())[1]

        learning_curves_df = None
        if is_multiclass:
            if hasattr(train_set[1], "merror"):
                learning_curves_df = pd.DataFrame({
                    "train_error":
                    train_set[1]["merror"],
                    "valid_error":
                    valid_set[1]["merror"],
                })
        else:
            if hasattr(train_set[1], "error"):
                learning_curves_df = pd.DataFrame({
                    "train_error":
                    train_set[1]["error"],
                    "valid_error":
                    valid_set[1]["error"],
                })

        if learning_curves_df:
            extra_data["learning curve table.csv"] = df_blob(
                learning_curves_df)

            _, ax = plt.subplots()
            plt.xlabel("# training examples")
            plt.ylabel("error rate")
            plt.title("learning curve - error")
            ax.plot(learning_curves_df["train_error"], label="train")
            ax.plot(learning_curves_df["valid_error"], label="valid")
            learning = context.log_artifact(
                PlotArtifact("learning-curve",
                             body=plt.gcf(),
                             title="Learning Curve - error"),
                artifact_path=plots_path,
                db_key=False,
            )
            extra_data["learning curve"] = learning

    # FEATURE IMPORTANCES
    if hasattr(model, "feature_importances_"):
        (fi_plot, fi_tbl) = feature_importances(model, xtest.columns)
        extra_data["feature importances"] = context.log_artifact(
            fi_plot, db_key=False, artifact_path=plots_path)
        extra_data["feature importances table.csv"] = df_blob(fi_tbl)

    # AUC - ROC - PR CURVES
    if is_multiclass and is_classifier(model):
        lb = LabelBinarizer()
        ytest_b = lb.fit_transform(ytest)

        extra_data["precision_recall_multi"] = context.log_artifact(
            precision_recall_multi(ytest_b, yprob, unique_labels),
            artifact_path=plots_path,
            db_key=False,
        )
        extra_data["roc_multi"] = context.log_artifact(
            roc_multi(ytest_b, yprob, unique_labels),
            artifact_path=plots_path,
            db_key=False,
        )

        # AUC multiclass
        aucmicro = metrics.roc_auc_score(ytest_b,
                                         yprob,
                                         multi_class="ovo",
                                         average="micro")
        aucweighted = metrics.roc_auc_score(ytest_b,
                                            yprob,
                                            multi_class="ovo",
                                            average="weighted")

        context.log_results({
            "auc-micro": aucmicro,
            "auc-weighted": aucweighted
        })

        # others (todo - macro, micro...)
        f1 = metrics.f1_score(ytest, ypred, average="macro")
        ps = metrics.precision_score(ytest, ypred, average="macro")
        rs = metrics.recall_score(ytest, ypred, average="macro")
        context.log_results({
            "f1-score": f1,
            "precision_score": ps,
            "recall_score": rs
        })

    elif is_classifier(model):
        yprob_pos = yprob[:, 1]
        extra_data["precision_recall_bin"] = context.log_artifact(
            precision_recall_bin(model, xtest, ytest, yprob_pos),
            artifact_path=plots_path,
            db_key=False,
        )
        extra_data["roc_bin"] = context.log_artifact(
            roc_bin(ytest, yprob_pos, clear=True),
            artifact_path=plots_path,
            db_key=False,
        )

        rocauc = metrics.roc_auc_score(ytest, yprob_pos)
        brier_score = metrics.brier_score_loss(ytest,
                                               yprob_pos,
                                               pos_label=ytest.max())
        f1 = metrics.f1_score(ytest, ypred)
        ps = metrics.precision_score(ytest, ypred)
        rs = metrics.recall_score(ytest, ypred)
        context.log_results({
            "rocauc": rocauc,
            "brier_score": brier_score,
            "f1-score": f1,
            "precision_score": ps,
            "recall_score": rs,
        })

    elif is_regressor(model):
        r_squared = r2_score(ytest, ypred)
        rmse = mean_squared_error(ytest, ypred, squared=False)
        mse = mean_squared_error(ytest, ypred, squared=True)
        mae = mean_absolute_error(ytest, ypred)
        context.log_results({
            "R2": r_squared,
            "root_mean_squared_error": rmse,
            "mean_squared_error": mse,
            "mean_absolute_error": mae,
        })
    # return all model metrics and plots
    return extra_data
Beispiel #28
0
def summarize(
    context,
    dask_key: str = "dask_key",
    dataset: mlrun.DataItem = None,
    label_column: str = "label",
    class_labels: List[str] = [],
    plot_hist: bool = True,
    plots_dest: str = "plots",
    dask_function: str = None,
    dask_client=None,
) -> None:
    """Summarize a table
    
    Connects to dask client through the function context, or through an optional
    user-supplied scheduler.

    :param context:         the function context
    :param dask_key:        key of dataframe in dask client "datasets" attribute
    :param label_column:    ground truth column label
    :param class_labels:    label for each class in tables and plots
    :param plot_hist:       (True) set this to False for large tables
    :param plots_dest:      destination folder of summary plots (relative to artifact_path)
    :param dask_function:   dask function url (db://..)
    :param dask_client:     dask client object
    """
    if dask_function:
        client = mlrun.import_function(dask_function).client
    elif dask_client:
        client = dask_client
    else:
        raise ValueError('dask client was not provided')

    if dask_key in client.datasets:
        table = client.get_dataset(dask_key)
    elif dataset:
        table = dataset.as_df(df_module=dd)
    else:
        context.logger.info(
            f"only these datasets are available {client.datasets} in client {client}"
        )
        raise Exception("dataset not found on dask cluster")
    header = table.columns.values

    gcf_clear(plt)
    table = table.compute()
    snsplt = sns.pairplot(table, hue=label_column, diag_kws={'bw': 1.5})
    context.log_artifact(PlotArtifact('histograms', body=plt.gcf()),
                         local_path=f"{plots_dest}/hist.html")

    gcf_clear(plt)
    labels = table.pop(label_column)
    if not class_labels:
        class_labels = labels.unique()
    class_balance_model = ClassBalance(labels=class_labels)
    class_balance_model.fit(labels)
    scale_pos_weight = class_balance_model.support_[
        0] / class_balance_model.support_[1]
    context.log_result("scale_pos_weight", f"{scale_pos_weight:0.2f}")
    context.log_artifact(PlotArtifact("imbalance", body=plt.gcf()),
                         local_path=f"{plots_dest}/imbalance.html")

    gcf_clear(plt)
    tblcorr = table.corr()
    ax = plt.axes()
    sns.heatmap(tblcorr, ax=ax, annot=False, cmap=plt.cm.Reds)
    ax.set_title("features correlation")
    context.log_artifact(PlotArtifact("correlation", body=plt.gcf()),
                         local_path=f"{plots_dest}/corr.html")
    gcf_clear(plt)
Beispiel #29
0
def eval_class_model(context,
                     xtest,
                     ytest,
                     model,
                     plots_dest: str = "plots",
                     pred_params: dict = {}):
    """generate predictions and validation stats
    
    pred_params are non-default, scikit-learn api prediction-function parameters.
    For example, a tree-type of model may have a tree depth limit for its prediction
    function.
    
    :param xtest:        features array type Union(DataItem, DataFrame, np. Array)
    :param ytest:        ground-truth labels Union(DataItem, DataFrame, Series, np. Array, List)
    :param model:        estimated model
    :param pred_params:  (None) dict of predict function parameters
    """
    if isinstance(ytest, np.ndarray):
        unique_labels = np.unique(ytest)
    elif isinstance(ytest, list):
        unique_labels = set(ytest)
    else:
        try:
            ytest = ytest.values
            unique_labels = np.unique(ytest)
        except:
            raise Exception("unrecognized data type for ytest")

    n_classes = len(unique_labels)
    is_multiclass = True if n_classes > 2 else False

    # INIT DICT...OR SOME OTHER COLLECTOR THAT CAN BE ACCESSED
    mm_plots = []
    mm_tables = []
    mm = {}

    ypred = model.predict(xtest, **pred_params)
    mm.update({
        "test-accuracy": float(metrics.accuracy_score(ytest, ypred)),
        "test-error": np.sum(ytest != ypred) / ytest.shape[0]
    })

    # GEN PROBS (INCL CALIBRATED PROBABILITIES)
    if hasattr(model, "predict_proba"):
        yprob = model.predict_proba(xtest, **pred_params)
    else:
        # todo if decision fn...
        raise Exception("not implemented for this classifier")
    plot_calibration_curve(ytest, [yprob], ['xgboost'])
    context.log_artifact(PlotArtifact("calibration curve", body=plt.gcf()),
                         local_path=f"{plots_dest}/calibration curve.html")

    # start evaluating:
    # mm_plots.extend(learning_curves(model))
    if hasattr(model, "evals_result"):
        results = model.evals_result()
        train_set = list(results.items())[0]
        valid_set = list(results.items())[1]

        learning_curves = pd.DataFrame({
            "train_error": train_set[1]["error"],
            "train_auc": train_set[1]["auc"],
            "valid_error": valid_set[1]["error"],
            "valid_auc": valid_set[1]["auc"]
        })

        plt.clf()  #gcf_clear(plt)
        fig, ax = plt.subplots()
        plt.xlabel('# training examples')
        plt.ylabel('auc')
        plt.title('learning curve - auc')
        ax.plot(learning_curves.train_auc, label='train')
        ax.plot(learning_curves.valid_auc, label='valid')
        legend = ax.legend(loc='lower left')
        context.log_artifact(
            PlotArtifact("learning curve - auc", body=plt.gcf()),
            local_path=f"{plots_dest}/learning curve - auc.html")

        plt.clf()  #gcf_clear(plt)
        fig, ax = plt.subplots()
        plt.xlabel('# training examples')
        plt.ylabel('error rate')
        plt.title('learning curve - error')
        ax.plot(learning_curves.train_error, label='train')
        ax.plot(learning_curves.valid_error, label='valid')
        legend = ax.legend(loc='lower left')
        context.log_artifact(
            PlotArtifact("learning curve - erreur", body=plt.gcf()),
            local_path=f"{plots_dest}/learning curve - erreur.html")

    (fi_plot, fi_tbl) = feature_importances(model, xtest.columns)
    mm_plots.append(fi_plot)
    mm_tables.append(fi_tbl)

    mm_plots.append(confusion_matrix(model, xtest, ytest))

    if is_multiclass:
        lb = LabelBinarizer()
        ytest_b = lb.fit_transform(ytest)

        mm_plots.append(precision_recall_multi(ytest_b, yprob, unique_labels))
        mm_plots.append(roc_multi(ytest_b, yprob, unique_labels))

        # AUC multiclass
        mm.update({
            "auc-micro":
            metrics.roc_auc_score(ytest_b,
                                  yprob,
                                  multi_class="ovo",
                                  average="micro"),
            "auc-weighted":
            metrics.roc_auc_score(ytest_b,
                                  yprob,
                                  multi_class="ovo",
                                  average="weighted")
        })

        # others (todo - macro, micro...)
        mm.update({
            "f1-score":
            metrics.f1_score(ytest, ypred, average="micro"),
            "precision_score":
            metrics.precision_score(ytest, ypred, average="micro"),
            "recall_score":
            metrics.recall_score(ytest, ypred, average="micro")
        })

    else:
        # extract the positive label
        yprob_pos = yprob[:, 1]

        mm_plots.append(roc_bin(ytest, yprob_pos))
        mm_plots.append(precision_recall_bin(model, xtest, ytest, yprob_pos))

        mm.update({
            "rocauc":
            metrics.roc_auc_score(ytest, yprob_pos),
            "brier_score":
            metrics.brier_score_loss(ytest, yprob_pos, pos_label=ytest.max()),
            "f1-score":
            metrics.f1_score(ytest, ypred),
            "precision_score":
            metrics.precision_score(ytest, ypred),
            "recall_score":
            metrics.recall_score(ytest, ypred)
        })

    # return all model metrics and plots
    mm.update({"plots": mm_plots, "tables": mm_tables})

    return mm
Beispiel #30
0
def roc_multi(ytest_b, yprob, labels):
    """"""
    n_classes = len(labels)

    # Compute ROC curve and ROC area for each class
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(n_classes):
        fpr[i], tpr[i], _ = metrics.roc_curve(ytest_b[:, i], yprob[:, i])
        roc_auc[i] = metrics.auc(fpr[i], tpr[i])

    # Compute micro-average ROC curve and ROC area
    fpr["micro"], tpr["micro"], _ = metrics.roc_curve(ytest_b.ravel(),
                                                      yprob.ravel())
    roc_auc["micro"] = metrics.auc(fpr["micro"], tpr["micro"])

    # First aggregate all false positive rates
    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))

    # Then interpolate all ROC curves at this points
    mean_tpr = np.zeros_like(all_fpr)
    for i in range(n_classes):
        mean_tpr += interp(all_fpr, fpr[i], tpr[i])

    # Finally average it and compute AUC
    mean_tpr /= n_classes

    fpr["macro"] = all_fpr
    tpr["macro"] = mean_tpr
    roc_auc["macro"] = metrics.auc(fpr["macro"], tpr["macro"])

    # Plot all ROC curves
    gcf_clear(plt)
    plt.figure()
    plt.plot(
        fpr["micro"],
        tpr["micro"],
        label=f"micro-average ROC curve (area = {roc_auc['micro']:0.2f})",
        color="deeppink",
        linestyle=":",
        linewidth=4,
    )

    plt.plot(
        fpr["macro"],
        tpr["macro"],
        label=f"macro-average ROC curve (area = {roc_auc['macro']:0.2f})",
        color="navy",
        linestyle=":",
        linewidth=4,
    )

    colors = cycle(["aqua", "darkorange", "cornflowerblue"])
    for i, color in zip(range(n_classes), colors):
        plt.plot(
            fpr[i],
            tpr[i],
            color=color,
            lw=2,
            label=f"ROC curve of class {i} (area = {roc_auc[i]:0.2f})",
        )

    plt.plot([0, 1], [0, 1], "k--", lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("receiver operating characteristic - multiclass")
    plt.legend(loc=(0, -0.68), prop=dict(size=10))

    return PlotArtifact("roc-multiclass",
                        body=plt.gcf(),
                        title="Multiclass ROC Curve")