Exemplo n.º 1
0
    def function() -> Column:
        targets = data.targets
        predictions = data.predictions

        p = figure(tools=TOOLS)

        residuals = np.asarray(targets) - np.asarray(predictions)

        residual_hist, edges = np.histogram(residuals, bins=n_bins)

        p.quad(
            top=residual_hist,
            bottom=0,
            left=edges[:-1],
            right=edges[1:],
            fill_color=DARK_RED,
            fill_alpha=0.5,
            line_color="white",
            line_alpha=0.5,
        )

        apply_default_style(p)
        p.xaxis.axis_label = "Residual"
        p.yaxis.axis_label = "Number of Occurrences"

        return column(title_div(title_rows), p, sizing_mode="scale_width")
    def figure() -> Figure:

        p = plotting.figure(
            x_range=(-0.5, -0.5 + len(class_names)),
            y_range=(-0.5, -0.5 + len(class_names)),
            plot_height=350,
            plot_width=350,
            tools=TOOLS,
            toolbar_location=TOOLBAR_LOCATION,
            match_aspect=True,
        )

        def noise() -> np.ndarray:
            return (np.random.beta(1, 1, size=len(y_true)) - 0.5) * 0.6

        p.scatter(
            x=y_true + noise(),
            y=y_pred + noise(),
            size=scatter_plot_circle_size(
                num_points=len(y_true),
                biggest=4.0,
                smallest=1.0,
                use_smallest_when_num_points_at_least=5000,
            ),
            color=DARK_BLUE,
            fill_alpha=SCATTER_CIRCLES_FILL_ALPHA,
            line_alpha=SCATTER_CIRCLES_LINE_ALPHA,
        )

        add_title_rows(p, title_rows)
        apply_default_style(p)

        p.xaxis.axis_label = "Ground Truth"
        p.yaxis.axis_label = "Prediction"

        arange = np.arange(len(class_names))

        p.xaxis.ticker = arange
        p.yaxis.ticker = arange

        p.xaxis.major_label_overrides = {i: name for i, name in enumerate(class_names)}
        p.yaxis.major_label_overrides = {i: name for i, name in enumerate(class_names)}

        p.xaxis.major_label_orientation = x_label_rotation
        p.yaxis.major_label_orientation = y_label_rotation

        # grid between classes, not at classes
        p.xgrid.ticker = arange[0:-1] + 0.5
        p.ygrid.ticker = arange[0:-1] + 0.5

        p.xgrid.grid_line_width = 3
        p.ygrid.grid_line_width = 3

        # prevent panning to empty regions
        p.x_range.bounds = (-0.5, -0.5 + len(class_names))
        p.y_range.bounds = (-0.5, -0.5 + len(class_names))

        return p
    def figure() -> Figure:
        source = ColumnDataSource(
            data={
                "FPR": fpr,
                "TPR": tpr,
                "threshold": thresholds,
                "specificity": 1.0 - fpr,
            })

        p = plotting.figure(
            plot_height=400,
            plot_width=350,
            tools=TOOLS,
            toolbar_location=TOOLBAR_LOCATION,
            # toolbar_location=None,  # hides entire toolbar
            match_aspect=True,
        )

        p.xaxis.axis_label = "FPR"
        p.yaxis.axis_label = "TPR"

        add_title_rows(p, title_rows)
        apply_default_style(p)

        curve = p.line(x="FPR",
                       y="TPR",
                       line_width=2,
                       color=DARK_BLUE,
                       source=source)
        p.line(
            x=[0.0, 1.0],
            y=[0.0, 1.0],
            line_alpha=0.75,
            color="grey",
            line_dash="dotted",
        )

        p.add_tools(
            HoverTool(
                # make sure there is no tool tip for the diagonal baseline
                renderers=[curve],
                tooltips=[
                    ("TPR", "@TPR"),
                    ("FPR", "@FPR"),
                    ("Sensitivity", "@TPR"),
                    ("Specificity", "@specificity"),
                    ("Threshold", "@threshold"),
                ],
                # display a tooltip whenever the cursor is vertically in line with a glyph
                mode="vline",
            ))

        return p
    def figure() -> Figure:
        p = plotting.figure(
            x_range=class_names,
            plot_height=350,
            plot_width=350,
            tools=TOOLS,
            toolbar_location=TOOLBAR_LOCATION,
        )

        # class distribution in prediction
        p.vbar(
            x=class_names,
            top=np.histogram(y_pred,
                             bins=bins,
                             weights=sample_weights,
                             density=normalize)[0],
            width=0.85,
            color=DARK_BLUE,
            alpha=HISTOGRAM_ALPHA,
            legend_label="Prediction",
        )

        # class distribution in ground truth
        p.vbar(
            x=class_names,
            top=np.histogram(y_true,
                             bins=bins,
                             weights=sample_weights,
                             density=normalize)[0],
            width=0.85,
            alpha=0.6,
            legend_label="Ground Truth",
            fill_color=None,
            line_color="black",
            line_width=2.5,
        )

        add_title_rows(p, title_rows)
        apply_default_style(p)

        p.yaxis.axis_label = ("Fraction of Instances"
                              if normalize else "Number of Instances")

        p.xaxis.major_label_orientation = x_label_rotation

        p.xgrid.grid_line_color = None

        # prevent panning to empty regions
        p.x_range.bounds = (-0.5, 0.5 + len(class_names))
        return p
