Ejemplo n.º 1
0
def _coxph_log_model(
    context,
    model,
    dataset_key: str = "coxhazard-summary",
    models_dest: str = "models",
    plot_cov_groups: bool = False,
    p_value: float = 0.005,
    plot_key: str = "km-cx",
    plots_dest: str = "plots",
    file_ext="csv",
    extra_data: dict = {},
):
    """log a coxph model (and submodel locations)

    :param model:        estimated coxph model
    :param extra_data:   if this model wants to store the locations of submodels
                         use this
    """
    import matplotlib.pyplot as plt

    sumtbl = model.summary

    context.log_dataset(dataset_key, df=sumtbl, index=True, format=file_ext)

    model_bin = dumps(model)
    context.log_model(
        "cx-model",
        body=model_bin,
        artifact_path=os.path.join(context.artifact_path, models_dest),
        model_file="model.pkl",
    )
    if plot_cov_groups:
        select_covars = summary[summary.p <= p_value].index.values
        for group in select_covars:
            axs = model.plot_covariate_groups(group, values=[0, 1])
            for ix, ax in enumerate(axs):
                f = ax.get_figure()
                context.log_artifact(
                    PlotArtifact(f"cx-{group}-{ix}", body=plt.gcf()),
                    local_path=f"{plots_dest}/cx-{group}-{ix}.html",
                )
                gcf_clear(plt)
Ejemplo n.º 2
0
def plot_stat(context,
              stat_name,
              stat_df):
    gcf_clear(plt)

    # Add chart
    ax = plt.axes()
    stat_chart = sns.barplot(x=stat_name,
                             y='index',
                             data=stat_df.sort_values(stat_name, ascending=False).reset_index(),
                             ax=ax)
    plt.tight_layout()

    for p in stat_chart.patches:
        width = p.get_width()
        plt.text(5 + p.get_width(), p.get_y() + 0.55 * p.get_height(),
                 '{:1.2f}'.format(width),
                 ha='center', va='center')

    context.log_artifact(PlotArtifact(f'{stat_name}', body=plt.gcf()),
                         local_path=os.path.join('plots', 'feature_selection', f'{stat_name}.html'))
    gcf_clear(plt)
Ejemplo n.º 3
0
def summarize(
    context,
    dask_key: str = "dask_key",
    dataset: mlrun.DataItem = None,
    label_column: str = "label",
    plots_dest: str = "plots",
    dask_function: str = None,
    dask_client=None,
) -> None:
    """Summarize a table
    
    Connects to dask client through the function context, or through an optional
    user-supplied scheduler.

    :param context:         the function context
    :param dask_key:        key of dataframe in dask client "datasets" attribute
    :param label_column:    ground truth column label
    :param plots_dest:      destination folder of summary plots (relative to artifact_path)
    :param dask_function:   dask function url (db://..)
    :param dask_client:     dask client object
    """
    if dask_function:
        client = mlrun.import_function(dask_function).client
    elif dask_client:
        client = dask_client
    else:
        raise ValueError('dask client was not provided')
        
    if dask_key in client.datasets:
        table = client.get_dataset(dask_key)
    elif dataset:
        #table = dataset.as_df(df_module=dd)
        table = dataset.as_df()
    else:
        context.logger.info(f"only these datasets are available {client.datasets} in client {client}")
        raise Exception("dataset not found on dask cluster")
    df = table
    header = df.columns.values
    extra_data = {}

    try:
        gcf_clear(plt)
        snsplt = sns.pairplot(df, hue=label_column)  # , diag_kws={"bw": 1.5})
        extra_data["histograms"] = context.log_artifact(
            PlotArtifact("histograms", body=plt.gcf()),
            local_path=f"{plots_dest}/hist.html",
            db_key=False,
        )
    except Exception as e:
        context.logger.error(f"Failed to create pairplot histograms due to: {e}")

    try:
        gcf_clear(plt)
        plot_cols = 3
        plot_rows = int((len(header) - 1) / plot_cols) + 1
        fig, ax = plt.subplots(plot_rows, plot_cols, figsize=(15, 4))
        fig.tight_layout(pad=2.0)
        for i in range(plot_rows * plot_cols):
            if i < len(header):
                sns.violinplot(
                    x=df[header[i]],
                    ax=ax[int(i / plot_cols)][i % plot_cols],
                    orient="h",
                    width=0.7,
                    inner="quartile",
                )
            else:
                fig.delaxes(ax[int(i / plot_cols)][i % plot_cols])
            i += 1
        extra_data["violin"] = context.log_artifact(
            PlotArtifact("violin", body=plt.gcf(), title="Violin Plot"),
            local_path=f"{plots_dest}/violin.html",
            db_key=False,
        )
    except Exception as e:
        context.logger.warn(f"Failed to create violin distribution plots due to: {e}")

    if label_column:
        labels = df.pop(label_column)
        imbtable = labels.value_counts(normalize=True).sort_index()
        try:
            gcf_clear(plt)
            balancebar = imbtable.plot(kind="bar", title="class imbalance - labels")
            balancebar.set_xlabel("class")
            balancebar.set_ylabel("proportion of total")
            extra_data["imbalance"] = context.log_artifact(
                PlotArtifact("imbalance", body=plt.gcf()),
                local_path=f"{plots_dest}/imbalance.html",
            )
        except Exception as e:
            context.logger.warn(f"Failed to create class imbalance plot due to: {e}")
        context.log_artifact(
            TableArtifact(
                "imbalance-weights-vec", df=pd.DataFrame({"weights": imbtable})
            ),
            local_path=f"{plots_dest}/imbalance-weights-vec.csv",
            db_key=False,
        )

    tblcorr = df.corr()
    mask = np.zeros_like(tblcorr, dtype=np.bool)
    mask[np.triu_indices_from(mask)] = True

    dfcorr = pd.DataFrame(data=tblcorr, columns=header, index=header)
    dfcorr = dfcorr[np.arange(dfcorr.shape[0])[:, None] > np.arange(dfcorr.shape[1])]
    context.log_artifact(
        TableArtifact("correlation-matrix", df=tblcorr, visible=True),
        local_path=f"{plots_dest}/correlation-matrix.csv",
        db_key=False,
    )

    try:
        gcf_clear(plt)
        ax = plt.axes()
        sns.heatmap(tblcorr, ax=ax, mask=mask, annot=False, cmap=plt.cm.Reds)
        ax.set_title("features correlation")
        extra_data["correlation"] = context.log_artifact(
            PlotArtifact("correlation", body=plt.gcf(), title="Correlation Matrix"),
            local_path=f"{plots_dest}/corr.html",
            db_key=False,
        )
    except Exception as e:
        context.logger.warn(f"Failed to create features correlation plot due to: {e}")

    gcf_clear(plt)
