def _coxph_log_model( context, model, dataset_key: str = "coxhazard-summary", models_dest: str = "models", plot_cov_groups: bool = False, p_value: float = 0.005, plot_key: str = "km-cx", plots_dest: str = "plots", file_ext="csv", extra_data: dict = {}, ): """log a coxph model (and submodel locations) :param model: estimated coxph model :param extra_data: if this model wants to store the locations of submodels use this """ import matplotlib.pyplot as plt sumtbl = model.summary context.log_dataset(dataset_key, df=sumtbl, index=True, format=file_ext) model_bin = dumps(model) context.log_model( "cx-model", body=model_bin, artifact_path=os.path.join(context.artifact_path, models_dest), model_file="model.pkl", ) if plot_cov_groups: select_covars = summary[summary.p <= p_value].index.values for group in select_covars: axs = model.plot_covariate_groups(group, values=[0, 1]) for ix, ax in enumerate(axs): f = ax.get_figure() context.log_artifact( PlotArtifact(f"cx-{group}-{ix}", body=plt.gcf()), local_path=f"{plots_dest}/cx-{group}-{ix}.html", ) gcf_clear(plt)
def plot_stat(context, stat_name, stat_df): gcf_clear(plt) # Add chart ax = plt.axes() stat_chart = sns.barplot(x=stat_name, y='index', data=stat_df.sort_values(stat_name, ascending=False).reset_index(), ax=ax) plt.tight_layout() for p in stat_chart.patches: width = p.get_width() plt.text(5 + p.get_width(), p.get_y() + 0.55 * p.get_height(), '{:1.2f}'.format(width), ha='center', va='center') context.log_artifact(PlotArtifact(f'{stat_name}', body=plt.gcf()), local_path=os.path.join('plots', 'feature_selection', f'{stat_name}.html')) gcf_clear(plt)
def summarize( context, dask_key: str = "dask_key", dataset: mlrun.DataItem = None, label_column: str = "label", plots_dest: str = "plots", dask_function: str = None, dask_client=None, ) -> None: """Summarize a table Connects to dask client through the function context, or through an optional user-supplied scheduler. :param context: the function context :param dask_key: key of dataframe in dask client "datasets" attribute :param label_column: ground truth column label :param plots_dest: destination folder of summary plots (relative to artifact_path) :param dask_function: dask function url (db://..) :param dask_client: dask client object """ if dask_function: client = mlrun.import_function(dask_function).client elif dask_client: client = dask_client else: raise ValueError('dask client was not provided') if dask_key in client.datasets: table = client.get_dataset(dask_key) elif dataset: #table = dataset.as_df(df_module=dd) table = dataset.as_df() else: context.logger.info(f"only these datasets are available {client.datasets} in client {client}") raise Exception("dataset not found on dask cluster") df = table header = df.columns.values extra_data = {} try: gcf_clear(plt) snsplt = sns.pairplot(df, hue=label_column) # , diag_kws={"bw": 1.5}) extra_data["histograms"] = context.log_artifact( PlotArtifact("histograms", body=plt.gcf()), local_path=f"{plots_dest}/hist.html", db_key=False, ) except Exception as e: context.logger.error(f"Failed to create pairplot histograms due to: {e}") try: gcf_clear(plt) plot_cols = 3 plot_rows = int((len(header) - 1) / plot_cols) + 1 fig, ax = plt.subplots(plot_rows, plot_cols, figsize=(15, 4)) fig.tight_layout(pad=2.0) for i in range(plot_rows * plot_cols): if i < len(header): sns.violinplot( x=df[header[i]], ax=ax[int(i / plot_cols)][i % plot_cols], orient="h", width=0.7, inner="quartile", ) else: fig.delaxes(ax[int(i / plot_cols)][i % plot_cols]) i += 1 extra_data["violin"] = context.log_artifact( PlotArtifact("violin", body=plt.gcf(), title="Violin Plot"), local_path=f"{plots_dest}/violin.html", db_key=False, ) except Exception as e: context.logger.warn(f"Failed to create violin distribution plots due to: {e}") if label_column: labels = df.pop(label_column) imbtable = labels.value_counts(normalize=True).sort_index() try: gcf_clear(plt) balancebar = imbtable.plot(kind="bar", title="class imbalance - labels") balancebar.set_xlabel("class") balancebar.set_ylabel("proportion of total") extra_data["imbalance"] = context.log_artifact( PlotArtifact("imbalance", body=plt.gcf()), local_path=f"{plots_dest}/imbalance.html", ) except Exception as e: context.logger.warn(f"Failed to create class imbalance plot due to: {e}") context.log_artifact( TableArtifact( "imbalance-weights-vec", df=pd.DataFrame({"weights": imbtable}) ), local_path=f"{plots_dest}/imbalance-weights-vec.csv", db_key=False, ) tblcorr = df.corr() mask = np.zeros_like(tblcorr, dtype=np.bool) mask[np.triu_indices_from(mask)] = True dfcorr = pd.DataFrame(data=tblcorr, columns=header, index=header) dfcorr = dfcorr[np.arange(dfcorr.shape[0])[:, None] > np.arange(dfcorr.shape[1])] context.log_artifact( TableArtifact("correlation-matrix", df=tblcorr, visible=True), local_path=f"{plots_dest}/correlation-matrix.csv", db_key=False, ) try: gcf_clear(plt) ax = plt.axes() sns.heatmap(tblcorr, ax=ax, mask=mask, annot=False, cmap=plt.cm.Reds) ax.set_title("features correlation") extra_data["correlation"] = context.log_artifact( PlotArtifact("correlation", body=plt.gcf(), title="Correlation Matrix"), local_path=f"{plots_dest}/corr.html", db_key=False, ) except Exception as e: context.logger.warn(f"Failed to create features correlation plot due to: {e}") gcf_clear(plt)
def summarize( context: MLClientCtx, table: DataItem, label_column: str = None, class_labels: List[str] = [], plot_hist: bool = True, plots_dest: str = "plots", update_dataset=False, ) -> None: """Summarize a table :param context: the function context :param table: MLRun input pointing to pandas dataframe (csv/parquet file path) :param label_column: ground truth column label :param class_labels: label for each class in tables and plots :param plot_hist: (True) set this to False for large tables :param plots_dest: destination folder of summary plots (relative to artifact_path) :param update_dataset: when the table is a registered dataset update the charts in-place """ df = table.as_df() header = df.columns.values extra_data = {} try: gcf_clear(plt) snsplt = sns.pairplot(df, hue=label_column) # , diag_kws={"bw": 1.5}) extra_data["histograms"] = context.log_artifact( PlotArtifact("histograms", body=plt.gcf()), local_path=f"{plots_dest}/hist.html", db_key=False, ) except Exception as e: context.logger.error( f"Failed to create pairplot histograms due to: {e}") try: gcf_clear(plt) plot_cols = 3 plot_rows = int((len(header) - 1) / plot_cols) + 1 fig, ax = plt.subplots(plot_rows, plot_cols, figsize=(15, 4)) fig.tight_layout(pad=2.0) for i in range(plot_rows * plot_cols): if i < len(header): sns.violinplot( x=df[header[i]], ax=ax[int(i / plot_cols)][i % plot_cols], orient="h", width=0.7, inner="quartile", ) else: fig.delaxes(ax[int(i / plot_cols)][i % plot_cols]) i += 1 extra_data["violin"] = context.log_artifact( PlotArtifact("violin", body=plt.gcf(), title="Violin Plot"), local_path=f"{plots_dest}/violin.html", db_key=False, ) except Exception as e: context.logger.warn( f"Failed to create violin distribution plots due to: {e}") if label_column: labels = df.pop(label_column) imbtable = labels.value_counts(normalize=True).sort_index() try: gcf_clear(plt) balancebar = imbtable.plot(kind="bar", title="class imbalance - labels") balancebar.set_xlabel("class") balancebar.set_ylabel("proportion of total") extra_data["imbalance"] = context.log_artifact( PlotArtifact("imbalance", body=plt.gcf()), local_path=f"{plots_dest}/imbalance.html", ) except Exception as e: context.logger.warn( f"Failed to create class imbalance plot due to: {e}") context.log_artifact( TableArtifact("imbalance-weights-vec", df=pd.DataFrame({"weights": imbtable})), local_path=f"{plots_dest}/imbalance-weights-vec.csv", db_key=False, ) tblcorr = df.corr() mask = np.zeros_like(tblcorr, dtype=np.bool) mask[np.triu_indices_from(mask)] = True dfcorr = pd.DataFrame(data=tblcorr, columns=header, index=header) dfcorr = dfcorr[ np.arange(dfcorr.shape[0])[:, None] > np.arange(dfcorr.shape[1])] context.log_artifact( TableArtifact("correlation-matrix", df=tblcorr, visible=True), local_path=f"{plots_dest}/correlation-matrix.csv", db_key=False, ) try: gcf_clear(plt) ax = plt.axes() sns.heatmap(tblcorr, ax=ax, mask=mask, annot=False, cmap=plt.cm.Reds) ax.set_title("features correlation") extra_data["correlation"] = context.log_artifact( PlotArtifact("correlation", body=plt.gcf(), title="Correlation Matrix"), local_path=f"{plots_dest}/corr.html", db_key=False, ) except Exception as e: context.logger.warn( f"Failed to create features correlation plot due to: {e}") gcf_clear(plt) if update_dataset and table.meta and table.meta.kind == "dataset": from mlrun.artifacts import update_dataset_meta update_dataset_meta(table.meta, extra_data=extra_data)
def summarize( context, dask_key: str = "dask_key", dataset: mlrun.DataItem = None, label_column: str = "label", class_labels: List[str] = [], plot_hist: bool = True, plots_dest: str = "plots", dask_function: str = None, dask_client=None, ) -> None: """Summarize a table Connects to dask client through the function context, or through an optional user-supplied scheduler. :param context: the function context :param dask_key: key of dataframe in dask client "datasets" attribute :param label_column: ground truth column label :param class_labels: label for each class in tables and plots :param plot_hist: (True) set this to False for large tables :param plots_dest: destination folder of summary plots (relative to artifact_path) :param dask_function: dask function url (db://..) :param dask_client: dask client object """ if dask_function: client = mlrun.import_function(dask_function).client elif dask_client: client = dask_client else: raise ValueError('dask client was not provided') if dask_key in client.datasets: table = client.get_dataset(dask_key) elif dataset: table = dataset.as_df(df_module=dd) else: context.logger.info( f"only these datasets are available {client.datasets} in client {client}" ) raise Exception("dataset not found on dask cluster") header = table.columns.values gcf_clear(plt) table = table.compute() snsplt = sns.pairplot(table, hue=label_column, diag_kws={'bw': 1.5}) context.log_artifact(PlotArtifact('histograms', body=plt.gcf()), local_path=f"{plots_dest}/hist.html") gcf_clear(plt) labels = table.pop(label_column) if not class_labels: class_labels = labels.unique() class_balance_model = ClassBalance(labels=class_labels) class_balance_model.fit(labels) scale_pos_weight = class_balance_model.support_[ 0] / class_balance_model.support_[1] context.log_result("scale_pos_weight", f"{scale_pos_weight:0.2f}") context.log_artifact(PlotArtifact("imbalance", body=plt.gcf()), local_path=f"{plots_dest}/imbalance.html") gcf_clear(plt) tblcorr = table.corr() ax = plt.axes() sns.heatmap(tblcorr, ax=ax, annot=False, cmap=plt.cm.Reds) ax.set_title("features correlation") context.log_artifact(PlotArtifact("correlation", body=plt.gcf()), local_path=f"{plots_dest}/corr.html") gcf_clear(plt)