Exemplo n.º 5
0
    def function() -> Column:
        targets = data.targets
        predictions = data.predictions

        p = figure(tools=TOOLS)

        # Plot looks better when we use the same bin sizes for both histograms.
        _, edges = np.histogram(list(targets) + list(predictions), bins=n_bins)

        pred_hist, pred_edges = np.histogram(predictions, bins=edges)
        np.testing.assert_allclose(pred_edges, edges)
        p.quad(
            top=pred_hist,
            bottom=0,
            left=pred_edges[:-1],
            right=pred_edges[1:],
            fill_color=DARK_BLUE,
            fill_alpha=0.5,
            line_color="white",
            line_alpha=0.5,
            legend_label="Predicted  ",
        )

        target_hist, target_edges = np.histogram(targets, bins=edges)
        np.testing.assert_allclose(target_edges, edges)
        p.line(
            x=[e for e in target_edges for _ in range(2)],
            y=[0.0] + [h for h in target_hist for _ in range(2)] + [0.0],
            color=GROUND_TRUTH_HISTOGRAM_ENVELOPE_COLOR,
            alpha=GROUND_TRUTH_HISTOGRAM_ENVELOPE_ALPHA,
            legend_label="Actual  ",
        )

        # Increase the y range a bit to leave enough space for the legend
        p.y_range.end = 1.25 * max(max(target_hist), max(pred_hist))

        apply_default_style(p)
        p.xaxis.axis_label = "Value"
        p.yaxis.axis_label = "Number of Occurrences"
        p.legend.padding = 4
        p.legend.orientation = "horizontal"

        return column(title_div(title_rows), p, sizing_mode="scale_width")
    def figure() -> Figure:

        p = plotting.figure(
            plot_height=400,
            plot_width=350,
            x_range=(-0.05, 1.05),
            y_range=(-0.05, 1.05),
            tools=TOOLS,
            toolbar_location=TOOLBAR_LOCATION,
            # match_aspect=True,
        )

        source = ColumnDataSource(data={
            "precision": precision,
            "recall": recall,
            "threshold": thresholds
        })

        # reminder: tpr == recall == sensitivity
        p.line(x="recall", y="precision", line_width=2, source=source)

        add_title_rows(p, title_rows)
        apply_default_style(p)

        p.xaxis.axis_label = "Recall"
        p.yaxis.axis_label = "Precision"

        p.add_tools(
            HoverTool(
                tooltips=[
                    ("Precision", "@precision"),
                    ("Recall", "@recall"),
                    ("Threshold", "@threshold"),
                ],
                # display a tooltip whenever the cursor is vertically in line with a glyph
                mode="vline",
            ))

        return p
