Exemplo n.º 1
0
def density_scatter(
    xs: NumArray,
    ys: NumArray,
    ax: Axes = None,
    color_map: str = "Blues",
    sort: bool = True,
    log: bool = True,
    density_bins: int = 100,
    xlabel: str = "Actual",
    ylabel: str = "Predicted",
    identity: bool = True,
    stats: bool = True,
    **kwargs: Any,
) -> Axes:
    """Scatter plot colored (and optionally sorted) by density.

    Args:
        xs (array): x values.
        ys (array): y values.
        ax (Axes, optional): matplotlib Axes on which to plot. Defaults to None.
        color_map (str, optional): plt color map or valid string name.
            Defaults to "Blues".
        sort (bool, optional): Whether to sort the data. Defaults to True.
        log (bool, optional): Whether to the color scale. Defaults to True.
        density_bins (int, optional): How many density_bins to use for the density
            histogram, i.e. granularity of the density color scale. Defaults to 100.
        xlabel (str, optional): x-axis label. Defaults to "Actual".
        ylabel (str, optional): y-axis label. Defaults to "Predicted".
        identity (bool, optional): Whether to add an identity/parity line (y = x).
            Defaults to True.
        stats (bool, optional): Whether to display a text box with MAE and R^2.
            Defaults to True.

    Returns:
        ax: The plot's matplotlib Axes.
    """
    if ax is None:
        ax = plt.gca()

    xs, ys, cs = hist_density(xs, ys, sort=sort, bins=density_bins)

    norm = mpl.colors.LogNorm() if log else None

    ax.scatter(xs, ys, c=cs, cmap=color_map, norm=norm, **kwargs)

    if identity:
        ax.axline(
            (0, 0), (1, 1), alpha=0.5, zorder=0, linestyle="dashed", color="black"
        )

    if stats:
        add_mae_r2_box(xs, ys, ax)

    ax.set(xlabel=xlabel, ylabel=ylabel)

    return ax
Exemplo n.º 2
0
def density_hexbin(
    xs: NumArray,
    yx: NumArray,
    ax: Axes = None,
    weights: NumArray = None,
    xlabel: str = "Actual",
    ylabel: str = "Predicted",
    **kwargs: Any,
) -> Axes:
    """Hexagonal-grid scatter plot colored by point density or by density in third
    dimension passed as weights.

    Args:
        xs (array): x values
        yx (array): y values
        ax (Axes, optional): matplotlib Axes on which to plot. Defaults to None.
        weights (array, optional): If given, these values are accumulated in the bins.
            Otherwise, every point has value 1. Must be of the same length as x and y.
            Defaults to None.
        xlabel (str, optional): x-axis label. Defaults to "Actual".
        ylabel (str, optional): y-axis label. Defaults to "Predicted".

    Returns:
        ax: The plot's matplotlib Axes.
    """
    if ax is None:
        ax = plt.gca()

    # the scatter plot
    hexbin = ax.hexbin(xs, yx, gridsize=75, mincnt=1, bins="log", C=weights, **kwargs)

    cb_ax = ax.inset_axes([0.95, 0.03, 0.03, 0.7])  # [left, bottom, width, height]
    plt.colorbar(hexbin, cax=cb_ax)
    cb_ax.yaxis.set_ticks_position("left")

    # identity line
    ax.axline((0, 0), (1, 1), alpha=0.5, zorder=0, linestyle="dashed", color="black")

    add_mae_r2_box(xs, yx, ax, loc="upper left")

    ax.set(xlabel=xlabel, ylabel=ylabel)

    return ax
