def explain_text(text, probas, target_names=None): """ ExplainText adds support for eli5's LIME based TextExplainer. Arguments: text (str): Text to explain probas (black-box classification pipeline): A function which takes a list of strings (documents) and returns a matrix of shape (n_samples, n_classes) with probability values, i.e. a row per document and a column per output label. Returns: Nothing. To see plots, go to your W&B run page. Example: wandb.log({'roc': wandb.plots.ExplainText(text, probas)}) """ deprecation_notice() eli5 = util.get_module( "eli5", required= "explain_text requires the eli5 library, install with `pip install eli5`", ) if test_missing(text=text, probas=probas): # and test_types(proba=proba)): wandb.termlog("Visualizing TextExplainer.") te = eli5.lime.TextExplainer(random_state=42) te.fit(text, probas) html = te.show_prediction(target_names=target_names) return wandb.Html(html.data)
def part_of_speech(docs): """ Adds support for spaCy's dependency visualizer which shows part-of-speech tags and syntactic dependencies. Arguments: docs (list, Doc, Span): Document(s) to visualize. Returns: Nothing. To see plots, go to your W&B run page. Example: wandb.log({'part_of_speech': wandb.plots.POS(docs=doc)}) """ deprecation_notice() spacy = util.get_module( "spacy", required= "part_of_speech requires the spacy library, install with `pip install spacy`" ) en_core_web_md = util.get_module( "en_core_web_md", required= "part_of_speech requires the en_core_web_md library, install with `python -m spacy download en_core_web_md`" ) nlp = en_core_web_md.load() if (test_missing(docs=docs)): #and test_types(docs=docs)): wandb.termlog('Visualizing part of speech.') options = { "compact": True, "color": "#1a1c1f", "font": "Source Sans Pro", "collapse_punct": True, "collapse_phrases": True } html = spacy.displacy.render(nlp(str(docs)), style='dep', minify=True, options=options, page=True) return wandb.Html(html)
def named_entity(docs): """ Adds support for spaCy's entity visualizer, which highlights named entities and their labels in a text. Arguments: docs (list, Doc, Span): Document(s) to visualize. Returns: Nothing. To see plots, go to your W&B run page. Example: wandb.log({'NER': wandb.plots.NER(docs=doc)}) """ deprecation_notice() spacy = util.get_module( "spacy", required= "part_of_speech requires the spacy library, install with `pip install spacy`", ) en_core_web_md = util.get_module( "en_core_web_md", required= "part_of_speech requires the en_core_web_md library, install with `python -m spacy download en_core_web_md`", ) nlp = en_core_web_md.load() if test_missing(docs=docs): # and test_types(docs=docs)): wandb.termlog("Visualizing named entity recognition.") html = spacy.displacy.render(nlp(str(docs)), style="ent", page=True, minify=True) return wandb.Html(html)
def heatmap(x_labels, y_labels, matrix_values, show_text=False): """ Generates a heatmap. Arguments: matrix_values (arr): 2D dataset of shape x_labels * y_labels, containing heatmap values that can be coerced into an ndarray. x_labels (list): Named labels for rows (x_axis). y_labels (list): Named labels for columns (y_axis). show_text (bool): Show text values in heatmap cells. Returns: Nothing. To see plots, go to your W&B run page then expand the 'media' tab under 'auto visualizations'. Example: wandb.log({'heatmap': wandb.plots.HeatMap(x_labels, y_labels, matrix_values)}) """ deprecation_notice() np = util.get_module( "numpy", required= "roc requires the numpy library, install with `pip install numpy`", ) scikit = util.get_module( "sklearn", required= "roc requires the scikit library, install with `pip install scikit-learn`", ) if test_missing(x_labels=x_labels, y_labels=y_labels, matrix_values=matrix_values) and test_types( x_labels=x_labels, y_labels=y_labels, matrix_values=matrix_values): matrix_values = np.array(matrix_values) wandb.termlog("Visualizing heatmap.") def heatmap_table(x_labels, y_labels, matrix_values, show_text): x_axis = [] y_axis = [] values = [] count = 0 for i, x in enumerate(x_labels): for j, y in enumerate(y_labels): x_axis.append(x) y_axis.append(y) values.append(matrix_values[j][i]) count += 1 if count >= chart_limit: wandb.termwarn( "wandb uses only the first %d datapoints to create the plots." % wandb.Table.MAX_ROWS) break heatmap_key = "wandb/heatmap/v1" if show_text else "wandb/heatmap_no_text/v1" return wandb.visualize( heatmap_key, wandb.Table( columns=["x_axis", "y_axis", "values"], data=[[x_axis[i], y_axis[i], round(values[i], 2)] for i in range(len(x_axis))], ), ) return heatmap_table(x_labels, y_labels, matrix_values, show_text)
def roc( y_true=None, y_probas=None, labels=None, plot_micro=True, plot_macro=True, classes_to_plot=None, ): """ Calculates receiver operating characteristic scores and visualizes them as the ROC curve. Arguments: y_true (arr): Test set labels. y_probas (arr): Test set predicted probabilities. labels (list): Named labels for target varible (y). Makes plots easier to read by replacing target values with corresponding index. For example labels= ['dog', 'cat', 'owl'] all 0s are replaced by 'dog', 1s by 'cat'. Returns: Nothing. To see plots, go to your W&B run page then expand the 'media' tab under 'auto visualizations'. Example: wandb.log({'roc': wandb.plots.ROC(y_true, y_probas, labels)}) """ deprecation_notice() np = util.get_module( "numpy", required= "roc requires the numpy library, install with `pip install numpy`", ) sklearn = util.get_module( "sklearn", required= "roc requires the scikit library, install with `pip install scikit-learn`", ) from sklearn.metrics import roc_curve, auc if test_missing(y_true=y_true, y_probas=y_probas) and test_types( y_true=y_true, y_probas=y_probas): y_true = np.array(y_true) y_probas = np.array(y_probas) classes = np.unique(y_true) probas = y_probas if classes_to_plot is None: classes_to_plot = classes fpr_dict = dict() tpr_dict = dict() indices_to_plot = np.in1d(classes, classes_to_plot) def roc_table(fpr_dict, tpr_dict, classes, indices_to_plot): data = [] count = 0 for i, to_plot in enumerate(indices_to_plot): if to_plot: fpr_dict[i], tpr_dict[i], _ = roc_curve( y_true, probas[:, i], pos_label=classes[i]) roc_auc = auc(fpr_dict[i], tpr_dict[i]) for j in range(len(fpr_dict[i])): if labels is not None and ( isinstance(classes[i], int) or isinstance(classes[0], np.integer)): class_dict = labels[classes[i]] else: class_dict = classes[i] fpr = [ class_dict, round(fpr_dict[i][j], 3), round(tpr_dict[i][j], 3), ] data.append(fpr) count += 1 if count >= chart_limit: wandb.termwarn( "wandb uses only the first %d datapoints to create the plots." % wandb.Table.MAX_ROWS) break return wandb.visualize( "wandb/roc/v1", wandb.Table(columns=["class", "fpr", "tpr"], data=data)) return roc_table(fpr_dict, tpr_dict, classes, indices_to_plot)
def precision_recall(y_true=None, y_probas=None, labels=None, plot_micro=True, classes_to_plot=None): """ Computes the tradeoff between precision and recall for different thresholds. A high area under the curve represents both high recall and high precision, where high precision relates to a low false positive rate, and high recall relates to a low false negative rate. High scores for both show that the classifier is returning accurate results (high precision), as well as returning a majority of all positive results (high recall). PR curve is useful when the classes are very imbalanced. Arguments: y_true (arr): Test set labels. y_probas (arr): Test set predicted probabilities. labels (list): Named labels for target varible (y). Makes plots easier to read by replacing target values with corresponding index. For example labels= ['dog', 'cat', 'owl'] all 0s are replaced by 'dog', 1s by 'cat'. Returns: Nothing. To see plots, go to your W&B run page then expand the 'media' tab under 'auto visualizations'. Example: wandb.log({'pr': wandb.plots.precision_recall(y_true, y_probas, labels)}) """ deprecation_notice() np = util.get_module( "numpy", required= "roc requires the numpy library, install with `pip install numpy`", ) scikit = util.get_module( "sklearn", "roc requires the scikit library, install with `pip install scikit-learn`", ) y_true = np.array(y_true) y_probas = np.array(y_probas) if test_missing(y_true=y_true, y_probas=y_probas) and test_types( y_true=y_true, y_probas=y_probas): classes = np.unique(y_true) probas = y_probas if classes_to_plot is None: classes_to_plot = classes binarized_y_true = scikit.preprocessing.label_binarize(y_true, classes=classes) if len(classes) == 2: binarized_y_true = np.hstack( (1 - binarized_y_true, binarized_y_true)) pr_curves = {} indices_to_plot = np.in1d(classes, classes_to_plot) for i, to_plot in enumerate(indices_to_plot): if to_plot: average_precision = scikit.metrics.average_precision_score( binarized_y_true[:, i], probas[:, i]) precision, recall, _ = scikit.metrics.precision_recall_curve( y_true, probas[:, i], pos_label=classes[i]) samples = 20 sample_precision = [] sample_recall = [] for k in range(samples): sample_precision.append(precision[int( len(precision) * k / samples)]) sample_recall.append(recall[int(len(recall) * k / samples)]) pr_curves[classes[i]] = (sample_precision, sample_recall) def pr_table(pr_curves): data = [] count = 0 for class_name in pr_curves.keys(): precision, recall = pr_curves[class_name] for p, r in zip(precision, recall): # if class_names are ints and labels are set if labels is not None and isinstance( class_name, (int, np.integer)): class_name = labels[class_name] # if class_names are ints and labels are not set # or, if class_names have something other than ints # (string, float, date) - user class_names data.append([class_name, round(p, 3), round(r, 3)]) count += 1 if count >= chart_limit: wandb.termwarn( "wandb uses only the first %d datapoints to create the plots." % wandb.Table.MAX_ROWS) break return wandb.visualize( "wandb/pr_curve/v1", wandb.Table(columns=["class", "precision", "recall"], data=data), ) return pr_table(pr_curves)