Exemplo n.º 7
0
def _bokeh_scatter_with_histograms(
    x: Floats,
    y: Floats,
    x_label: str,
    y_label: str,
    x_histogram_fill_color: Optional[str],
    y_histogram_fill_color: Optional[str],
    y_histogram_envelope_color: Optional[str],
    title_rows: Sequence[str],
    n_bins: int,
) -> Tuple[LayoutDOM, Figure]:
    """
    Scatter plot with small histograms attached to the axes.
    """
    assert len(x) == len(y)

    # --- Create the scatter plot ---
    p_scatter = figure(
        tools=TOOLS,
        plot_width=250,
        plot_height=250,
        min_border=0,
        min_border_left=20,
        sizing_mode="stretch_width",
    )

    p_scatter.scatter(
        x,
        y,
        color=DARK_BLUE,
        size=(scatter_plot_circle_size(
            len(x),
            biggest=3,
            smallest=1,
            use_smallest_when_num_points_at_least=5000,
        )),
        fill_alpha=SCATTER_CIRCLES_FILL_ALPHA,
        line_alpha=SCATTER_CIRCLES_LINE_ALPHA,
    )

    p_scatter.xaxis.axis_label = x_label
    p_scatter.yaxis.axis_label = y_label

    # --- Compute histograms ---
    hist_x, hist_x_edges = np.histogram(x, bins=n_bins)
    hist_y, hist_y_edges = np.histogram(y, bins=n_bins)
    # Make sure the two histograms are similarly scaled
    hist_range = (-0.1, 1.05 * max(max(hist_x), max(hist_y)))

    # --- Create the horizontal (ground truth) histogram, above the scatter plot ---
    p_hist_x_above = figure(
        plot_width=p_scatter.plot_width,
        plot_height=30,
        x_range=p_scatter.x_range,
        y_range=hist_range,
        min_border=0,
        min_border_left=0,
        x_axis_location=None,
        y_axis_location=None,
        sizing_mode="stretch_width",
        tools=TOOLS,
    )
    p_hist_x_above.quad(
        bottom=0,
        left=hist_x_edges[:-1],
        right=hist_x_edges[1:],
        top=hist_x,
        color=x_histogram_fill_color,
        alpha=HISTOGRAM_ALPHA,
        line_color="white",
    )

    # --- Create the vertical (prediction) histogram, to the right of the scatter plot ---
    p_hist_y_right = figure(
        plot_width=40,
        plot_height=p_scatter.plot_height,
        x_range=hist_range,
        y_range=p_scatter.y_range,
        x_axis_location=None,
        y_axis_location=None,
        sizing_mode="fixed",
        tools=TOOLS,
    )
    p_hist_y_right.quad(
        left=0,
        bottom=hist_y_edges[:-1],
        top=hist_y_edges[1:],
        right=hist_y,
        color=y_histogram_fill_color,
        alpha=HISTOGRAM_ALPHA,
        line_color="white" if y_histogram_fill_color is not None else None,
    )

    if y_histogram_envelope_color is not None:
        p_hist_y_right.line(
            x=[0.0] + [h for h in hist_y for _ in range(2)] + [0.0],
            y=[e for e in hist_y_edges for _ in range(2)],
            color=GROUND_TRUTH_HISTOGRAM_ENVELOPE_COLOR,
            alpha=GROUND_TRUTH_HISTOGRAM_ENVELOPE_ALPHA,
        )

    # --- Style and layout ---
    apply_default_style(p_scatter)
    apply_default_style(p_hist_x_above)
    apply_default_style(p_hist_y_right)
    # Overwrite some of the defaults
    p_hist_x_above.ygrid.grid_line_color = None
    p_hist_y_right.xgrid.grid_line_color = None

    grid = gridplot(
        # fmt:off
        [[p_hist_x_above, None], [p_scatter, p_hist_y_right]],
        merge_tools=True,
        toolbar_location="right",
        toolbar_options={
            "logo": None,
        },
        sizing_mode="scale_width",
        # fmt:on
    )

    layout = column(title_div(title_rows), grid, sizing_mode="scale_width")

    return layout, p_scatter
    def figure() -> Figure:

        # ----- bokeh plot -----
        p = plotting.figure(
            plot_height=400,
            plot_width=350,
            x_range=(-0.05, 1.05),
            y_range=(-0.05, 1.05),
            tools=TOOLS,
            toolbar_location=TOOLBAR_LOCATION,
            # match_aspect=True,
        )

        source = ColumnDataSource(
            data={key: np.array(lst)
                  for key, lst in chart_data.items()})

        accuracy_line = p.line(
            x="automation_rate",
            y="accuracy",
            line_width=2,
            color=DARK_BLUE,
            source=source,
            legend_label="Accuracy",
        )

        p.line(
            x="automation_rate",
            y="threshold",
            line_width=2,
            color="grey",
            source=source,
            legend_label="Threshold",
        )

        # make sure something is visible if lines consist of just a single point
        p.scatter(
            x=source.data["automation_rate"][[0, -1]],
            y=source.data["accuracy"][[0, -1]],
        )
        p.scatter(
            x=source.data["automation_rate"][[0, -1]],
            y=source.data["threshold"][[0, -1]],
            color="grey",
        )

        add_title_rows(p, title_rows)
        apply_default_style(p)

        p.xaxis.axis_label = "Automation Rate"
        p.legend.location = "bottom_left"

        p.add_tools(
            HoverTool(
                renderers=[accuracy_line],
                tooltips=[
                    ("Accuracy", "@accuracy"),
                    ("Threshold", "@threshold"),
                    ("Automation Rate", "@automation_rate"),
                ],
                # display a tooltip whenever the cursor is vertically in line with a glyph
                mode="vline",
            ))

        return p
    def figure() -> Figure:

        source = ColumnDataSource(
            data={
                "predicted": np.array(predicted),
                "actual": np.array(actual),
                "count": np.array(count),
                "normalized": np.array(normalized),
                "normalized_by_pred": np.array(normalized_by_pred),
                "normalized_by_true": np.array(normalized_by_true),
            })

        p = plotting.figure(tools=TOOLS,
                            x_range=class_names,
                            y_range=class_names)

        mapper = LinearColorMapper(palette="Viridis256", low=0.0, high=1.0)

        p.rect(
            x="actual",
            y="predicted",
            width=0.95,
            height=0.95,
            source=source,
            fill_color={
                "field": "normalized_by_true",
                "transform": mapper
            },
            line_width=0,
            line_color="black",
        )

        p.xaxis.axis_label = "Ground Truth"
        p.yaxis.axis_label = "Prediction"

        p.xaxis.major_label_orientation = x_label_rotation
        p.yaxis.major_label_orientation = y_label_rotation

        p.add_tools(
            HoverTool(tooltips=[
                ("Predicted", "@predicted"),
                ("Ground truth", "@actual"),
                ("Count", "@count"),
                ("Normalized", "@normalized"),
                ("Normalized by prediction", "@normalized_by_pred"),
                ("Normalized by ground truth", "@normalized_by_true"),
            ]))

        color_bar = ColorBar(
            color_mapper=mapper,
            major_label_text_font_size=FONT_SIZE,
            ticker=BasicTicker(desired_num_ticks=10),
            formatter=PrintfTickFormatter(format="%.1f"),
            label_standoff=5,
            border_line_color=None,
            location=(0, 0),
        )
        p.add_layout(color_bar, "right")

        add_title_rows(p, title_rows)
        apply_default_style(p)

        return p