Exemplo n.º 3
0
def scatter_with_err_bar(
    xs: NumArray,
    ys: NumArray,
    xerr: NumArray = None,
    yerr: NumArray = None,
    ax: Axes = None,
    xlabel: str = "Actual",
    ylabel: str = "Predicted",
    title: str = None,
    **kwargs: Any,
) -> Axes:
    """Scatter plot with optional x- and/or y-error bars. Useful when passing model
    uncertainties as yerr=y_std for checking if uncertainty correlates with error,
    i.e. if points farther from the parity line have larger uncertainty.

    Args:
        xs (array): x-values
        ys (array): y-values
        xerr (array, optional): Horizontal error bars. Defaults to None.
        yerr (array, optional): Vertical error bars. Defaults to None.
        ax (Axes, optional): matplotlib Axes on which to plot. Defaults to None.
        xlabel (str, optional): x-axis label. Defaults to "Actual".
        ylabel (str, optional): y-axis label. Defaults to "Predicted".
        title (str, optional): Plot tile. Defaults to None.

    Returns:
        ax: The plot's matplotlib Axes.
    """
    if ax is None:
        ax = plt.gca()

    styles = dict(markersize=6, fmt="o", ecolor="g", capthick=2, elinewidth=2)
    ax.errorbar(xs, ys, yerr=yerr, xerr=xerr, **kwargs, **styles)

    # identity line
    ax.axline((0, 0), (1, 1), alpha=0.5, zorder=0, linestyle="dashed", color="black")

    add_mae_r2_box(xs, ys, ax)

    ax.set(xlabel=xlabel, ylabel=ylabel, title=title)

    return ax
Exemplo n.º 4
0
def qq_gaussian(
    y_true: NumArray,
    y_pred: NumArray,
    y_std: NumArray | dict[str, NumArray],
    ax: Axes = None,
) -> Axes:
    """Plot the Gaussian quantile-quantile (Q-Q) plot of one (passed as array)
    or multiple (passed as dict) sets of uncertainty estimates for a single
    pair of ground truth targets `y_true` and model predictions `y_pred`.

    Overconfidence relative to a Gaussian distribution is visualized as shaded
    areas below the parity line, underconfidence (oversized uncertainties) as
    shaded areas above the parity line.

    The measure of calibration is how well the uncertainty percentiles conform
    to those of a normal distribution.

    Inspired by https://git.io/JufOz.
    Info on Q-Q plots: https://wikipedia.org/wiki/Q-Q_plot

    Args:
        y_true (array): ground truth targets
        y_pred (array): model predictions
        y_std (array | dict[str, array]): model uncertainties
        ax (Axes): matplotlib Axes on which to plot. Defaults to None.

    Returns:
        ax: The plot's matplotlib Axes.
    """
    if ax is None:
        ax = plt.gca()

    if isinstance(y_std, np.ndarray):
        y_std = {"std": y_std}

    res = np.abs(y_pred - y_true)
    resolution = 100

    lines = []  # collect plotted lines to show second legend with miscalibration areas
    for key, std in y_std.items():

        z_scored = (np.array(res) / std).reshape(-1, 1)

        exp_proportions = np.linspace(0, 1, resolution)
        gaussian_upper_bound = norm.ppf(0.5 + exp_proportions / 2)
        obs_proportions = np.mean(z_scored <= gaussian_upper_bound, axis=0)

        [line] = ax.plot(
            exp_proportions, obs_proportions, linewidth=2, alpha=0.8, label=key
        )
        ax.fill_between(
            exp_proportions, y1=obs_proportions, y2=exp_proportions, alpha=0.2
        )
        miscal_area = np.trapz(
            np.abs(obs_proportions - exp_proportions), dx=1 / resolution
        )
        lines.append([line, miscal_area])

    # identity line
    ax.axline((0, 0), (1, 1), alpha=0.5, zorder=0, linestyle="dashed", color="black")

    ax.set(xlim=(0, 1), ylim=(0, 1))
    ax.set(xlabel="Theoretical Quantile", ylabel="Observed Quantile")

    legend1 = ax.legend(loc="upper left", frameon=False)
    # Multiple legends on the same axes:
    # https://matplotlib.org/3.3.3/tutorials/intermediate/legend_guide.html#multiple-legends-on-the-same-axes
    ax.add_artist(legend1)

    lines, areas = zip(*lines)

    if len(lines) > 1:
        legend2 = ax.legend(
            lines,
            [f"{area:.2f}" for area in areas],
            title="Miscalibration areas",
            loc="lower right",
            ncol=2,
            frameon=False,
        )
        legend2._legend_box.align = "left"  # https://stackoverflow.com/a/44620643
    else:
        ax.legend(
            lines,
            [f"Miscalibration area: {areas[0]:.2f}"],
            loc="lower right",
            frameon=False,
        )

    return ax