def figure() -> Figure: p = plotting.figure( x_range=(-0.5, -0.5 + len(class_names)), y_range=(-0.5, -0.5 + len(class_names)), plot_height=350, plot_width=350, tools=TOOLS, toolbar_location=TOOLBAR_LOCATION, match_aspect=True, ) def noise() -> np.ndarray: return (np.random.beta(1, 1, size=len(y_true)) - 0.5) * 0.6 p.scatter( x=y_true + noise(), y=y_pred + noise(), size=scatter_plot_circle_size( num_points=len(y_true), biggest=4.0, smallest=1.0, use_smallest_when_num_points_at_least=5000, ), color=DARK_BLUE, fill_alpha=SCATTER_CIRCLES_FILL_ALPHA, line_alpha=SCATTER_CIRCLES_LINE_ALPHA, ) add_title_rows(p, title_rows) apply_default_style(p) p.xaxis.axis_label = "Ground Truth" p.yaxis.axis_label = "Prediction" arange = np.arange(len(class_names)) p.xaxis.ticker = arange p.yaxis.ticker = arange p.xaxis.major_label_overrides = {i: name for i, name in enumerate(class_names)} p.yaxis.major_label_overrides = {i: name for i, name in enumerate(class_names)} p.xaxis.major_label_orientation = x_label_rotation p.yaxis.major_label_orientation = y_label_rotation # grid between classes, not at classes p.xgrid.ticker = arange[0:-1] + 0.5 p.ygrid.ticker = arange[0:-1] + 0.5 p.xgrid.grid_line_width = 3 p.ygrid.grid_line_width = 3 # prevent panning to empty regions p.x_range.bounds = (-0.5, -0.5 + len(class_names)) p.y_range.bounds = (-0.5, -0.5 + len(class_names)) return p
def figure() -> Figure: source = ColumnDataSource( data={ "FPR": fpr, "TPR": tpr, "threshold": thresholds, "specificity": 1.0 - fpr, }) p = plotting.figure( plot_height=400, plot_width=350, tools=TOOLS, toolbar_location=TOOLBAR_LOCATION, # toolbar_location=None, # hides entire toolbar match_aspect=True, ) p.xaxis.axis_label = "FPR" p.yaxis.axis_label = "TPR" add_title_rows(p, title_rows) apply_default_style(p) curve = p.line(x="FPR", y="TPR", line_width=2, color=DARK_BLUE, source=source) p.line( x=[0.0, 1.0], y=[0.0, 1.0], line_alpha=0.75, color="grey", line_dash="dotted", ) p.add_tools( HoverTool( # make sure there is no tool tip for the diagonal baseline renderers=[curve], tooltips=[ ("TPR", "@TPR"), ("FPR", "@FPR"), ("Sensitivity", "@TPR"), ("Specificity", "@specificity"), ("Threshold", "@threshold"), ], # display a tooltip whenever the cursor is vertically in line with a glyph mode="vline", )) return p
def figure() -> Figure: p = plotting.figure( x_range=class_names, plot_height=350, plot_width=350, tools=TOOLS, toolbar_location=TOOLBAR_LOCATION, ) # class distribution in prediction p.vbar( x=class_names, top=np.histogram(y_pred, bins=bins, weights=sample_weights, density=normalize)[0], width=0.85, color=DARK_BLUE, alpha=HISTOGRAM_ALPHA, legend_label="Prediction", ) # class distribution in ground truth p.vbar( x=class_names, top=np.histogram(y_true, bins=bins, weights=sample_weights, density=normalize)[0], width=0.85, alpha=0.6, legend_label="Ground Truth", fill_color=None, line_color="black", line_width=2.5, ) add_title_rows(p, title_rows) apply_default_style(p) p.yaxis.axis_label = ("Fraction of Instances" if normalize else "Number of Instances") p.xaxis.major_label_orientation = x_label_rotation p.xgrid.grid_line_color = None # prevent panning to empty regions p.x_range.bounds = (-0.5, 0.5 + len(class_names)) return p
def figure() -> Figure: p = plotting.figure( plot_height=400, plot_width=350, x_range=(-0.05, 1.05), y_range=(-0.05, 1.05), tools=TOOLS, toolbar_location=TOOLBAR_LOCATION, # match_aspect=True, ) source = ColumnDataSource(data={ "precision": precision, "recall": recall, "threshold": thresholds }) # reminder: tpr == recall == sensitivity p.line(x="recall", y="precision", line_width=2, source=source) add_title_rows(p, title_rows) apply_default_style(p) p.xaxis.axis_label = "Recall" p.yaxis.axis_label = "Precision" p.add_tools( HoverTool( tooltips=[ ("Precision", "@precision"), ("Recall", "@recall"), ("Threshold", "@threshold"), ], # display a tooltip whenever the cursor is vertically in line with a glyph mode="vline", )) return p
def figure() -> Figure: # ----- bokeh plot ----- p = plotting.figure( plot_height=400, plot_width=350, x_range=(-0.05, 1.05), y_range=(-0.05, 1.05), tools=TOOLS, toolbar_location=TOOLBAR_LOCATION, # match_aspect=True, ) source = ColumnDataSource( data={key: np.array(lst) for key, lst in chart_data.items()}) accuracy_line = p.line( x="automation_rate", y="accuracy", line_width=2, color=DARK_BLUE, source=source, legend_label="Accuracy", ) p.line( x="automation_rate", y="threshold", line_width=2, color="grey", source=source, legend_label="Threshold", ) # make sure something is visible if lines consist of just a single point p.scatter( x=source.data["automation_rate"][[0, -1]], y=source.data["accuracy"][[0, -1]], ) p.scatter( x=source.data["automation_rate"][[0, -1]], y=source.data["threshold"][[0, -1]], color="grey", ) add_title_rows(p, title_rows) apply_default_style(p) p.xaxis.axis_label = "Automation Rate" p.legend.location = "bottom_left" p.add_tools( HoverTool( renderers=[accuracy_line], tooltips=[ ("Accuracy", "@accuracy"), ("Threshold", "@threshold"), ("Automation Rate", "@automation_rate"), ], # display a tooltip whenever the cursor is vertically in line with a glyph mode="vline", )) return p
def figure() -> Figure: source = ColumnDataSource( data={ "predicted": np.array(predicted), "actual": np.array(actual), "count": np.array(count), "normalized": np.array(normalized), "normalized_by_pred": np.array(normalized_by_pred), "normalized_by_true": np.array(normalized_by_true), }) p = plotting.figure(tools=TOOLS, x_range=class_names, y_range=class_names) mapper = LinearColorMapper(palette="Viridis256", low=0.0, high=1.0) p.rect( x="actual", y="predicted", width=0.95, height=0.95, source=source, fill_color={ "field": "normalized_by_true", "transform": mapper }, line_width=0, line_color="black", ) p.xaxis.axis_label = "Ground Truth" p.yaxis.axis_label = "Prediction" p.xaxis.major_label_orientation = x_label_rotation p.yaxis.major_label_orientation = y_label_rotation p.add_tools( HoverTool(tooltips=[ ("Predicted", "@predicted"), ("Ground truth", "@actual"), ("Count", "@count"), ("Normalized", "@normalized"), ("Normalized by prediction", "@normalized_by_pred"), ("Normalized by ground truth", "@normalized_by_true"), ])) color_bar = ColorBar( color_mapper=mapper, major_label_text_font_size=FONT_SIZE, ticker=BasicTicker(desired_num_ticks=10), formatter=PrintfTickFormatter(format="%.1f"), label_standoff=5, border_line_color=None, location=(0, 0), ) p.add_layout(color_bar, "right") add_title_rows(p, title_rows) apply_default_style(p) return p