def learning_curves(model): """model class dependent WIP get training history plots for xgboost, lightgbm returns list of PlotArtifacts, can be empty if no history is found """ plots = [] # do this here and not in the call to learning_curve plots, # this is default approach for xgboost and lightgbm if hasattr(model, "evals_result"): results = model.evals_result() train_set = list(results.items())[0] valid_set = list(results.items())[1] learning_curves = pd.DataFrame({ "train_error": train_set[1]["error"], "train_auc": train_set[1]["auc"], "valid_error": valid_set[1]["error"], "valid_auc": valid_set[1]["auc"], }) plt.clf() # gcf_clear(plt) fig, ax = plt.subplots() plt.xlabel("# training examples") plt.ylabel("auc") plt.title("learning curve - auc") ax.plot(learning_curves.train_auc, label="train") ax.plot(learning_curves.valid_auc, label="valid") ax.legend(loc="lower left") plots.append(PlotArtifact("learning curve - auc", body=plt.gcf())) plt.clf() # gcf_clear(plt) fig, ax = plt.subplots() plt.xlabel("# training examples") plt.ylabel("error rate") plt.title("learning curve - error") ax.plot(learning_curves.train_error, label="train") ax.plot(learning_curves.valid_error, label="valid") ax.legend(loc="lower left") plots.append(PlotArtifact("learning curve - taoot", body=plt.gcf())) # elif some other model history api... return plots
def feature_importances(model, header): """Display estimated feature importances Only works for models with attribute 'feature_importances_` :param model: fitted model :param header: feature labels """ if not hasattr(model, "feature_importances_"): raise Exception( "feature importaces are only available for some models") # create a feature importance table with desired labels zipped = zip(model.feature_importances_, header) feature_imp = pd.DataFrame(sorted(zipped), columns=["freq", "feature" ]).sort_values(by="freq", ascending=False) plt.clf() #gcf_clear(plt) plt.figure(figsize=(20, 10)) sns.barplot(x="freq", y="feature", data=feature_imp) plt.title("features") plt.tight_layout() return (PlotArtifact("feature-importances", body=plt.gcf()), TableArtifact("feature-importances-tbl", df=feature_imp))
def _kaplan_meier_log_model( context, model, time_column: str = "tenure", dataset_key: str = "km-timelines", plot_key: str = "km-survival", plots_dest: str = "plots", models_dest: str = "models", file_ext: str = "csv", ): import matplotlib.pyplot as plt o = [] for obj in model.__dict__.keys(): if isinstance(model.__dict__[obj], pd.DataFrame): o.append(model.__dict__[obj]) df = pd.concat(o, axis=1) df.index.name = time_column context.log_dataset(dataset_key, df=df, index=True, format=file_ext) model.plot() context.log_artifact( PlotArtifact(plot_key, body=plt.gcf()), local_path=f"{plots_dest}/{plot_key}.html", ) context.log_model( "km-model", body=dumps(model), model_dir=f"{models_dest}/km", model_file="model.pkl", )
def feature_importances(model, header): """Display estimated feature importances Only works for models with attribute 'feature_importances_` :param model: fitted model :param header: feature labels """ if not hasattr(model, "feature_importances_"): raise Exception( "feature importances are only available for some models, if you got " "here then please make sure to check your estimated model for a " "`feature_importances_` attribute before calling this method") # create a feature importance table with desired labels zipped = zip(model.feature_importances_, header) feature_imp = pd.DataFrame(sorted(zipped), columns=["freq", "feature" ]).sort_values(by="freq", ascending=False) plt.clf() # gcf_clear(plt) plt.figure() sns.barplot(x="freq", y="feature", data=feature_imp) plt.title("features") plt.tight_layout() return ( PlotArtifact("feature-importances", body=plt.gcf(), title="Feature Importances"), feature_imp, )
def precision_recall_multi(ytest_b, yprob, labels, scoring="micro"): """""" n_classes = len(labels) precision = dict() recall = dict() avg_prec = dict() for i in range(n_classes): precision[i], recall[i], _ = metrics.precision_recall_curve( ytest_b[:, i], yprob[:, i]) avg_prec[i] = metrics.average_precision_score(ytest_b[:, i], yprob[:, i]) precision["micro"], recall["micro"], _ = metrics.precision_recall_curve( ytest_b.ravel(), yprob.ravel()) avg_prec["micro"] = metrics.average_precision_score(ytest_b, yprob, average="micro") ap_micro = avg_prec["micro"] # model_metrics.update({'precision-micro-avg-classes': ap_micro}) # gcf_clear(plt) colors = cycle( ["navy", "turquoise", "darkorange", "cornflowerblue", "teal"]) plt.figure() f_scores = np.linspace(0.2, 0.8, num=4) lines = [] labels = [] for f_score in f_scores: x = np.linspace(0.01, 1) y = f_score * x / (2 * x - f_score) (l, ) = plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=0.2) plt.annotate(f"f1={f_score:0.1f}", xy=(0.9, y[45] + 0.02)) lines.append(l) labels.append("iso-f1 curves") (l, ) = plt.plot(recall["micro"], precision["micro"], color="gold", lw=10) lines.append(l) labels.append(f"micro-average precision-recall (area = {ap_micro:0.2f})") for i, color in zip(range(n_classes), colors): (l, ) = plt.plot(recall[i], precision[i], color=color, lw=2) lines.append(l) labels.append( f"precision-recall for class {i} (area = {avg_prec[i]:0.2f})") # fig = plt.gcf() # fig.subplots_adjust(bottom=0.25) plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel("recall") plt.ylabel("precision") plt.title("precision recall - multiclass") plt.legend(lines, labels, loc=(0, -0.41), prop=dict(size=10)) return PlotArtifact( "precision-recall-multiclass", body=plt.gcf(), title="Multiclass Precision Recall", )
def plot_iter(context, iterations, col='accuracy', num_bins=10): df = pd.read_csv(BytesIO(iterations.get())) x = df['output.{}'.format(col)] fig, ax = plt.subplots(figsize=(6, 6)) n, bins, patches = ax.hist(x, num_bins, density=1) ax.set_xlabel('Accuraccy') ax.set_ylabel('Count') context.log_artifact(PlotArtifact('myfig', body=fig))
def plot_roc( context, y_labels, y_probs, key="roc", plots_dir: str = "plots", fmt="png", fpr_label: str = "false positive rate", tpr_label: str = "true positive rate", title: str = "roc curve", legend_loc: str = "best", ): """plot roc curves TODO: add averaging method (as string) that was used to create probs, display in legend :param context: the function context :param y_labels: ground truth labels, hot encoded for multiclass :param y_probs: model prediction probabilities :param key: ("roc") key of plot in artifact store :param plots_dir: ("plots") destination folder relative path to artifact path :param fmt: ("png") plot format :param fpr_label: ("false positive rate") x-axis labels :param tpr_label: ("true positive rate") y-axis labels :param title: ("roc curve") title of plot :param legend_loc: ("best") location of plot legend """ # clear matplotlib current figure gcf_clear(plt) # draw 45 degree line plt.plot([0, 1], [0, 1], "k--") # labelling plt.xlabel(fpr_label) plt.ylabel(tpr_label) plt.title(title) plt.legend(loc=legend_loc) # single ROC or mutliple if y_labels.shape[1] > 1: # data accummulators by class fpr = dict() tpr = dict() roc_auc = dict() for i in range(y_labels[:, :-1].shape[1]): fpr[i], tpr[i], _ = metrics.roc_curve(y_labels[:, i], y_probs[:, i], pos_label=1) roc_auc[i] = metrics.auc(fpr[i], tpr[i]) plt.plot(fpr[i], tpr[i], label=f"class {i}") else: fpr, tpr, _ = metrics.roc_curve(y_labels, y_probs[:, 1], pos_label=1) plt.plot(fpr, tpr, label=f"positive class") fname = f"{plots_dir}/{key}.html" context.log_artifact(PlotArtifact(key, body=plt.gcf()), local_path=fname)
def confusion_matrix(model, xtest, ytest): cmd = metrics.plot_confusion_matrix(model, xtest, ytest, normalize='all', values_format='.2g', cmap=plt.cm.Blues) # for now only 1, add different views to this array for display in UI return PlotArtifact("confusion-matrix-normalized", body=cmd.figure_)
def precision_recall_multi(ytest_b, yprob, labels, scoring="micro"): """ """ n_classes = len(labels) precision = dict() recall = dict() avg_prec = dict() for i in range(n_classes): precision[i], recall[i], _ = metrics.precision_recall_curve( ytest_b[:, i], yprob[:, i]) avg_prec[i] = metrics.average_precision_score(ytest_b[:, i], yprob[:, i]) precision["micro"], recall["micro"], _ = metrics.precision_recall_curve( ytest_b.ravel(), yprob.ravel()) avg_prec["micro"] = metrics.average_precision_score(ytest_b, yprob, average="micro") ap_micro = avg_prec["micro"] model_metrics.update({'precision-micro-avg-classes': ap_micro}) gcf_clear(plt) colors = cycle( ['navy', 'turquoise', 'darkorange', 'cornflowerblue', 'teal']) plt.figure(figsize=(7, 8)) f_scores = np.linspace(0.2, 0.8, num=4) lines = [] labels = [] for f_score in f_scores: x = np.linspace(0.01, 1) y = f_score * x / (2 * x - f_score) l, = plt.plot(x[y >= 0], y[y >= 0], color='gray', alpha=0.2) plt.annotate('f1={0:0.1f}'.format(f_score), xy=(0.9, y[45] + 0.02)) lines.append(l) labels.append('iso-f1 curves') l, = plt.plot(recall["micro"], precision["micro"], color='gold', lw=10) lines.append(l) labels.append(f'micro-average precision-recall (area = {ap_micro:0.2f})') for i, color in zip(range(n_classes), colors): l, = plt.plot(recall[i], precision[i], color=color, lw=2) lines.append(l) labels.append( f'precision-recall for class {i} (area = {avg_prec[i]:0.2f})') fig = plt.gcf() fig.subplots_adjust(bottom=0.25) plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel('recall') plt.ylabel('precision') plt.title('precision recall - multiclass') plt.legend(lines, labels, loc=(0, -.41), prop=dict(size=10)) return PlotArtifact("precision-recall-multiclass", body=plt.gcf())
def handler(context, p1=1, p2="xx"): """this is a simple function :param context: handler context :param p1: first param :param p2: another param """ # access input metadata, values, and inputs print(f"Run: {context.name} (uid={context.uid})") print(f"Params: p1={p1}, p2={p2}") time.sleep(1) # log the run results (scalar values) context.log_result("accuracy", p1 * 2) context.log_result("loss", p1 * 3) # add a lable/tag to this run context.set_label("category", "tests") # create a matplot figure and store as artifact fig, ax = plt.subplots() np.random.seed(0) x, y = np.random.normal(size=(2, 200)) color, size = np.random.random((2, 200)) ax.scatter(x, y, c=color, s=500 * size, alpha=0.3) ax.grid(color="lightgray", alpha=0.7) context.log_artifact( PlotArtifact("myfig", body=fig, title="my plot")) # create a dataframe artifact df = pd.DataFrame([{ "A": 10, "B": 100 }, { "A": 11, "B": 110 }, { "A": 12, "B": 120 }]) context.log_dataset("mydf", df=df) # Log an ML Model artifact context.log_model( "mymodel", body=b"abc is 123", model_file="model.txt", model_dir="data", metrics={"accuracy": 0.85}, parameters={"xx": "abc"}, labels={"framework": "xgboost"}, ) return "my resp"
def precision_recall_bin(model, xtest, ytest, yprob): """ """ # precision-recall #gcf_clear(plt) disp = metrics.plot_precision_recall_curve(model, xtest, ytest) disp.ax_.set_title( f'precision recall: AP={metrics.average_precision_score(ytest, yprob):0.2f}' ) return PlotArtifact("precision-recall-binary", body=disp.figure_)
def precision_recall_bin(model, xtest, ytest, yprob, clear=False): """""" if clear: gcf_clear(plt) disp = metrics.plot_precision_recall_curve(model, xtest, ytest) disp.ax_.set_title( f"precision recall: AP={metrics.average_precision_score(ytest, yprob):0.2f}" ) return PlotArtifact("precision-recall-binary", body=disp.figure_, title="Binary Precision Recall")
def roc_bin(ytest, yprob): """ """ # ROC plot #gcf_clear(plt) fpr, tpr, _ = metrics.roc_curve(ytest, yprob) plt.figure(1) plt.plot([0, 1], [0, 1], 'k--') plt.plot(fpr, tpr, label='a label') plt.xlabel('false positive rate') plt.ylabel('true positive rate') plt.title('roc curve') plt.legend(loc='best') return PlotArtifact("roc-binary", body=plt.gcf())
def roc_bin(ytest, yprob, clear: bool = False): """""" # ROC plot if clear: gcf_clear(plt) fpr, tpr, _ = metrics.roc_curve(ytest, yprob) plt.figure() plt.plot([0, 1], [0, 1], "k--") plt.plot(fpr, tpr, label="a label") plt.xlabel("false positive rate") plt.ylabel("true positive rate") plt.title("roc curve") plt.legend(loc="best") return PlotArtifact("roc-binary", body=plt.gcf(), title="Binary ROC Curve")
def confusion_matrix(model, xtest, ytest, cmap="Blues"): cmd = metrics.plot_confusion_matrix( model, xtest, ytest, normalize="all", values_format=".2g", cmap=plt.get_cmap(cmap), ) # for now only 1, add different views to this array for display in UI cmd.plot() return PlotArtifact( "confusion-matrix-normalized", body=cmd.figure_, title="Confusion Matrix - Normalized Plot", )
def learning_curves(context: MLClientCtx, results: dict, figsz: Tuple[int, int] = (10, 10), plots_dest: str = "plots") -> None: """plot xgb learning curves this will also log a model's learning curves """ plt.clf() plt.figure(figsize=figsz) plt.plot(results["train"]["my_rmsle"], label="train-my-rmsle") plt.plot(results["valid"]["my_rmsle"], label="valid-my-rmsle") plt.title(f"learning curves") plt.legend() context.log_artifact(PlotArtifact(f"learning-curves", body=plt.gcf()), local_path=f"{plots_dest}/learning-curves.html")
def _coxph_log_model( context, model, dataset_key: str = "coxhazard-summary", models_dest: str = "models", plot_cov_groups: bool = False, p_value: float = 0.005, plot_key: str = "km-cx", plots_dest: str = "plots", file_ext="csv", extra_data: dict = {}, ): """log a coxph model (and submodel locations) :param model: estimated coxph model :param extra_data: if this model wants to store the locations of submodels use this """ import matplotlib.pyplot as plt sumtbl = model.summary context.log_dataset(dataset_key, df=sumtbl, index=True, format=file_ext) model_bin = dumps(model) context.log_model( "cx-model", body=model_bin, artifact_path=os.path.join(context.artifact_path, models_dest), model_file="model.pkl", ) if plot_cov_groups: select_covars = summary[summary.p <= p_value].index.values for group in select_covars: axs = model.plot_covariate_groups(group, values=[0, 1]) for ix, ax in enumerate(axs): f = ax.get_figure() context.log_artifact( PlotArtifact(f"cx-{group}-{ix}", body=plt.gcf()), local_path=f"{plots_dest}/cx-{group}-{ix}.html", ) gcf_clear(plt)
def plot_stat(context, stat_name, stat_df): gcf_clear(plt) # Add chart ax = plt.axes() stat_chart = sns.barplot(x=stat_name, y='index', data=stat_df.sort_values(stat_name, ascending=False).reset_index(), ax=ax) plt.tight_layout() for p in stat_chart.patches: width = p.get_width() plt.text(5 + p.get_width(), p.get_y() + 0.55 * p.get_height(), '{:1.2f}'.format(width), ha='center', va='center') context.log_artifact(PlotArtifact(f'{stat_name}', body=plt.gcf()), local_path=os.path.join('plots', 'feature_selection', f'{stat_name}.html')) gcf_clear(plt)
def plot_importance(context, model, key: str = "feature-importances", plots_dest: str = "plots"): """Display estimated feature importances Only works for models with attribute 'feature_importances_` **legacy version please deprecate in functions and demos** :param context: function context :param model: fitted model :param key: key of feature importances plot and table in artifact store :param plots_dest: subfolder in artifact store """ if not hasattr(model, "feature_importances_"): raise Exception( "feature importaces are only available for some models") # create a feature importance table with desired labels zipped = zip(model.feature_importances_, context.header) feature_imp = pd.DataFrame(sorted(zipped), columns=["freq", "feature" ]).sort_values(by="freq", ascending=False) gcf_clear(plt) plt.figure(figsize=(20, 10)) sns.barplot(x="freq", y="feature", data=feature_imp) plt.title("features") plt.tight_layout() fname = f"{plots_dest}/{key}.html" context.log_artifact(PlotArtifact(key, body=plt.gcf()), local_path=fname) # feature importances are also saved as a csv table (generally small): fname = key + "-tbl.csv" return context.log_dataset(key + "-tbl", df=feature_imp, local_path=fname)
def plot_confusion_matrix(context: MLClientCtx, labels, predictions, key: str = "confusion_matrix", plots_dir: str = "plots", colormap: str = "Blues", fmt: str = "png", sample_weight=None): """Create a confusion matrix. Plot and save a confusion matrix using test data from a modelline step. See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html TODO: fix label alignment TODO: consider using another packaged version TODO: refactor to take params dict for plot options :param context: function context :param labels: validation data ground-truth labels :param predictions: validation data predictions :param key: str :param plots_dir: relative path of plots in artifact store :param colormap: colourmap for confusion matrix :param fmt: plot format :param sample_weight: sample weights """ _gcf_clear(plt) cm = metrics.confusion_matrix(labels, predictions, sample_weight=None) sns.heatmap(cm, annot=True, cmap=colormap, square=True) fig = plt.gcf() fname = f"{plots_dir}/{key}.{fmt}" fig.savefig(os.path.join(context.artifact_path, fname)) context.log_artifact(PlotArtifact(key, body=fig), local_path=fname)
def summarize( context, dask_key: str = "dask_key", dataset: mlrun.DataItem = None, label_column: str = "label", plots_dest: str = "plots", dask_function: str = None, dask_client=None, ) -> None: """Summarize a table Connects to dask client through the function context, or through an optional user-supplied scheduler. :param context: the function context :param dask_key: key of dataframe in dask client "datasets" attribute :param label_column: ground truth column label :param plots_dest: destination folder of summary plots (relative to artifact_path) :param dask_function: dask function url (db://..) :param dask_client: dask client object """ if dask_function: client = mlrun.import_function(dask_function).client elif dask_client: client = dask_client else: raise ValueError('dask client was not provided') if dask_key in client.datasets: table = client.get_dataset(dask_key) elif dataset: #table = dataset.as_df(df_module=dd) table = dataset.as_df() else: context.logger.info(f"only these datasets are available {client.datasets} in client {client}") raise Exception("dataset not found on dask cluster") df = table header = df.columns.values extra_data = {} try: gcf_clear(plt) snsplt = sns.pairplot(df, hue=label_column) # , diag_kws={"bw": 1.5}) extra_data["histograms"] = context.log_artifact( PlotArtifact("histograms", body=plt.gcf()), local_path=f"{plots_dest}/hist.html", db_key=False, ) except Exception as e: context.logger.error(f"Failed to create pairplot histograms due to: {e}") try: gcf_clear(plt) plot_cols = 3 plot_rows = int((len(header) - 1) / plot_cols) + 1 fig, ax = plt.subplots(plot_rows, plot_cols, figsize=(15, 4)) fig.tight_layout(pad=2.0) for i in range(plot_rows * plot_cols): if i < len(header): sns.violinplot( x=df[header[i]], ax=ax[int(i / plot_cols)][i % plot_cols], orient="h", width=0.7, inner="quartile", ) else: fig.delaxes(ax[int(i / plot_cols)][i % plot_cols]) i += 1 extra_data["violin"] = context.log_artifact( PlotArtifact("violin", body=plt.gcf(), title="Violin Plot"), local_path=f"{plots_dest}/violin.html", db_key=False, ) except Exception as e: context.logger.warn(f"Failed to create violin distribution plots due to: {e}") if label_column: labels = df.pop(label_column) imbtable = labels.value_counts(normalize=True).sort_index() try: gcf_clear(plt) balancebar = imbtable.plot(kind="bar", title="class imbalance - labels") balancebar.set_xlabel("class") balancebar.set_ylabel("proportion of total") extra_data["imbalance"] = context.log_artifact( PlotArtifact("imbalance", body=plt.gcf()), local_path=f"{plots_dest}/imbalance.html", ) except Exception as e: context.logger.warn(f"Failed to create class imbalance plot due to: {e}") context.log_artifact( TableArtifact( "imbalance-weights-vec", df=pd.DataFrame({"weights": imbtable}) ), local_path=f"{plots_dest}/imbalance-weights-vec.csv", db_key=False, ) tblcorr = df.corr() mask = np.zeros_like(tblcorr, dtype=np.bool) mask[np.triu_indices_from(mask)] = True dfcorr = pd.DataFrame(data=tblcorr, columns=header, index=header) dfcorr = dfcorr[np.arange(dfcorr.shape[0])[:, None] > np.arange(dfcorr.shape[1])] context.log_artifact( TableArtifact("correlation-matrix", df=tblcorr, visible=True), local_path=f"{plots_dest}/correlation-matrix.csv", db_key=False, ) try: gcf_clear(plt) ax = plt.axes() sns.heatmap(tblcorr, ax=ax, mask=mask, annot=False, cmap=plt.cm.Reds) ax.set_title("features correlation") extra_data["correlation"] = context.log_artifact( PlotArtifact("correlation", body=plt.gcf(), title="Correlation Matrix"), local_path=f"{plots_dest}/corr.html", db_key=False, ) except Exception as e: context.logger.warn(f"Failed to create features correlation plot due to: {e}") gcf_clear(plt)
def summarize( context: MLClientCtx, table: DataItem, label_column: str = None, class_labels: List[str] = [], plot_hist: bool = True, plots_dest: str = "plots", update_dataset=False, ) -> None: """Summarize a table :param context: the function context :param table: MLRun input pointing to pandas dataframe (csv/parquet file path) :param label_column: ground truth column label :param class_labels: label for each class in tables and plots :param plot_hist: (True) set this to False for large tables :param plots_dest: destination folder of summary plots (relative to artifact_path) :param update_dataset: when the table is a registered dataset update the charts in-place """ df = table.as_df() header = df.columns.values extra_data = {} try: gcf_clear(plt) snsplt = sns.pairplot(df, hue=label_column) # , diag_kws={"bw": 1.5}) extra_data["histograms"] = context.log_artifact( PlotArtifact("histograms", body=plt.gcf()), local_path=f"{plots_dest}/hist.html", db_key=False, ) except Exception as e: context.logger.error( f"Failed to create pairplot histograms due to: {e}") try: gcf_clear(plt) plot_cols = 3 plot_rows = int((len(header) - 1) / plot_cols) + 1 fig, ax = plt.subplots(plot_rows, plot_cols, figsize=(15, 4)) fig.tight_layout(pad=2.0) for i in range(plot_rows * plot_cols): if i < len(header): sns.violinplot( x=df[header[i]], ax=ax[int(i / plot_cols)][i % plot_cols], orient="h", width=0.7, inner="quartile", ) else: fig.delaxes(ax[int(i / plot_cols)][i % plot_cols]) i += 1 extra_data["violin"] = context.log_artifact( PlotArtifact("violin", body=plt.gcf(), title="Violin Plot"), local_path=f"{plots_dest}/violin.html", db_key=False, ) except Exception as e: context.logger.warn( f"Failed to create violin distribution plots due to: {e}") if label_column: labels = df.pop(label_column) imbtable = labels.value_counts(normalize=True).sort_index() try: gcf_clear(plt) balancebar = imbtable.plot(kind="bar", title="class imbalance - labels") balancebar.set_xlabel("class") balancebar.set_ylabel("proportion of total") extra_data["imbalance"] = context.log_artifact( PlotArtifact("imbalance", body=plt.gcf()), local_path=f"{plots_dest}/imbalance.html", ) except Exception as e: context.logger.warn( f"Failed to create class imbalance plot due to: {e}") context.log_artifact( TableArtifact("imbalance-weights-vec", df=pd.DataFrame({"weights": imbtable})), local_path=f"{plots_dest}/imbalance-weights-vec.csv", db_key=False, ) tblcorr = df.corr() mask = np.zeros_like(tblcorr, dtype=np.bool) mask[np.triu_indices_from(mask)] = True dfcorr = pd.DataFrame(data=tblcorr, columns=header, index=header) dfcorr = dfcorr[ np.arange(dfcorr.shape[0])[:, None] > np.arange(dfcorr.shape[1])] context.log_artifact( TableArtifact("correlation-matrix", df=tblcorr, visible=True), local_path=f"{plots_dest}/correlation-matrix.csv", db_key=False, ) try: gcf_clear(plt) ax = plt.axes() sns.heatmap(tblcorr, ax=ax, mask=mask, annot=False, cmap=plt.cm.Reds) ax.set_title("features correlation") extra_data["correlation"] = context.log_artifact( PlotArtifact("correlation", body=plt.gcf(), title="Correlation Matrix"), local_path=f"{plots_dest}/corr.html", db_key=False, ) except Exception as e: context.logger.warn( f"Failed to create features correlation plot due to: {e}") gcf_clear(plt) if update_dataset and table.meta and table.meta.kind == "dataset": from mlrun.artifacts import update_dataset_meta update_dataset_meta(table.meta, extra_data=extra_data)
def train_model(context: MLClientCtx, dataset: DataItem, model_pkg_class: str, label_column: str = "label", train_validation_size: float = 0.75, sample: float = 1.0, models_dest: str = "models", test_set_key: str = "test_set", plots_dest: str = "plots", dask_key: str = "dask_key", dask_persist: bool = False, scheduler_key: str = '', file_ext: str = "parquet", random_state: int = 42) -> None: """ Train a sklearn classifier with Dask :param context: Function context. :param dataset: Raw data file. :param model_pkg_class: Model to train, e.g, "sklearn.ensemble.RandomForestClassifier", or json model config. :param label_column: (label) Ground-truth y labels. :param train_validation_size: (0.75) Train validation set proportion out of the full dataset. :param sample: (1.0) Select sample from dataset (n-rows/% of total), randomzie rows as default. :param models_dest: (models) Models subfolder on artifact path. :param test_set_key: (test_set) Mlrun db key of held out data in artifact store. :param plots_dest: (plots) Plot subfolder on artifact path. :param dask_key: (dask key) Key of dataframe in dask client "datasets" attribute. :param dask_persist: (False) Should the data be persisted (through the `client.persist`) :param scheduler_key: (scheduler) Dask scheduler configuration, json also logged as an artifact. :param file_ext: (parquet) format for test_set_key hold out data :param random_state: (42) sklearn seed """ if scheduler_key: client = Client(scheduler_key) else: client = Client() context.logger.info("Read Data") df = dataset.as_df(df_module=dd) context.logger.info("Prep Data") numerics = ['int16', 'int32', 'int64', 'float16', 'float32', 'float64'] df = df.select_dtypes(include=numerics) if df.isna().any().any().compute() == True: raise Exception('NAs valus found') df_header = df.columns df = df.sample(frac=sample).reset_index(drop=True) encoder = LabelEncoder() encoder = encoder.fit(df[label_column]) X = df.drop(label_column, axis=1).to_dask_array(lengths=True) y = encoder.transform(df[label_column]) classes = df[label_column].drop_duplicates() # no unique values in dask classes = [str(i) for i in classes] context.logger.info("Split and Train") X_train, X_test, y_train, y_test = model_selection.train_test_split( X, y, train_size=train_validation_size, random_state=random_state) scaler = StandardScaler() scaler = scaler.fit(X_train) X_train_transformed = scaler.transform(X_train) X_test_transformed = scaler.transform(X_test) model_config = gen_sklearn_model(model_pkg_class, context.parameters.items()) model_config["FIT"].update({"X": X_train_transformed, "y": y_train}) ClassifierClass = create_class(model_config["META"]["class"]) model = ClassifierClass(**model_config["CLASS"]) with joblib.parallel_backend("dask"): model = model.fit(**model_config["FIT"]) artifact_path = context.artifact_subpath(models_dest) plots_path = context.artifact_subpath(models_dest, plots_dest) context.logger.info("Evaluate") extra_data_dict = {} for report in (ROCAUC, ClassificationReport, ConfusionMatrix): report_name = str(report.__name__) plt.cla() plt.clf() plt.close() viz = report(model, classes=classes, per_class=True, is_fitted=True) viz.fit(X_train_transformed, y_train) # Fit the training data to the visualizer viz.score(X_test_transformed, y_test.compute()) # Evaluate the model on the test data plot = context.log_artifact(PlotArtifact(report_name, body=viz.fig, title=report_name), db_key=False) extra_data_dict[str(report)] = plot if report_name == 'ROCAUC': context.log_results({ "micro": viz.roc_auc.get("micro"), "macro": viz.roc_auc.get("macro") }) elif report_name == 'ClassificationReport': for score_name in viz.scores_: for score_class in viz.scores_[score_name]: context.log_results({ score_name + "-" + score_class: viz.scores_[score_name].get(score_class) }) viz = FeatureImportances(model, classes=classes, per_class=True, is_fitted=True, labels=df_header.delete( df_header.get_loc(label_column))) viz.fit(X_train_transformed, y_train) viz.score(X_test_transformed, y_test) plot = context.log_artifact(PlotArtifact("FeatureImportances", body=viz.fig, title="FeatureImportances"), db_key=False) extra_data_dict[str("FeatureImportances")] = plot plt.cla() plt.clf() plt.close() context.logger.info("Log artifacts") artifact_path = context.artifact_subpath(models_dest) plots_path = context.artifact_subpath(models_dest, plots_dest) context.set_label('class', model_pkg_class) context.log_model("model", body=dumps(model), artifact_path=artifact_path, model_file="model.pkl", extra_data=extra_data_dict, metrics=context.results, labels={"class": model_pkg_class}) context.log_artifact("standard_scaler", body=dumps(scaler), artifact_path=artifact_path, model_file="scaler.gz", label="standard_scaler") context.log_artifact("label_encoder", body=dumps(encoder), artifact_path=artifact_path, model_file="encoder.gz", label="label_encoder") df_to_save = delayed(np.column_stack)((X_test, y_test)).compute() context.log_dataset( test_set_key, df=pd.DataFrame(df_to_save, columns=df_header), # improve log dataset ability format=file_ext, index=False, labels={"data-type": "held-out"}, artifact_path=context.artifact_subpath('data')) context.logger.info("Done!")
def permutation_importance( context: MLClientCtx, model: DataItem, dataset: DataItem, labels: str, figsz=(10, 5), plots_dest: str = "plots", fitype: str = "permute", ) -> pd.DataFrame: """calculate change in metric type 'permute' uses a pre-estimated model type 'dropcol' uses a re-estimates model :param context: the function's execution context :param model: a trained model :param dataset: features and ground truths, regression targets :param labels name of the ground truths column :param figsz: matplotlib figure size :param plots_dest: path within artifact store : """ model_file, model_data, _ = get_model(model.url, suffix=".pkl") model = load(open(str(model_file), "rb")) X = dataset.as_df() y = X.pop(labels) header = X.columns metric = _oob_classifier_accuracy baseline = metric(model, X, y) imp = [] for col in X.columns: if fitype is "permute": save = X[col].copy() X[col] = np.random.permutation(X[col]) m = metric(model, X, y) X[col] = save imp.append(baseline - m) elif fitype is "dropcol": X_ = X.drop(col, axis=1) model_ = clone(model) #model_.random_state = random_state model_.fit(X_, y) o = model_.oob_score_ imp.append(baseline - o) else: raise ValueError( "unknown fitype, only 'permute' or 'dropcol' permitted") zipped = zip(imp, header) feature_imp = pd.DataFrame(sorted(zipped), columns=["importance", "feature"]) feature_imp.sort_values(by="importance", ascending=False, inplace=True) plt.clf() plt.figure(figsize=figsz) sns.barplot(x="importance", y="feature", data=feature_imp) plt.title(f"feature importances-{fitype}") plt.tight_layout() context.log_artifact( PlotArtifact(f"feature importances-{fitype}", body=plt.gcf()), local_path=f"{plots_dest}/feature-permutations.html", ) context.log_dataset(f"feature-importances-{fitype}-tbl", df=feature_imp, index=False)
def describe( context: MLClientCtx, table: Union[DataItem, str], label_column: str, class_labels: List[str], key: str = "table-summary", ) -> None: """Summarize a table TODO: merge with dask version :param context: the function context :param table: pandas dataframe :param key: key of table summary in artifact store """ _gcf_clear(plt) base_path = context.artifact_path os.makedirs(base_path, exist_ok=True) os.makedirs(base_path + "/plots", exist_ok=True) print(f'TABLE {table}') table = pd.read_parquet(str(table)) header = table.columns.values # describe table sumtbl = table.describe() sumtbl = sumtbl.append(len(table.index) - table.count(), ignore_index=True) sumtbl.insert( 0, "metric", ["count", "mean", "std", "min", "25%", "50%", "75%", "max", "nans"]) sumtbl.to_csv(os.path.join(base_path, key + ".csv"), index=False) context.log_artifact(key, local_path=key + ".csv") # plot class balance, record relative class weight _gcf_clear(plt) labels = table.pop(label_column) class_balance_model = ClassBalance(labels=class_labels) class_balance_model.fit(labels) scale_pos_weight = class_balance_model.support_[ 0] / class_balance_model.support_[1] #context.log_artifact("scale_pos_weight", f"{scale_pos_weight:0.2f}") context.log_artifact("scale_pos_weight", str(scale_pos_weight)) class_balance_model.show( outpath=os.path.join(base_path, "plots/imbalance.png")) context.log_artifact(PlotArtifact("imbalance", body=plt.gcf()), local_path="plots/imbalance.html") # plot feature correlation _gcf_clear(plt) tblcorr = table.corr() ax = plt.axes() sns.heatmap(tblcorr, ax=ax, annot=False, cmap=plt.cm.Reds) ax.set_title("features correlation") plt.savefig(os.path.join(base_path, "plots/corr.png")) context.log_artifact(PlotArtifact("correlation", body=plt.gcf()), local_path="plots/corr.html") # plot histogram _gcf_clear(plt) """
def roc_multi(ytest_b, yprob, labels): """ """ n_classes = len(labels) # Compute ROC curve and ROC area for each class fpr = dict() tpr = dict() roc_auc = dict() for i in range(n_classes): fpr[i], tpr[i], _ = metrics.roc_curve(ytest_b[:, i], yprob[:, i]) roc_auc[i] = metrics.auc(fpr[i], tpr[i]) # Compute micro-average ROC curve and ROC area fpr["micro"], tpr["micro"], _ = metrics.roc_curve(ytest_b.ravel(), yprob.ravel()) roc_auc["micro"] = metrics.auc(fpr["micro"], tpr["micro"]) # First aggregate all false positive rates all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)])) # Then interpolate all ROC curves at this points mean_tpr = np.zeros_like(all_fpr) for i in range(n_classes): mean_tpr += interp(all_fpr, fpr[i], tpr[i]) # Finally average it and compute AUC mean_tpr /= n_classes fpr["macro"] = all_fpr tpr["macro"] = mean_tpr roc_auc["macro"] = metrics.auc(fpr["macro"], tpr["macro"]) # Plot all ROC curves #gcf_clear(plt) plt.figure() plt.plot(fpr["micro"], tpr["micro"], label='micro-average ROC curve (area = {0:0.2f})' ''.format(roc_auc["micro"]), color='deeppink', linestyle=':', linewidth=4) plt.plot(fpr["macro"], tpr["macro"], label='macro-average ROC curve (area = {0:0.2f})' ''.format(roc_auc["macro"]), color='navy', linestyle=':', linewidth=4) colors = cycle(['aqua', 'darkorange', 'cornflowerblue']) for i, color in zip(range(n_classes), colors): plt.plot(fpr[i], tpr[i], color=color, lw=2, label='ROC curve of class {0} (area = {1:0.2f})' ''.format(i, roc_auc[i])) plt.plot([0, 1], [0, 1], 'k--', lw=2) plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('receiver operating characteristic - multiclass') plt.legend(loc=(0, -.68), prop=dict(size=10)) return PlotArtifact("roc-multiclass", body=plt.gcf())
def eval_model_v2( context, xtest, ytest, model, pcurve_bins: int = 10, pcurve_names: List[str] = ["my classifier"], plots_artifact_path: str = "", pred_params: dict = {}, cmap="Blues", is_xgb=False, ): """generate predictions and validation stats pred_params are non-default, scikit-learn api prediction-function parameters. For example, a tree-type of model may have a tree depth limit for its prediction function. :param xtest: features array type Union(DataItem, DataFrame, numpy array) :param ytest: ground-truth labels Union(DataItem, DataFrame, Series, numpy array, List) :param model: estimated model :param pcurve_bins: (10) subdivide [0,1] interval into n bins, x-axis :param pcurve_names: label for each calibration curve :param pred_params: (None) dict of predict function parameters :param cmap: ('Blues') matplotlib color map :param is_xgb """ if hasattr(model, "get_xgb_params"): is_xgb = True def df_blob(df): return bytes(df.to_csv(index=False), encoding="utf-8") if isinstance(ytest, np.ndarray): unique_labels = np.unique(ytest) elif isinstance(ytest, list): unique_labels = set(ytest) else: try: ytest = ytest.values unique_labels = np.unique(ytest) except Exception as exc: raise Exception(f"unrecognized data type for ytest {exc}") n_classes = len(unique_labels) is_multiclass = True if n_classes > 2 else False # INIT DICT...OR SOME OTHER COLLECTOR THAT CAN BE ACCESSED plots_path = plots_artifact_path or context.artifact_subpath("plots") extra_data = {} ypred = model.predict(xtest, **pred_params) if isinstance(ypred.flat[0], np.floating): accuracy = mean_absolute_error(ytest, ypred) else: accuracy = float(metrics.accuracy_score(ytest, ypred)) context.log_results({ "accuracy": accuracy, "test-error": np.sum(ytest != ypred) / ytest.shape[0] }) # PROBABILITIES if hasattr(model, "predict_proba"): yprob = model.predict_proba(xtest, **pred_params) if not is_multiclass: fraction_of_positives, mean_predicted_value = calibration_curve( ytest, yprob[:, -1], n_bins=pcurve_bins, strategy="uniform") cmd = plot_calibration_curve(ytest, [yprob], pcurve_names) calibration = context.log_artifact( PlotArtifact( "probability-calibration", body=cmd.get_figure(), title="probability calibration plot", ), artifact_path=plots_path, db_key=False, ) extra_data["probability calibration"] = calibration # CONFUSION MATRIX if is_classifier(model): cm = sklearn_confusion_matrix(ytest, ypred, normalize="all") df = pd.DataFrame(data=cm) extra_data["confusion matrix table.csv"] = df_blob(df) cmd = metrics.plot_confusion_matrix( model, xtest, ytest, normalize="all", values_format=".2g", cmap=plt.get_cmap(cmap), ) confusion = context.log_artifact( PlotArtifact( "confusion-matrix", body=cmd.figure_, title="Confusion Matrix - Normalized Plot", ), artifact_path=plots_path, db_key=False, ) extra_data["confusion matrix"] = confusion # LEARNING CURVES if hasattr(model, "evals_result") and is_xgb is False: results = model.evals_result() train_set = list(results.items())[0] valid_set = list(results.items())[1] learning_curves_df = None if is_multiclass: if hasattr(train_set[1], "merror"): learning_curves_df = pd.DataFrame({ "train_error": train_set[1]["merror"], "valid_error": valid_set[1]["merror"], }) else: if hasattr(train_set[1], "error"): learning_curves_df = pd.DataFrame({ "train_error": train_set[1]["error"], "valid_error": valid_set[1]["error"], }) if learning_curves_df: extra_data["learning curve table.csv"] = df_blob( learning_curves_df) _, ax = plt.subplots() plt.xlabel("# training examples") plt.ylabel("error rate") plt.title("learning curve - error") ax.plot(learning_curves_df["train_error"], label="train") ax.plot(learning_curves_df["valid_error"], label="valid") learning = context.log_artifact( PlotArtifact("learning-curve", body=plt.gcf(), title="Learning Curve - error"), artifact_path=plots_path, db_key=False, ) extra_data["learning curve"] = learning # FEATURE IMPORTANCES if hasattr(model, "feature_importances_"): (fi_plot, fi_tbl) = feature_importances(model, xtest.columns) extra_data["feature importances"] = context.log_artifact( fi_plot, db_key=False, artifact_path=plots_path) extra_data["feature importances table.csv"] = df_blob(fi_tbl) # AUC - ROC - PR CURVES if is_multiclass and is_classifier(model): lb = LabelBinarizer() ytest_b = lb.fit_transform(ytest) extra_data["precision_recall_multi"] = context.log_artifact( precision_recall_multi(ytest_b, yprob, unique_labels), artifact_path=plots_path, db_key=False, ) extra_data["roc_multi"] = context.log_artifact( roc_multi(ytest_b, yprob, unique_labels), artifact_path=plots_path, db_key=False, ) # AUC multiclass aucmicro = metrics.roc_auc_score(ytest_b, yprob, multi_class="ovo", average="micro") aucweighted = metrics.roc_auc_score(ytest_b, yprob, multi_class="ovo", average="weighted") context.log_results({ "auc-micro": aucmicro, "auc-weighted": aucweighted }) # others (todo - macro, micro...) f1 = metrics.f1_score(ytest, ypred, average="macro") ps = metrics.precision_score(ytest, ypred, average="macro") rs = metrics.recall_score(ytest, ypred, average="macro") context.log_results({ "f1-score": f1, "precision_score": ps, "recall_score": rs }) elif is_classifier(model): yprob_pos = yprob[:, 1] extra_data["precision_recall_bin"] = context.log_artifact( precision_recall_bin(model, xtest, ytest, yprob_pos), artifact_path=plots_path, db_key=False, ) extra_data["roc_bin"] = context.log_artifact( roc_bin(ytest, yprob_pos, clear=True), artifact_path=plots_path, db_key=False, ) rocauc = metrics.roc_auc_score(ytest, yprob_pos) brier_score = metrics.brier_score_loss(ytest, yprob_pos, pos_label=ytest.max()) f1 = metrics.f1_score(ytest, ypred) ps = metrics.precision_score(ytest, ypred) rs = metrics.recall_score(ytest, ypred) context.log_results({ "rocauc": rocauc, "brier_score": brier_score, "f1-score": f1, "precision_score": ps, "recall_score": rs, }) elif is_regressor(model): r_squared = r2_score(ytest, ypred) rmse = mean_squared_error(ytest, ypred, squared=False) mse = mean_squared_error(ytest, ypred, squared=True) mae = mean_absolute_error(ytest, ypred) context.log_results({ "R2": r_squared, "root_mean_squared_error": rmse, "mean_squared_error": mse, "mean_absolute_error": mae, }) # return all model metrics and plots return extra_data
def summarize( context, dask_key: str = "dask_key", dataset: mlrun.DataItem = None, label_column: str = "label", class_labels: List[str] = [], plot_hist: bool = True, plots_dest: str = "plots", dask_function: str = None, dask_client=None, ) -> None: """Summarize a table Connects to dask client through the function context, or through an optional user-supplied scheduler. :param context: the function context :param dask_key: key of dataframe in dask client "datasets" attribute :param label_column: ground truth column label :param class_labels: label for each class in tables and plots :param plot_hist: (True) set this to False for large tables :param plots_dest: destination folder of summary plots (relative to artifact_path) :param dask_function: dask function url (db://..) :param dask_client: dask client object """ if dask_function: client = mlrun.import_function(dask_function).client elif dask_client: client = dask_client else: raise ValueError('dask client was not provided') if dask_key in client.datasets: table = client.get_dataset(dask_key) elif dataset: table = dataset.as_df(df_module=dd) else: context.logger.info( f"only these datasets are available {client.datasets} in client {client}" ) raise Exception("dataset not found on dask cluster") header = table.columns.values gcf_clear(plt) table = table.compute() snsplt = sns.pairplot(table, hue=label_column, diag_kws={'bw': 1.5}) context.log_artifact(PlotArtifact('histograms', body=plt.gcf()), local_path=f"{plots_dest}/hist.html") gcf_clear(plt) labels = table.pop(label_column) if not class_labels: class_labels = labels.unique() class_balance_model = ClassBalance(labels=class_labels) class_balance_model.fit(labels) scale_pos_weight = class_balance_model.support_[ 0] / class_balance_model.support_[1] context.log_result("scale_pos_weight", f"{scale_pos_weight:0.2f}") context.log_artifact(PlotArtifact("imbalance", body=plt.gcf()), local_path=f"{plots_dest}/imbalance.html") gcf_clear(plt) tblcorr = table.corr() ax = plt.axes() sns.heatmap(tblcorr, ax=ax, annot=False, cmap=plt.cm.Reds) ax.set_title("features correlation") context.log_artifact(PlotArtifact("correlation", body=plt.gcf()), local_path=f"{plots_dest}/corr.html") gcf_clear(plt)
def eval_class_model(context, xtest, ytest, model, plots_dest: str = "plots", pred_params: dict = {}): """generate predictions and validation stats pred_params are non-default, scikit-learn api prediction-function parameters. For example, a tree-type of model may have a tree depth limit for its prediction function. :param xtest: features array type Union(DataItem, DataFrame, np. Array) :param ytest: ground-truth labels Union(DataItem, DataFrame, Series, np. Array, List) :param model: estimated model :param pred_params: (None) dict of predict function parameters """ if isinstance(ytest, np.ndarray): unique_labels = np.unique(ytest) elif isinstance(ytest, list): unique_labels = set(ytest) else: try: ytest = ytest.values unique_labels = np.unique(ytest) except: raise Exception("unrecognized data type for ytest") n_classes = len(unique_labels) is_multiclass = True if n_classes > 2 else False # INIT DICT...OR SOME OTHER COLLECTOR THAT CAN BE ACCESSED mm_plots = [] mm_tables = [] mm = {} ypred = model.predict(xtest, **pred_params) mm.update({ "test-accuracy": float(metrics.accuracy_score(ytest, ypred)), "test-error": np.sum(ytest != ypred) / ytest.shape[0] }) # GEN PROBS (INCL CALIBRATED PROBABILITIES) if hasattr(model, "predict_proba"): yprob = model.predict_proba(xtest, **pred_params) else: # todo if decision fn... raise Exception("not implemented for this classifier") plot_calibration_curve(ytest, [yprob], ['xgboost']) context.log_artifact(PlotArtifact("calibration curve", body=plt.gcf()), local_path=f"{plots_dest}/calibration curve.html") # start evaluating: # mm_plots.extend(learning_curves(model)) if hasattr(model, "evals_result"): results = model.evals_result() train_set = list(results.items())[0] valid_set = list(results.items())[1] learning_curves = pd.DataFrame({ "train_error": train_set[1]["error"], "train_auc": train_set[1]["auc"], "valid_error": valid_set[1]["error"], "valid_auc": valid_set[1]["auc"] }) plt.clf() #gcf_clear(plt) fig, ax = plt.subplots() plt.xlabel('# training examples') plt.ylabel('auc') plt.title('learning curve - auc') ax.plot(learning_curves.train_auc, label='train') ax.plot(learning_curves.valid_auc, label='valid') legend = ax.legend(loc='lower left') context.log_artifact( PlotArtifact("learning curve - auc", body=plt.gcf()), local_path=f"{plots_dest}/learning curve - auc.html") plt.clf() #gcf_clear(plt) fig, ax = plt.subplots() plt.xlabel('# training examples') plt.ylabel('error rate') plt.title('learning curve - error') ax.plot(learning_curves.train_error, label='train') ax.plot(learning_curves.valid_error, label='valid') legend = ax.legend(loc='lower left') context.log_artifact( PlotArtifact("learning curve - erreur", body=plt.gcf()), local_path=f"{plots_dest}/learning curve - erreur.html") (fi_plot, fi_tbl) = feature_importances(model, xtest.columns) mm_plots.append(fi_plot) mm_tables.append(fi_tbl) mm_plots.append(confusion_matrix(model, xtest, ytest)) if is_multiclass: lb = LabelBinarizer() ytest_b = lb.fit_transform(ytest) mm_plots.append(precision_recall_multi(ytest_b, yprob, unique_labels)) mm_plots.append(roc_multi(ytest_b, yprob, unique_labels)) # AUC multiclass mm.update({ "auc-micro": metrics.roc_auc_score(ytest_b, yprob, multi_class="ovo", average="micro"), "auc-weighted": metrics.roc_auc_score(ytest_b, yprob, multi_class="ovo", average="weighted") }) # others (todo - macro, micro...) mm.update({ "f1-score": metrics.f1_score(ytest, ypred, average="micro"), "precision_score": metrics.precision_score(ytest, ypred, average="micro"), "recall_score": metrics.recall_score(ytest, ypred, average="micro") }) else: # extract the positive label yprob_pos = yprob[:, 1] mm_plots.append(roc_bin(ytest, yprob_pos)) mm_plots.append(precision_recall_bin(model, xtest, ytest, yprob_pos)) mm.update({ "rocauc": metrics.roc_auc_score(ytest, yprob_pos), "brier_score": metrics.brier_score_loss(ytest, yprob_pos, pos_label=ytest.max()), "f1-score": metrics.f1_score(ytest, ypred), "precision_score": metrics.precision_score(ytest, ypred), "recall_score": metrics.recall_score(ytest, ypred) }) # return all model metrics and plots mm.update({"plots": mm_plots, "tables": mm_tables}) return mm
def roc_multi(ytest_b, yprob, labels): """""" n_classes = len(labels) # Compute ROC curve and ROC area for each class fpr = dict() tpr = dict() roc_auc = dict() for i in range(n_classes): fpr[i], tpr[i], _ = metrics.roc_curve(ytest_b[:, i], yprob[:, i]) roc_auc[i] = metrics.auc(fpr[i], tpr[i]) # Compute micro-average ROC curve and ROC area fpr["micro"], tpr["micro"], _ = metrics.roc_curve(ytest_b.ravel(), yprob.ravel()) roc_auc["micro"] = metrics.auc(fpr["micro"], tpr["micro"]) # First aggregate all false positive rates all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)])) # Then interpolate all ROC curves at this points mean_tpr = np.zeros_like(all_fpr) for i in range(n_classes): mean_tpr += interp(all_fpr, fpr[i], tpr[i]) # Finally average it and compute AUC mean_tpr /= n_classes fpr["macro"] = all_fpr tpr["macro"] = mean_tpr roc_auc["macro"] = metrics.auc(fpr["macro"], tpr["macro"]) # Plot all ROC curves gcf_clear(plt) plt.figure() plt.plot( fpr["micro"], tpr["micro"], label=f"micro-average ROC curve (area = {roc_auc['micro']:0.2f})", color="deeppink", linestyle=":", linewidth=4, ) plt.plot( fpr["macro"], tpr["macro"], label=f"macro-average ROC curve (area = {roc_auc['macro']:0.2f})", color="navy", linestyle=":", linewidth=4, ) colors = cycle(["aqua", "darkorange", "cornflowerblue"]) for i, color in zip(range(n_classes), colors): plt.plot( fpr[i], tpr[i], color=color, lw=2, label=f"ROC curve of class {i} (area = {roc_auc[i]:0.2f})", ) plt.plot([0, 1], [0, 1], "k--", lw=2) plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel("False Positive Rate") plt.ylabel("True Positive Rate") plt.title("receiver operating characteristic - multiclass") plt.legend(loc=(0, -0.68), prop=dict(size=10)) return PlotArtifact("roc-multiclass", body=plt.gcf(), title="Multiclass ROC Curve")