def subplot_fn(data, **kwargs): del kwargs ax = plt.gca() x = ax.my_xlabel = utils.assert_and_get_constant(data.x) y = ax.my_ylabel = utils.assert_and_get_constant(data.y) if y.endswith("_residual"): # Plot residual w.r.t. accuracy: y = y.replace("_residual", "") data = data.copy() data.loc[:, y] = get_residual(data["downstream_error"], data[y]) for marker in data.family_marker.unique(): data_sub = data[data.family_marker == marker] ax.scatter( data_sub[x], data_sub[y], s=plotting.model_to_scatter_size(data_sub.model_size), c=data_sub.family_index, cmap=cmap, vmin=0, vmax=len(family_order), marker=marker, alpha=0.7, zorder=30, label=utils.assert_and_get_constant(data_sub.ModelFamily), linewidth=0.0)
def subplot_fn(**kwargs): x = kwargs["x"] y = kwargs["y"] data = kwargs["data"] utils.assert_no_duplicates_in_condition( data, group_by=["DatasetName", "ModelName"]) ax = plt.gca() kwargs["color"] = utils.assert_and_get_constant(data.family_color) ax = sns.regplot(ax=ax, **kwargs) plotter = RegressionPlotter(x, y, data=data) # Get regression parameters and show in plot: grid = np.linspace(data[x].min(), data[x].max(), 100) beta_plot, beta_boots = plotter.get_params(grid) beta_plot = np.array(beta_plot) beta_boots = np.array(beta_boots) intercept = 10 ** np.median(beta_boots[0, :]) intercept_ci = 10 ** sns_utils.ci(beta_boots[0, :]) slope = np.median(beta_boots[1, :]) slope_ci = sns_utils.ci(beta_boots[1, :]) s = (f"a = {intercept:1.2f} ({intercept_ci[0]:1.2f}, " f"{intercept_ci[1]:1.2f})\nk = {slope:1.2f} ({slope_ci[0]:1.2f}, " f"{slope_ci[1]:1.2f})") ax.text(0.04, 0.96, s, va="top", ha="left", transform=ax.transAxes, fontsize=4, color=(0.3, 0.3, 0.3), bbox=dict(facecolor="w", alpha=0.8, boxstyle="square,pad=0.1"))
def add_metric_description_title(df_plot: pd.DataFrame, fig: mpl.figure.Figure, y: float = 1.0): """Adds a suptitle to the figure, describing the metric used.""" assert df_plot.Metric.nunique() == 1, "More than one metric in DataFrame." binning_scheme = utils.assert_and_get_constant(df_plot.binning_scheme) num_bins = utils.assert_and_get_constant(df_plot.num_bins) norm = utils.assert_and_get_constant(df_plot.norm) title = (f"ECE variant: {binning_scheme} binning, " f"{num_bins:.0f} bins, " f"{norm} norm") display_names = { "adaptive": "equal-mass", "even": "equal-width", "l1": "L1", "l2": "L2", } for old, new in display_names.items(): title = title.replace(old, new) fig.suptitle(title, y=y, verticalalignment="bottom")
def subplot_fn(data, x, y, **kwargs): del kwargs ax = plt.gca() for marker in data.family_marker.unique(): data_sub = data[data.family_marker == marker] ax.scatter( data_sub[x], data_sub[y], s=plotting.model_to_scatter_size(data_sub.model_size), c=data_sub.family_index, cmap=cmap, vmin=0, vmax=len(family_order), marker=marker, linewidth=0, alpha=0.7, zorder=30, label=utils.assert_and_get_constant(data_sub.ModelFamily))
def plot(df_main: pd.DataFrame, df_reliability: pd.DataFrame, gce_prefix: str = plotting.STD_GCE_PREFIX, rescaling_method: str = "temperature_scaling", add_guo: bool = False) -> mpl.figure.Figure: """Plots acc/calib and reliability diagrams on clean ImageNet (Figure 1).""" family_order = display.get_model_families_sorted() if add_guo: family_order.append("guo") if rescaling_method == "both": rescaling_methods = ["none", "temperature_scaling"] else: rescaling_methods = [rescaling_method] # Set up figure: fig = plt.figure(figsize=(display.FULL_WIDTH, 1.6)) if rescaling_method == "both": widths = [1.75, 1.75, 1, 1, 1] else: widths = [1.8, 1, 1, 1, 1, 1] heights = [0.3, 1] spec = fig.add_gridspec(ncols=len(widths), nrows=len(heights), width_ratios=widths, height_ratios=heights) # First panels: acc vs calib ImageNet: for ax_i, rescaling_method in enumerate(rescaling_methods): df_plot, cmap = _get_data(df_main, gce_prefix, family_order, rescaling_methods=[rescaling_method]) ax = fig.add_subplot(spec[0:2, ax_i], box_aspect=1.0) big_ax = ax for i, family in enumerate(family_order): if family == "guo": continue data_sub = df_plot[df_plot.ModelFamily == family] if data_sub.empty: continue ax.scatter( data_sub["downstream_error"], data_sub["MetricValue"], s=plotting.model_to_scatter_size(data_sub.model_size), c=data_sub.family_index, cmap=cmap, vmin=0, vmax=len(family_order), marker=utils.assert_and_get_constant(data_sub.family_marker), alpha=0.7, linewidth=0.0, zorder=100 - i, # Z-order is same as model family order. label=family) # Manually add Guo et al data: # From Table 1 and Table S2 in https://arxiv.org/pdf/1706.04599.pdf. # First model is DenseNet161, second is ResNet152. if add_guo: size = plotting.model_to_scatter_size(1) color = [len(family_order) - 1] * 2 marker = "x" if rescaling_method == "none": ax.scatter([0.2257, 0.2231], [0.0628, 0.0548], s=size, c=color, marker=marker, alpha=0.7, label="guo") if rescaling_method == "temperature_scaling": ax.scatter([0.2257, 0.2231], [0.0199, 0.0186], s=size, c=color, marker=marker, alpha=0.7, label="guo") plotting.show_spines(ax) ax.set_anchor("N") ax.grid(False, which="minor") ax.grid(True, axis="both") ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.1)) ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.01)) ax.set_ylim(bottom=0.01, top=0.09) ax.set_xlim(0.05, 0.55) ax.set_xlabel(display.XLABEL_INET_ERROR) if len(rescaling_methods) == 1: # Showing just one rescaling method. if rescaling_method == "none": ax.set_ylabel(display.YLABEL_ECE_UNSCALED) elif rescaling_method == "temperature_scaling": ax.set_ylabel(display.YLABEL_ECE_TEMP_SCALED) else: # Showing both rescaling methods. if rescaling_method == "none": ax.set_title("Unscaled") elif rescaling_method == "temperature_scaling": ax.set_title("Temperature-scaled") if ax.is_first_col(): ax.set_ylabel("ECE") # Remaining panels: Reliability diagrams: offset = len(rescaling_methods) model_names = [ "mixer/jft-300m/H/14", "vit-h/14", "bit-jft-r152-x4-480", ] if offset == 1: model_names += ["wsl_32x48d", "simclr-4x-fine-tuned-100"] dataset_name = "imagenet(split='validation[20%:]')" for i, model_name in enumerate(model_names): # Get predictions: mask = df_main.ModelName == model_name mask &= df_main.rescaling_method == "none" mask &= df_main.Metric == "accuracy" mask &= df_main.DatasetName == dataset_name raw_model_name = df_main[mask].RawModelName assert len(raw_model_name) <= 1, df_main[mask] if len(raw_model_name) == 0: # pylint: disable=g-explicit-length-test continue binned = _get_binned_reliability_data( df_reliability, utils.assert_and_get_constant(raw_model_name), dataset_name) rel_ax = fig.add_subplot(spec[1, i + offset]) _plot_confidence_and_reliability(conf_ax=fig.add_subplot(spec[0, i + offset]), rel_ax=rel_ax, binned=binned, model_name=model_name, first_col=offset) if rescaling_method == "none": rel_ax.set_xlabel("Confidence\n(unscaled)") elif rescaling_method == "temperature_scaled": rel_ax.set_xlabel("Confidence\n(temp. scaled)") # Model family legend: handles, labels = plotting.get_model_family_legend(big_ax, family_order) legend = big_ax.legend(handles=handles, labels=labels, loc="upper right", frameon=True, labelspacing=0.25, handletextpad=0.1, borderpad=0.3, fontsize=4) legend.get_frame().set_linewidth(mpl.rcParams["axes.linewidth"]) legend.get_frame().set_edgecolor("lightgray") plotting.apply_to_fig_text(fig, display.prettify) offset = 0.05 for ax in fig.axes[1:]: box = ax.get_position() box.x0 += offset box.x1 += offset ax.set_position(box) return fig
def plot_reliability_diagrams( df_main: pd.DataFrame, df_reliability: pd.DataFrame, family: str, rescaling_method: str = "temperature_scaling", dataset_name: str = "imagenet(split='validation[20%:]')", gce_prefix=plotting.STD_GCE_PREFIX) -> mpl.figure.Figure: """Plots acc/calib and reliability diagrams on clean ImageNet (Figure 1).""" df_plot, _ = _get_data(df_main, gce_prefix, [family], rescaling_methods=[rescaling_method], dataset_name=dataset_name) df_models = df_plot.drop_duplicates(subset=["ModelName"]) df_models = df_models.sort_values(by="model_size") model_names = df_models.ModelName.to_list() # Set up figure: num_cols = max(5, len(model_names)) width = num_cols * 0.80 fig = plt.figure(figsize=(width, 1.4)) spec = fig.add_gridspec(ncols=num_cols, nrows=2, height_ratios=[0.4, 1]) for i in range(num_cols): if i >= len(model_names): # Add axes as placeholders for formatting but set to invisible: fig.add_subplot(spec[0, i]).set_visible(False) fig.add_subplot(spec[1, i]).set_visible(False) continue model_name = model_names[i] # Get predictions: mask = df_main.ModelName == model_name mask &= df_main.rescaling_method == rescaling_method mask &= df_main.Metric == "accuracy" mask &= df_main.DatasetName == dataset_name raw_model_name = df_main[mask].RawModelName assert len(raw_model_name) == 1 binned = _get_binned_reliability_data( df_reliability, utils.assert_and_get_constant(raw_model_name), dataset_name) rel_ax = fig.add_subplot(spec[1, i]) _plot_confidence_and_reliability(conf_ax=fig.add_subplot(spec[0, i]), rel_ax=rel_ax, binned=binned, model_name=model_name) if rescaling_method == "none": rel_ax.set_xlabel("Confidence\n(unscaled)") elif rescaling_method == "temperature_scaling": rel_ax.set_xlabel("Confidence\n(temp. scaled)") def prettify(s): s = display.prettify(s) s = s.replace("MLP-Mixer-", "MLP-Mixer\n") return s plotting.apply_to_fig_text(fig, prettify) plotting.apply_to_fig_text(fig, lambda x: x.replace("EfficientNet", "EffNet")) fig.subplots_adjust(hspace=-0.05) return fig
def plot(df_main: pd.DataFrame, gce_prefix: str = plotting.STD_GCE_PREFIX, rescaling_method: str = "temperature_scaling", add_guo: bool = False) -> mpl.figure.Figure: """Plots acc/calib and reliability diagrams on clean ImageNet (Figure 1).""" rescaling_methods = ["none", rescaling_method] family_order = display.get_model_families_sorted( ["mixer", "vit", "bit", "simclr"]) if add_guo: family_order.append("guo") # Set up figure: fig = plt.figure(figsize=(display.FULL_WIDTH / 2, 2)) spec = fig.add_gridspec(ncols=3, nrows=2) for col, bit_version in enumerate( ["BiT-ImageNet", "BiT-ImageNet21k", "BiT-JFT"]): # pylint: disable=g-long-lambda if bit_version == "BiT-ImageNet": display.get_standard_model_list = lambda: [ m for m in display.MODEL_SIZE.keys() if not (m.startswith( "bit-imagenet21k-") or m.startswith("bit-jft-")) ] elif bit_version == "BiT-ImageNet21k": display.get_standard_model_list = lambda: [ m for m in display.MODEL_SIZE.keys() if not (m.startswith( "bit-imagenet-") or m.startswith("bit-jft-")) ] elif bit_version == "BiT-JFT": display.get_standard_model_list = lambda: [ m for m in display.MODEL_SIZE.keys() if not (m.startswith( "bit-imagenet-") or m.startswith("bit-imagenet21k-")) ] else: raise ValueError(f"Unknown BiT version: {bit_version}") # pylint: enable=g-long-lambda for row, rescaling_method in enumerate(rescaling_methods): df_plot, cmap = _get_data(df_main, gce_prefix, family_order, rescaling_methods=[rescaling_method]) ax = fig.add_subplot(spec[row, col]) big_ax = ax for i, family in enumerate(family_order): if family == "guo": continue data_sub = df_plot[df_plot.ModelFamily == family] if data_sub.empty: continue ax.scatter( data_sub["downstream_error"], data_sub["MetricValue"], s=plotting.model_to_scatter_size(data_sub.model_size), c=data_sub.family_index, cmap=cmap, vmin=0, vmax=len(family_order), marker=utils.assert_and_get_constant( data_sub.family_marker), linewidth=0.5, alpha=1.0 if "bit" in family else 0.5, zorder=100 - i, # Z-order is same as model family order. label=family) # Manually add Guo et al data: # From Table 1 and Table S2 in https://arxiv.org/pdf/1706.04599.pdf. # First model is DenseNet161, second is ResNet152. if add_guo: size = plotting.model_to_scatter_size(1) color = [len(family_order) - 1] * 2 marker = "x" if rescaling_method == "none": ax.scatter([0.2257, 0.2231], [0.0628, 0.0548], s=size, c=color, marker=marker, alpha=0.7, label="guo") if rescaling_method == "temperature_scaling": ax.scatter([0.2257, 0.2231], [0.0199, 0.0186], s=size, c=color, marker=marker, alpha=0.7, label="guo") plotting.show_spines(ax) # Aspect ratios are tuned manually for display in the paper: ax.set_anchor("N") ax.grid(False, which="minor") ax.grid(True, axis="both") ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.01)) ax.set_ylim(bottom=0.0, top=0.09) ax.set_xlim(0.05, 0.3) ax.set_xlabel(display.XLABEL_INET_ERROR) if ax.is_first_row(): ax.set_title(bit_version, fontsize=6) ax.set_xlabel("") ax.set_xticklabels("") else: ax.set_ylim(bottom=0.0, top=0.05) if ax.is_first_col(): if rescaling_method == "none": ax.set_ylabel(display.YLABEL_ECE_UNSCALED) elif rescaling_method == "temperature_scaling": ax.set_ylabel(display.YLABEL_ECE_TEMP_SCALED) else: ax.set_yticklabels("") fig.tight_layout(pad=0.5) # Model family legend: handles, labels = plotting.get_model_family_legend(big_ax, family_order) plotting.apply_to_fig_text(fig, display.prettify) plotting.apply_to_fig_text(fig, lambda x: x.replace("EfficientNet", "EffNet")) legend = fig.axes[0].legend(handles=handles, labels=labels, loc="upper center", title="Model family", bbox_to_anchor=(0.55, -0.025), frameon=True, bbox_transform=fig.transFigure, ncol=len(family_order), handletextpad=0.1) legend.get_frame().set_linewidth(mpl.rcParams["axes.linewidth"]) legend.get_frame().set_edgecolor("lightgray") plotting.apply_to_fig_text(fig, display.prettify) return fig
def plot(df_main: pd.DataFrame, gce_prefix: str = plotting.STD_GCE_PREFIX, rescaling_method: str = "temperature_scaling", add_guo: bool = False) -> mpl.figure.Figure: """Plots acc/calib and reliability diagrams on clean ImageNet (Figure 1).""" rescaling_methods = ["none", rescaling_method] family_order = display.get_model_families_sorted() if add_guo: family_order.append("guo") # Set up figure: fig = plt.figure(figsize=(display.FULL_WIDTH/2, 1.4)) spec = fig.add_gridspec(ncols=2, nrows=1) for ax_i, rescaling_method in enumerate(rescaling_methods): df_plot, cmap = _get_data(df_main, gce_prefix, family_order, rescaling_methods=[rescaling_method]) ax = fig.add_subplot(spec[:, ax_i]) big_ax = ax for i, family in enumerate(family_order): if family == "guo": continue data_sub = df_plot[df_plot.ModelFamily == family] if data_sub.empty: continue ax.scatter( data_sub["downstream_error"], data_sub["MetricValue"], s=plotting.model_to_scatter_size(data_sub.model_size), c=data_sub.family_index, cmap=cmap, vmin=0, vmax=len(family_order), marker=utils.assert_and_get_constant(data_sub.family_marker), alpha=0.7, linewidth=0.0, zorder=100 - i, # Z-order is same as model family order. label=family) # Manually add Guo et al data: # From Table 1 and Table S2 in https://arxiv.org/pdf/1706.04599.pdf. # First model is DenseNet161, second is ResNet152. if add_guo: size = plotting.model_to_scatter_size(1) color = [len(family_order) - 1] * 2 marker = "x" if rescaling_method == "none": ax.scatter([0.2257, 0.2231], [0.0628, 0.0548], s=size, c=color, marker=marker, alpha=0.7, label="guo") if rescaling_method == "temperature_scaling": ax.scatter([0.2257, 0.2231], [0.0199, 0.0186], s=size, c=color, marker=marker, alpha=0.7, label="guo") plotting.show_spines(ax) # Aspect ratios are tuned manually for display in the paper: ax.set_anchor("N") ax.grid(False, which="minor") ax.grid(True, axis="both") ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.1)) ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.01)) ax.set_ylim(bottom=0.01, top=0.09) ax.set_xlim(0.05, 0.5) ax.set_xlabel(display.XLABEL_INET_ERROR) if rescaling_method == "none": ax.set_title("Unscaled") elif rescaling_method == "temperature_scaling": ax.set_title("Temperature-scaled") ax.set_yticklabels("") if ax.is_first_col(): ax.set_ylabel("ECE") # Model family legend: handles, labels = plotting.get_model_family_legend(big_ax, family_order) legend = big_ax.legend( handles=handles, labels=labels, loc="upper right", frameon=True, labelspacing=0.3, handletextpad=0.1, borderpad=0.3, fontsize=4) legend.get_frame().set_linewidth(mpl.rcParams["axes.linewidth"]) legend.get_frame().set_edgecolor("lightgray") plotting.apply_to_fig_text(fig, display.prettify) plotting.apply_to_fig_text(fig, lambda x: x.replace("EfficientNet", "EffNet")) return fig