Ejemplo n.º 4
0
def summarize(
    context: MLClientCtx,
    table: DataItem,
    label_column: str = None,
    class_labels: List[str] = [],
    plot_hist: bool = True,
    plots_dest: str = "plots",
    update_dataset=False,
) -> None:
    """Summarize a table

    :param context:         the function context
    :param table:           MLRun input pointing to pandas dataframe (csv/parquet file path)
    :param label_column:    ground truth column label
    :param class_labels:    label for each class in tables and plots
    :param plot_hist:       (True) set this to False for large tables
    :param plots_dest:      destination folder of summary plots (relative to artifact_path)
    :param update_dataset:  when the table is a registered dataset update the charts in-place
    """
    df = table.as_df()
    header = df.columns.values
    extra_data = {}

    try:
        gcf_clear(plt)
        snsplt = sns.pairplot(df, hue=label_column)  # , diag_kws={"bw": 1.5})
        extra_data["histograms"] = context.log_artifact(
            PlotArtifact("histograms", body=plt.gcf()),
            local_path=f"{plots_dest}/hist.html",
            db_key=False,
        )
    except Exception as e:
        context.logger.error(
            f"Failed to create pairplot histograms due to: {e}")

    try:
        gcf_clear(plt)
        plot_cols = 3
        plot_rows = int((len(header) - 1) / plot_cols) + 1
        fig, ax = plt.subplots(plot_rows, plot_cols, figsize=(15, 4))
        fig.tight_layout(pad=2.0)
        for i in range(plot_rows * plot_cols):
            if i < len(header):
                sns.violinplot(
                    x=df[header[i]],
                    ax=ax[int(i / plot_cols)][i % plot_cols],
                    orient="h",
                    width=0.7,
                    inner="quartile",
                )
            else:
                fig.delaxes(ax[int(i / plot_cols)][i % plot_cols])
            i += 1
        extra_data["violin"] = context.log_artifact(
            PlotArtifact("violin", body=plt.gcf(), title="Violin Plot"),
            local_path=f"{plots_dest}/violin.html",
            db_key=False,
        )
    except Exception as e:
        context.logger.warn(
            f"Failed to create violin distribution plots due to: {e}")

    if label_column:
        labels = df.pop(label_column)
        imbtable = labels.value_counts(normalize=True).sort_index()
        try:
            gcf_clear(plt)
            balancebar = imbtable.plot(kind="bar",
                                       title="class imbalance - labels")
            balancebar.set_xlabel("class")
            balancebar.set_ylabel("proportion of total")
            extra_data["imbalance"] = context.log_artifact(
                PlotArtifact("imbalance", body=plt.gcf()),
                local_path=f"{plots_dest}/imbalance.html",
            )
        except Exception as e:
            context.logger.warn(
                f"Failed to create class imbalance plot due to: {e}")
        context.log_artifact(
            TableArtifact("imbalance-weights-vec",
                          df=pd.DataFrame({"weights": imbtable})),
            local_path=f"{plots_dest}/imbalance-weights-vec.csv",
            db_key=False,
        )

    tblcorr = df.corr()
    mask = np.zeros_like(tblcorr, dtype=np.bool)
    mask[np.triu_indices_from(mask)] = True

    dfcorr = pd.DataFrame(data=tblcorr, columns=header, index=header)
    dfcorr = dfcorr[
        np.arange(dfcorr.shape[0])[:, None] > np.arange(dfcorr.shape[1])]
    context.log_artifact(
        TableArtifact("correlation-matrix", df=tblcorr, visible=True),
        local_path=f"{plots_dest}/correlation-matrix.csv",
        db_key=False,
    )

    try:
        gcf_clear(plt)
        ax = plt.axes()
        sns.heatmap(tblcorr, ax=ax, mask=mask, annot=False, cmap=plt.cm.Reds)
        ax.set_title("features correlation")
        extra_data["correlation"] = context.log_artifact(
            PlotArtifact("correlation",
                         body=plt.gcf(),
                         title="Correlation Matrix"),
            local_path=f"{plots_dest}/corr.html",
            db_key=False,
        )
    except Exception as e:
        context.logger.warn(
            f"Failed to create features correlation plot due to: {e}")

    gcf_clear(plt)
    if update_dataset and table.meta and table.meta.kind == "dataset":
        from mlrun.artifacts import update_dataset_meta

        update_dataset_meta(table.meta, extra_data=extra_data)
