def density_scatter( xs: NumArray, ys: NumArray, ax: Axes = None, color_map: str = "Blues", sort: bool = True, log: bool = True, density_bins: int = 100, xlabel: str = "Actual", ylabel: str = "Predicted", identity: bool = True, stats: bool = True, **kwargs: Any, ) -> Axes: """Scatter plot colored (and optionally sorted) by density. Args: xs (array): x values. ys (array): y values. ax (Axes, optional): matplotlib Axes on which to plot. Defaults to None. color_map (str, optional): plt color map or valid string name. Defaults to "Blues". sort (bool, optional): Whether to sort the data. Defaults to True. log (bool, optional): Whether to the color scale. Defaults to True. density_bins (int, optional): How many density_bins to use for the density histogram, i.e. granularity of the density color scale. Defaults to 100. xlabel (str, optional): x-axis label. Defaults to "Actual". ylabel (str, optional): y-axis label. Defaults to "Predicted". identity (bool, optional): Whether to add an identity/parity line (y = x). Defaults to True. stats (bool, optional): Whether to display a text box with MAE and R^2. Defaults to True. Returns: ax: The plot's matplotlib Axes. """ if ax is None: ax = plt.gca() xs, ys, cs = hist_density(xs, ys, sort=sort, bins=density_bins) norm = mpl.colors.LogNorm() if log else None ax.scatter(xs, ys, c=cs, cmap=color_map, norm=norm, **kwargs) if identity: ax.axline( (0, 0), (1, 1), alpha=0.5, zorder=0, linestyle="dashed", color="black" ) if stats: add_mae_r2_box(xs, ys, ax) ax.set(xlabel=xlabel, ylabel=ylabel) return ax
def density_hexbin( xs: NumArray, yx: NumArray, ax: Axes = None, weights: NumArray = None, xlabel: str = "Actual", ylabel: str = "Predicted", **kwargs: Any, ) -> Axes: """Hexagonal-grid scatter plot colored by point density or by density in third dimension passed as weights. Args: xs (array): x values yx (array): y values ax (Axes, optional): matplotlib Axes on which to plot. Defaults to None. weights (array, optional): If given, these values are accumulated in the bins. Otherwise, every point has value 1. Must be of the same length as x and y. Defaults to None. xlabel (str, optional): x-axis label. Defaults to "Actual". ylabel (str, optional): y-axis label. Defaults to "Predicted". Returns: ax: The plot's matplotlib Axes. """ if ax is None: ax = plt.gca() # the scatter plot hexbin = ax.hexbin(xs, yx, gridsize=75, mincnt=1, bins="log", C=weights, **kwargs) cb_ax = ax.inset_axes([0.95, 0.03, 0.03, 0.7]) # [left, bottom, width, height] plt.colorbar(hexbin, cax=cb_ax) cb_ax.yaxis.set_ticks_position("left") # identity line ax.axline((0, 0), (1, 1), alpha=0.5, zorder=0, linestyle="dashed", color="black") add_mae_r2_box(xs, yx, ax, loc="upper left") ax.set(xlabel=xlabel, ylabel=ylabel) return ax
def scatter_with_err_bar( xs: NumArray, ys: NumArray, xerr: NumArray = None, yerr: NumArray = None, ax: Axes = None, xlabel: str = "Actual", ylabel: str = "Predicted", title: str = None, **kwargs: Any, ) -> Axes: """Scatter plot with optional x- and/or y-error bars. Useful when passing model uncertainties as yerr=y_std for checking if uncertainty correlates with error, i.e. if points farther from the parity line have larger uncertainty. Args: xs (array): x-values ys (array): y-values xerr (array, optional): Horizontal error bars. Defaults to None. yerr (array, optional): Vertical error bars. Defaults to None. ax (Axes, optional): matplotlib Axes on which to plot. Defaults to None. xlabel (str, optional): x-axis label. Defaults to "Actual". ylabel (str, optional): y-axis label. Defaults to "Predicted". title (str, optional): Plot tile. Defaults to None. Returns: ax: The plot's matplotlib Axes. """ if ax is None: ax = plt.gca() styles = dict(markersize=6, fmt="o", ecolor="g", capthick=2, elinewidth=2) ax.errorbar(xs, ys, yerr=yerr, xerr=xerr, **kwargs, **styles) # identity line ax.axline((0, 0), (1, 1), alpha=0.5, zorder=0, linestyle="dashed", color="black") add_mae_r2_box(xs, ys, ax) ax.set(xlabel=xlabel, ylabel=ylabel, title=title) return ax
def qq_gaussian( y_true: NumArray, y_pred: NumArray, y_std: NumArray | dict[str, NumArray], ax: Axes = None, ) -> Axes: """Plot the Gaussian quantile-quantile (Q-Q) plot of one (passed as array) or multiple (passed as dict) sets of uncertainty estimates for a single pair of ground truth targets `y_true` and model predictions `y_pred`. Overconfidence relative to a Gaussian distribution is visualized as shaded areas below the parity line, underconfidence (oversized uncertainties) as shaded areas above the parity line. The measure of calibration is how well the uncertainty percentiles conform to those of a normal distribution. Inspired by https://git.io/JufOz. Info on Q-Q plots: https://wikipedia.org/wiki/Q-Q_plot Args: y_true (array): ground truth targets y_pred (array): model predictions y_std (array | dict[str, array]): model uncertainties ax (Axes): matplotlib Axes on which to plot. Defaults to None. Returns: ax: The plot's matplotlib Axes. """ if ax is None: ax = plt.gca() if isinstance(y_std, np.ndarray): y_std = {"std": y_std} res = np.abs(y_pred - y_true) resolution = 100 lines = [] # collect plotted lines to show second legend with miscalibration areas for key, std in y_std.items(): z_scored = (np.array(res) / std).reshape(-1, 1) exp_proportions = np.linspace(0, 1, resolution) gaussian_upper_bound = norm.ppf(0.5 + exp_proportions / 2) obs_proportions = np.mean(z_scored <= gaussian_upper_bound, axis=0) [line] = ax.plot( exp_proportions, obs_proportions, linewidth=2, alpha=0.8, label=key ) ax.fill_between( exp_proportions, y1=obs_proportions, y2=exp_proportions, alpha=0.2 ) miscal_area = np.trapz( np.abs(obs_proportions - exp_proportions), dx=1 / resolution ) lines.append([line, miscal_area]) # identity line ax.axline((0, 0), (1, 1), alpha=0.5, zorder=0, linestyle="dashed", color="black") ax.set(xlim=(0, 1), ylim=(0, 1)) ax.set(xlabel="Theoretical Quantile", ylabel="Observed Quantile") legend1 = ax.legend(loc="upper left", frameon=False) # Multiple legends on the same axes: # https://matplotlib.org/3.3.3/tutorials/intermediate/legend_guide.html#multiple-legends-on-the-same-axes ax.add_artist(legend1) lines, areas = zip(*lines) if len(lines) > 1: legend2 = ax.legend( lines, [f"{area:.2f}" for area in areas], title="Miscalibration areas", loc="lower right", ncol=2, frameon=False, ) legend2._legend_box.align = "left" # https://stackoverflow.com/a/44620643 else: ax.legend( lines, [f"Miscalibration area: {areas[0]:.2f}"], loc="lower right", frameon=False, ) return ax