Ejemplo n.º 5
0
def summarize(
    context,
    dask_key: str = "dask_key",
    dataset: mlrun.DataItem = None,
    label_column: str = "label",
    class_labels: List[str] = [],
    plot_hist: bool = True,
    plots_dest: str = "plots",
    dask_function: str = None,
    dask_client=None,
) -> None:
    """Summarize a table
    
    Connects to dask client through the function context, or through an optional
    user-supplied scheduler.

    :param context:         the function context
    :param dask_key:        key of dataframe in dask client "datasets" attribute
    :param label_column:    ground truth column label
    :param class_labels:    label for each class in tables and plots
    :param plot_hist:       (True) set this to False for large tables
    :param plots_dest:      destination folder of summary plots (relative to artifact_path)
    :param dask_function:   dask function url (db://..)
    :param dask_client:     dask client object
    """
    if dask_function:
        client = mlrun.import_function(dask_function).client
    elif dask_client:
        client = dask_client
    else:
        raise ValueError('dask client was not provided')

    if dask_key in client.datasets:
        table = client.get_dataset(dask_key)
    elif dataset:
        table = dataset.as_df(df_module=dd)
    else:
        context.logger.info(
            f"only these datasets are available {client.datasets} in client {client}"
        )
        raise Exception("dataset not found on dask cluster")
    header = table.columns.values

    gcf_clear(plt)
    table = table.compute()
    snsplt = sns.pairplot(table, hue=label_column, diag_kws={'bw': 1.5})
    context.log_artifact(PlotArtifact('histograms', body=plt.gcf()),
                         local_path=f"{plots_dest}/hist.html")

    gcf_clear(plt)
    labels = table.pop(label_column)
    if not class_labels:
        class_labels = labels.unique()
    class_balance_model = ClassBalance(labels=class_labels)
    class_balance_model.fit(labels)
    scale_pos_weight = class_balance_model.support_[
        0] / class_balance_model.support_[1]
    context.log_result("scale_pos_weight", f"{scale_pos_weight:0.2f}")
    context.log_artifact(PlotArtifact("imbalance", body=plt.gcf()),
                         local_path=f"{plots_dest}/imbalance.html")

    gcf_clear(plt)
    tblcorr = table.corr()
    ax = plt.axes()
    sns.heatmap(tblcorr, ax=ax, annot=False, cmap=plt.cm.Reds)
    ax.set_title("features correlation")
    context.log_artifact(PlotArtifact("correlation", body=plt.gcf()),
                         local_path=f"{plots_dest}/corr.html")
    gcf_clear(plt)