def plot(df_main: pd.DataFrame, gce_prefix: str = plotting.STD_GCE_PREFIX, rescaling_method: str = "temperature_scaling", add_guo: bool = False) -> mpl.figure.Figure: """Plots acc/calib and reliability diagrams on clean ImageNet (Figure 1).""" rescaling_methods = ["none", rescaling_method] family_order = display.get_model_families_sorted( ["mixer", "vit", "bit", "simclr"]) if add_guo: family_order.append("guo") # Set up figure: fig = plt.figure(figsize=(display.FULL_WIDTH / 2, 2)) spec = fig.add_gridspec(ncols=3, nrows=2) for col, bit_version in enumerate( ["BiT-ImageNet", "BiT-ImageNet21k", "BiT-JFT"]): # pylint: disable=g-long-lambda if bit_version == "BiT-ImageNet": display.get_standard_model_list = lambda: [ m for m in display.MODEL_SIZE.keys() if not (m.startswith( "bit-imagenet21k-") or m.startswith("bit-jft-")) ] elif bit_version == "BiT-ImageNet21k": display.get_standard_model_list = lambda: [ m for m in display.MODEL_SIZE.keys() if not (m.startswith( "bit-imagenet-") or m.startswith("bit-jft-")) ] elif bit_version == "BiT-JFT": display.get_standard_model_list = lambda: [ m for m in display.MODEL_SIZE.keys() if not (m.startswith( "bit-imagenet-") or m.startswith("bit-imagenet21k-")) ] else: raise ValueError(f"Unknown BiT version: {bit_version}") # pylint: enable=g-long-lambda for row, rescaling_method in enumerate(rescaling_methods): df_plot, cmap = _get_data(df_main, gce_prefix, family_order, rescaling_methods=[rescaling_method]) ax = fig.add_subplot(spec[row, col]) big_ax = ax for i, family in enumerate(family_order): if family == "guo": continue data_sub = df_plot[df_plot.ModelFamily == family] if data_sub.empty: continue ax.scatter( data_sub["downstream_error"], data_sub["MetricValue"], s=plotting.model_to_scatter_size(data_sub.model_size), c=data_sub.family_index, cmap=cmap, vmin=0, vmax=len(family_order), marker=utils.assert_and_get_constant( data_sub.family_marker), linewidth=0.5, alpha=1.0 if "bit" in family else 0.5, zorder=100 - i, # Z-order is same as model family order. label=family) # Manually add Guo et al data: # From Table 1 and Table S2 in https://arxiv.org/pdf/1706.04599.pdf. # First model is DenseNet161, second is ResNet152. if add_guo: size = plotting.model_to_scatter_size(1) color = [len(family_order) - 1] * 2 marker = "x" if rescaling_method == "none": ax.scatter([0.2257, 0.2231], [0.0628, 0.0548], s=size, c=color, marker=marker, alpha=0.7, label="guo") if rescaling_method == "temperature_scaling": ax.scatter([0.2257, 0.2231], [0.0199, 0.0186], s=size, c=color, marker=marker, alpha=0.7, label="guo") plotting.show_spines(ax) # Aspect ratios are tuned manually for display in the paper: ax.set_anchor("N") ax.grid(False, which="minor") ax.grid(True, axis="both") ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.01)) ax.set_ylim(bottom=0.0, top=0.09) ax.set_xlim(0.05, 0.3) ax.set_xlabel(display.XLABEL_INET_ERROR) if ax.is_first_row(): ax.set_title(bit_version, fontsize=6) ax.set_xlabel("") ax.set_xticklabels("") else: ax.set_ylim(bottom=0.0, top=0.05) if ax.is_first_col(): if rescaling_method == "none": ax.set_ylabel(display.YLABEL_ECE_UNSCALED) elif rescaling_method == "temperature_scaling": ax.set_ylabel(display.YLABEL_ECE_TEMP_SCALED) else: ax.set_yticklabels("") fig.tight_layout(pad=0.5) # Model family legend: handles, labels = plotting.get_model_family_legend(big_ax, family_order) plotting.apply_to_fig_text(fig, display.prettify) plotting.apply_to_fig_text(fig, lambda x: x.replace("EfficientNet", "EffNet")) legend = fig.axes[0].legend(handles=handles, labels=labels, loc="upper center", title="Model family", bbox_to_anchor=(0.55, -0.025), frameon=True, bbox_transform=fig.transFigure, ncol=len(family_order), handletextpad=0.1) legend.get_frame().set_linewidth(mpl.rcParams["axes.linewidth"]) legend.get_frame().set_edgecolor("lightgray") plotting.apply_to_fig_text(fig, display.prettify) return fig
def plot(df_main: pd.DataFrame, df_reliability: pd.DataFrame, gce_prefix: str = plotting.STD_GCE_PREFIX, rescaling_method: str = "temperature_scaling", add_guo: bool = False) -> mpl.figure.Figure: """Plots acc/calib and reliability diagrams on clean ImageNet (Figure 1).""" family_order = display.get_model_families_sorted() if add_guo: family_order.append("guo") if rescaling_method == "both": rescaling_methods = ["none", "temperature_scaling"] else: rescaling_methods = [rescaling_method] # Set up figure: fig = plt.figure(figsize=(display.FULL_WIDTH, 1.6)) if rescaling_method == "both": widths = [1.75, 1.75, 1, 1, 1] else: widths = [1.8, 1, 1, 1, 1, 1] heights = [0.3, 1] spec = fig.add_gridspec(ncols=len(widths), nrows=len(heights), width_ratios=widths, height_ratios=heights) # First panels: acc vs calib ImageNet: for ax_i, rescaling_method in enumerate(rescaling_methods): df_plot, cmap = _get_data(df_main, gce_prefix, family_order, rescaling_methods=[rescaling_method]) ax = fig.add_subplot(spec[0:2, ax_i], box_aspect=1.0) big_ax = ax for i, family in enumerate(family_order): if family == "guo": continue data_sub = df_plot[df_plot.ModelFamily == family] if data_sub.empty: continue ax.scatter( data_sub["downstream_error"], data_sub["MetricValue"], s=plotting.model_to_scatter_size(data_sub.model_size), c=data_sub.family_index, cmap=cmap, vmin=0, vmax=len(family_order), marker=utils.assert_and_get_constant(data_sub.family_marker), alpha=0.7, linewidth=0.0, zorder=100 - i, # Z-order is same as model family order. label=family) # Manually add Guo et al data: # From Table 1 and Table S2 in https://arxiv.org/pdf/1706.04599.pdf. # First model is DenseNet161, second is ResNet152. if add_guo: size = plotting.model_to_scatter_size(1) color = [len(family_order) - 1] * 2 marker = "x" if rescaling_method == "none": ax.scatter([0.2257, 0.2231], [0.0628, 0.0548], s=size, c=color, marker=marker, alpha=0.7, label="guo") if rescaling_method == "temperature_scaling": ax.scatter([0.2257, 0.2231], [0.0199, 0.0186], s=size, c=color, marker=marker, alpha=0.7, label="guo") plotting.show_spines(ax) ax.set_anchor("N") ax.grid(False, which="minor") ax.grid(True, axis="both") ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.1)) ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.01)) ax.set_ylim(bottom=0.01, top=0.09) ax.set_xlim(0.05, 0.55) ax.set_xlabel(display.XLABEL_INET_ERROR) if len(rescaling_methods) == 1: # Showing just one rescaling method. if rescaling_method == "none": ax.set_ylabel(display.YLABEL_ECE_UNSCALED) elif rescaling_method == "temperature_scaling": ax.set_ylabel(display.YLABEL_ECE_TEMP_SCALED) else: # Showing both rescaling methods. if rescaling_method == "none": ax.set_title("Unscaled") elif rescaling_method == "temperature_scaling": ax.set_title("Temperature-scaled") if ax.is_first_col(): ax.set_ylabel("ECE") # Remaining panels: Reliability diagrams: offset = len(rescaling_methods) model_names = [ "mixer/jft-300m/H/14", "vit-h/14", "bit-jft-r152-x4-480", ] if offset == 1: model_names += ["wsl_32x48d", "simclr-4x-fine-tuned-100"] dataset_name = "imagenet(split='validation[20%:]')" for i, model_name in enumerate(model_names): # Get predictions: mask = df_main.ModelName == model_name mask &= df_main.rescaling_method == "none" mask &= df_main.Metric == "accuracy" mask &= df_main.DatasetName == dataset_name raw_model_name = df_main[mask].RawModelName assert len(raw_model_name) <= 1, df_main[mask] if len(raw_model_name) == 0: # pylint: disable=g-explicit-length-test continue binned = _get_binned_reliability_data( df_reliability, utils.assert_and_get_constant(raw_model_name), dataset_name) rel_ax = fig.add_subplot(spec[1, i + offset]) _plot_confidence_and_reliability(conf_ax=fig.add_subplot(spec[0, i + offset]), rel_ax=rel_ax, binned=binned, model_name=model_name, first_col=offset) if rescaling_method == "none": rel_ax.set_xlabel("Confidence\n(unscaled)") elif rescaling_method == "temperature_scaled": rel_ax.set_xlabel("Confidence\n(temp. scaled)") # Model family legend: handles, labels = plotting.get_model_family_legend(big_ax, family_order) legend = big_ax.legend(handles=handles, labels=labels, loc="upper right", frameon=True, labelspacing=0.25, handletextpad=0.1, borderpad=0.3, fontsize=4) legend.get_frame().set_linewidth(mpl.rcParams["axes.linewidth"]) legend.get_frame().set_edgecolor("lightgray") plotting.apply_to_fig_text(fig, display.prettify) offset = 0.05 for ax in fig.axes[1:]: box = ax.get_position() box.x0 += offset box.x1 += offset ax.set_position(box) return fig
def plot(df_main: pd.DataFrame, gce_prefix: str = plotting.STD_GCE_PREFIX, rescaling_methods: Optional[List[str]] = None, plot_confidence: bool = False, add_legend: bool = True, legend_position: str = "right", height: float = 1.05, aspect: float = 1.2, add_metric_description: bool = False, ) -> sns.axisgrid.FacetGrid: """Plots acc/calib for ImageNet and natural OOD datasets.""" # Settings: rescaling_methods = rescaling_methods or ["none", "temperature_scaling"] if plot_confidence: rescaling_methods += ["tau"] family_order = display.get_model_families_sorted() dataset_order = display.OOD_DATASET_ORDER # Select data: mask = df_main.Metric.str.startswith(gce_prefix) mask &= df_main.ModelName.isin(display.get_standard_model_list()) mask &= df_main.DatasetName.isin(dataset_order) mask &= df_main.rescaling_method.isin(rescaling_methods) mask &= ~(df_main.DatasetName.str.startswith("imagenet_a") & (~df_main.use_dataset_labelset.eq(True))) mask &= ~(df_main.DatasetName.str.startswith("imagenet_r") & (~df_main.use_dataset_labelset.eq(True))) family_order = display.get_model_families_sorted() mask &= df_main.ModelFamily.isin(family_order) mask &= df_main.DatasetName.isin(dataset_order) df_plot = df_main[mask].copy() df_plot, cmap = display.add_display_data(df_plot, family_order) # Remove "use_dataset_labelset=True" to have uniform metric name: df_plot.Metric = df_plot.Metric.str.replace("use_dataset_labelset=True,", "") # Add "optimal temperature" as another rescaling method, so that seaborn can # plot it as a third row: df_tau = df_plot[df_plot.rescaling_method == "temperature_scaling"].copy() df_tau.rescaling_method = "tau" df_tau.MetricValue = df_tau.tau_on_eval_data df_plot = pd.concat([df_plot, df_tau]) def subplot_fn(data, x, y, **kwargs): del kwargs ax = plt.gca() for marker in data.family_marker.unique(): data_sub = data[data.family_marker == marker] ax.scatter( data_sub[x], data_sub[y], s=plotting.model_to_scatter_size(data_sub.model_size), c=data_sub.family_index, cmap=cmap, vmin=0, vmax=len(family_order), marker=marker, linewidth=0, alpha=0.7, zorder=30, label=utils.assert_and_get_constant(data_sub.ModelFamily)) g = plotting.FacetGrid( data=df_plot, sharex=False, sharey=False, dropna=False, col="DatasetName", col_order=dataset_order, row="rescaling_method", row_order=rescaling_methods, height=height, aspect=aspect, margin_titles=True) g.map_dataframe(subplot_fn, x="downstream_error", y="MetricValue") g.set_titles(col_template="{col_name}", row_template="", size=mpl.rcParams["axes.titlesize"]) for ax in g.axes.flat: plotting.show_spines(ax) ax.grid(True, axis="both") ax.grid(True, axis="both", which="minor") ax.set_xlim(left=0.1) ax.set_ylim(bottom=0.0) ax.xaxis.set_minor_locator(mpl.ticker.MultipleLocator(0.1)) if rescaling_methods[plotting.row_num(ax)] != "tau": ax.set_yticks([], minor=True) ax.yaxis.set_minor_locator(mpl.ticker.MultipleLocator(0.02)) if dataset_order[plotting.col_num( ax)] == "imagenet(split='validation[20%:]')": if rescaling_methods[plotting.row_num(ax)] == "none": ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.04)) if dataset_order[plotting.col_num( ax)] == "imagenet_v2(variant='MATCHED_FREQUENCY')": ax.set_ylim(top=0.14) if rescaling_methods[plotting.row_num(ax)] == "none": ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.02)) if rescaling_methods[plotting.row_num(ax)] == "temperature_scaling": ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.02)) if dataset_order[plotting.col_num(ax)] == "imagenet_r": ax.set_ylim(top=0.3) ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.1)) if dataset_order[plotting.col_num( ax)] == "imagenet_v2(variant='MATCHED_FREQUENCY')": ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.1)) if dataset_order[plotting.col_num(ax)] == "imagenet_a": if rescaling_methods[plotting.row_num(ax)] != "tau": ax.set_ylim(top=0.6) ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.2)) # Labels: if ax.is_first_col(): if rescaling_methods[plotting.row_num(ax)] == "none": ax.set_ylabel(display.YLABEL_ECE_UNSCALED) elif rescaling_methods[plotting.row_num(ax)] == "temperature_scaling": ax.set_ylabel(display.YLABEL_ECE_TEMP_SCALED) elif rescaling_methods[plotting.row_num(ax)] == "beta_scaling": ax.set_ylabel(display.YLABEL_ECE_BETA_SCALED) elif rescaling_methods[plotting.row_num(ax)] == "tau": ax.set_ylabel(display.YLABEL_TEMP_FACTOR) else: ax.set_ylabel(rescaling_methods[plotting.row_num(ax)]) if ax.is_last_row(): ax.set_xlabel(display.XLABEL_CLASSIFICATION_ERROR) else: ax.set_xticklabels("") if add_metric_description: plotting.add_metric_description_title(df_plot, g.fig, y=1.05) plotting.apply_to_fig_text(g.fig, display.prettify) g.fig.tight_layout(pad=0) g.fig.subplots_adjust(wspace=0.2, hspace=0.2) for ax in g.axes.flat: if rescaling_methods[plotting.row_num(ax)] == "tau": ax.set_ylim(0.5, 2.5) plotting.annotate_confidence_plot(ax) if add_legend: handles, labels = plotting.get_model_family_legend( g.axes.flat[0], family_order) if legend_position == "below": # Model family legend below plot: legend = g.axes.flat[0].legend( handles=handles, labels=labels, loc="upper center", title="Model family", bbox_to_anchor=(0.5, 0.0), frameon=True, bbox_transform=g.fig.transFigure, ncol=len(family_order), handletextpad=0.1) elif legend_position == "right": # Model family legend next to plot: legend = g.axes.flat[0].legend( handles=handles, labels=labels, loc="center left", title="Model family", bbox_to_anchor=(1.00, 0.53), frameon=True, bbox_transform=g.fig.transFigure, ncol=1, handletextpad=0.1) legend.get_frame().set_linewidth(mpl.rcParams["axes.linewidth"]) legend.get_frame().set_edgecolor("lightgray") plotting.apply_to_fig_text(g.fig, display.prettify) return g
def plot_alternative_metrics( df_main: pd.DataFrame, xs: List[str], ys: List[str], gce_prefix: str = plotting.STD_GCE_PREFIX, add_legend: bool = True, add_metric_description: bool = False, ) -> sns.axisgrid.FacetGrid: """Plots acc/calib for ImageNet and natural OOD datasets.""" # Settings: rescaling_method = "temperature_scaling" family_order = display.get_model_families_sorted() dataset_order = (["imagenet(split='validation[20%:]')"] + display.OOD_DATASET_ORDER) # Select data: mask = df_main.Metric.str.startswith(gce_prefix) mask &= df_main.ModelName.isin(display.get_standard_model_list()) mask &= df_main.DatasetName.isin(dataset_order) mask &= df_main.rescaling_method.isin([rescaling_method]) mask &= ~(df_main.DatasetName.str.startswith("imagenet_a") & (~df_main.use_dataset_labelset.eq(True))) mask &= ~(df_main.DatasetName.str.startswith("imagenet_r") & (~df_main.use_dataset_labelset.eq(True))) family_order = display.get_model_families_sorted() mask &= df_main.ModelFamily.isin(family_order) mask &= df_main.DatasetName.isin(dataset_order) df_plot = df_main[mask].copy() df_plot, cmap = display.add_display_data(df_plot, family_order) # Remove "use_dataset_labelset=True" to have uniform metric name: df_plot.Metric = df_plot.Metric.str.replace("use_dataset_labelset=True,", "") # Repeat dataframe with rows: df_orig = df_plot df_plot = pd.DataFrame() assert not ((len(set(xs)) > 1) and (len(set(ys)) > 1)), ( "One of x and y must contain a constant value.") varying_set = xs if (len(set(xs)) > 1) else ys for x, y, varying in zip(xs, ys, varying_set): df_here = df_orig.copy() df_here["x"] = x df_here["y"] = y df_here["row"] = varying df_plot = pd.concat([df_plot, df_here]) def get_residual(classification_error, metric): reg = linear_model.LinearRegression() metric = metric.to_numpy() not_nan = ~np.isnan(metric) x = classification_error.to_numpy()[:, None] reg.fit(x[not_nan, :], metric[not_nan]) return metric - reg.predict(x) def subplot_fn(data, **kwargs): del kwargs ax = plt.gca() x = ax.my_xlabel = utils.assert_and_get_constant(data.x) y = ax.my_ylabel = utils.assert_and_get_constant(data.y) if y.endswith("_residual"): # Plot residual w.r.t. accuracy: y = y.replace("_residual", "") data = data.copy() data.loc[:, y] = get_residual(data["downstream_error"], data[y]) for marker in data.family_marker.unique(): data_sub = data[data.family_marker == marker] ax.scatter( data_sub[x], data_sub[y], s=plotting.model_to_scatter_size(data_sub.model_size), c=data_sub.family_index, cmap=cmap, vmin=0, vmax=len(family_order), marker=marker, alpha=0.7, zorder=30, label=utils.assert_and_get_constant(data_sub.ModelFamily), linewidth=0.0) g = plotting.FacetGrid( data=df_plot, sharex=False, sharey=False, dropna=False, col="DatasetName", col_order=dataset_order, row="row", row_order=varying_set, height=1.05, aspect=1.3, margin_titles=True) g.map_dataframe(subplot_fn) g.set_titles(col_template="{col_name}", row_template="", size=mpl.rcParams["axes.titlesize"]) for ax in g.axes.flat: plotting.show_spines(ax) ax.xaxis.set_minor_locator(mpl.ticker.MultipleLocator(0.1)) col = plotting.col_num(ax) if dataset_order[col] == "imagenet_v2(variant='MATCHED_FREQUENCY')": ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.1)) if dataset_order[col] == "imagenet_r": ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.2)) if dataset_order[col] == "imagenet_a": ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.2)) if varying_set == xs: ax.set_xlabel(ax.my_xlabel) if ax.is_first_col(): ax.set_ylabel(ax.my_ylabel) elif varying_set == ys: if ax.is_last_row(): ax.set_xlabel(ax.my_xlabel) else: ax.set_xticklabels([]) if ax.is_first_col(): ax.set_ylabel(ax.my_ylabel) ax.grid(True, axis="x", which="minor") ax.grid(True, axis="both") if add_metric_description: plotting.add_metric_description_title(df_plot, g.fig, y=1.05) def prettify(x): x = display.prettify(x) if x == "downstream_error": return display.XLABEL_CLASSIFICATION_ERROR if x == "MetricValue": return display.YLABEL_ECE_TEMP_SCALED_SHORT if x == "brier": return "Brier score\n(temp.-scaled)" if x == "brier_div_error": return "Brier / class. error" if x == "brier_residual": return "Brier (residual)" if x == "nll": return "NLL\n(temp.-scaled)" if x == "nll_div_error": return "NLL / class. error" if x == "nll_residual": return "NLL (residual)" return x plotting.apply_to_fig_text(g.fig, prettify) g.fig.tight_layout() g.fig.subplots_adjust(wspace=0.4, hspace=0.1 if varying_set == ys else 0.5) # Model family legend below plot: if add_legend: handles, labels = plotting.get_model_family_legend( g.axes.flat[0], family_order) legend = g.axes.flat[0].legend( handles=handles, labels=labels, loc="upper center", title="Model family", bbox_to_anchor=(0.5, 0), frameon=True, bbox_transform=g.fig.transFigure, ncol=len(family_order), handletextpad=0.1) legend.get_frame().set_linewidth(mpl.rcParams["axes.linewidth"]) legend.get_frame().set_edgecolor("lightgray") plotting.apply_to_fig_text(g.fig, prettify) return g
def plot(df_main: pd.DataFrame, family_order: Optional[List[str]] = None, rescaling_method: str = "temperature_scaling", gce_prefix: str = plotting.STD_GCE_PREFIX) -> sns.axisgrid.FacetGrid: """Plots a regression of accuracy and calibration, split by model family.""" # Settings: if family_order is None: family_order = display.get_model_families_sorted() dataset_order = [ "imagenet(split='validation[20%:]')", "imagenet_v2(variant='MATCHED_FREQUENCY')", "imagenet_r", "imagenet_a", ] # Select data: mask = df_main.Metric.str.startswith(gce_prefix) mask &= df_main.ModelName.isin(display.get_standard_model_list()) mask &= df_main.rescaling_method == rescaling_method mask &= df_main.ModelFamily.isin(family_order) mask &= ~(df_main.DatasetName.str.startswith("imagenet_a") & ~df_main.use_dataset_labelset.eq(True)) mask &= ~(df_main.DatasetName.str.startswith("imagenet_r") & ~df_main.use_dataset_labelset.eq(True)) df_plot = df_main[mask].copy() # Select OOD datasets: mask = df_plot.DatasetName.isin(dataset_order) mask |= df_plot.DatasetName.str.startswith("imagenet_c") df_plot = df_plot[mask].copy() # Add plotting-related data: df_plot, _ = display.add_display_data(df_plot, family_order) df_plot["dataset_group"] = "others" im_c_mask = df_plot.DatasetName.str.startswith("imagenet_c") df_plot.loc[im_c_mask, "dataset_group"] = "imagenet_c" dataset_group_order = ["imagenet_c", "others"] def subplot_fn(**kwargs): x = kwargs["x"] y = kwargs["y"] data = kwargs["data"] utils.assert_no_duplicates_in_condition( data, group_by=["DatasetName", "ModelName"]) ax = plt.gca() kwargs["color"] = utils.assert_and_get_constant(data.family_color) ax = sns.regplot(ax=ax, **kwargs) plotter = RegressionPlotter(x, y, data=data) # Get regression parameters and show in plot: grid = np.linspace(data[x].min(), data[x].max(), 100) beta_plot, beta_boots = plotter.get_params(grid) beta_plot = np.array(beta_plot) beta_boots = np.array(beta_boots) intercept = 10 ** np.median(beta_boots[0, :]) intercept_ci = 10 ** sns_utils.ci(beta_boots[0, :]) slope = np.median(beta_boots[1, :]) slope_ci = sns_utils.ci(beta_boots[1, :]) s = (f"a = {intercept:1.2f} ({intercept_ci[0]:1.2f}, " f"{intercept_ci[1]:1.2f})\nk = {slope:1.2f} ({slope_ci[0]:1.2f}, " f"{slope_ci[1]:1.2f})") ax.text(0.04, 0.96, s, va="top", ha="left", transform=ax.transAxes, fontsize=4, color=(0.3, 0.3, 0.3), bbox=dict(facecolor="w", alpha=0.8, boxstyle="square,pad=0.1")) df_plot["MetricValue_log"] = np.log10(df_plot.MetricValue) df_plot["downstream_error_log"] = np.log10(df_plot.downstream_error) g = plotting.FacetGrid( data=df_plot, sharex=False, sharey=False, col="ModelFamily", col_order=family_order, hue="ModelFamily", hue_order=family_order, row="dataset_group", row_order=dataset_group_order, height=1.0, aspect=0.9,) g.map_dataframe(subplot_fn, x="downstream_error_log", y="MetricValue_log", scatter_kws={"alpha": 0.5, "linewidths": 0.0, "s": 2}) g.set_titles(template="{col_name}", size=mpl.rcParams["axes.titlesize"]) for ax in g.axes.flat: plotting.show_spines(ax) ax.set_xlim(np.log10(0.1), np.log10(1.0)) xticks = np.arange(0.1, 1.0 + 0.001, 0.1) ax.set_xticks(np.log10(xticks)) if ax.is_last_row(): show = [0.1, 0.2, 0.4, 1.0] xticklabels = [f"{x:0.1f}"if x in show else "" for x in xticks] ax.set_xticklabels(xticklabels) ax.set_title("") else: ax.set_xticklabels([]) ax.set_ylim(np.log10(0.01), np.log10(0.8)) yticks = np.arange(0.05, 0.8 + 0.001, 0.05) ax.set_yticks(np.log10(yticks)) if ax.is_first_col(): show = [0.1, 0.2, 0.4, 0.8] yticklabels = [f"{x:0.1f}"if x in show else "" for x in yticks] ax.set_yticklabels(yticklabels) else: ax.set_yticklabels([]) ax.grid(True, axis="both") # Labels: if ax.is_last_row(): ax.set_xlabel(display.XLABEL_CLASSIFICATION_ERROR) if ax.is_first_col(): ax.set_ylabel(display.YLABEL_ECE_TEMP_SCALED_SHORT) plotting.apply_to_fig_text(g.fig, display.prettify) g.fig.tight_layout(w_pad=0) return g
def plot_error_increase_vs_model_size( df_main: pd.DataFrame, compact_layout: bool = True, gce_prefix=plotting.STD_GCE_PREFIX) -> sns.axisgrid.FacetGrid: """Plots acc/calib for ImageNet-C.""" # Settings: rescaling_methods = ["temperature_scaling"] if compact_layout: family_order = ["mixer", "vit", "bit"] else: family_order = display.get_model_families_sorted() family_order.remove( "alexnet") # Remove AlexNet bc. it only has one size. df_plot = _get_data(df_main, gce_prefix, family_order) df_plot = df_plot[df_plot.rescaling_method.isin(rescaling_methods)].copy() # Add downstream_error as another "rescaling method", so that seaborn can # plot it as a separate row: df_tau = df_plot.copy() df_tau.rescaling_method = "downstream_error" df_tau.MetricValue = df_tau.downstream_error df_plot = pd.concat([df_plot, df_tau]) rescaling_methods = ["temperature_scaling", "downstream_error"] df_plot = utils.average_imagenet_c_corruption_types( df_plot, group_by=["ModelName", "Metric", "severity", "rescaling_method"]) # Normalize per severity: rescaling_methods = df_plot.rescaling_method.unique() datasets = df_plot.DatasetName.unique() df_plot["relative_metric_value"] = np.nan for family in family_order: for severity in df_plot.severity.unique(): for method in rescaling_methods: for dataset_ in datasets: mask = df_plot.severity == severity mask &= df_plot.rescaling_method == method mask &= df_plot.DatasetName == dataset_ mask &= df_plot.ModelFamily == family df_masked = df_plot[mask].copy() if not df_masked.shape[0]: continue size_mask = df_masked.model_size == df_masked.model_size.max( ) largest_model_value = float(df_masked.loc[size_mask, "MetricValue"]) df_plot.loc[mask, "relative_metric_value"] = ( df_masked.MetricValue - largest_model_value) cmap = sns.color_palette("flare", n_colors=df_plot.severity.nunique(), as_cmap=True) def subplot_fn(data, x, y, **kwargs): del kwargs ax = plt.gca() for condition in np.unique(data["model_size"]): data_sub = data[data["model_size"] == condition].copy() data_sub = data_sub.sort_values(by=data_sub.columns.to_list()) # Plot data: ax.scatter(data_sub[x], data_sub[y], s=plotting.model_to_scatter_size( data_sub.model_size, 0.75), c=data_sub.severity, cmap=cmap, alpha=0.7, zorder=30, label="data", vmin=0, vmax=5, linewidth=0) # White dots to mask lines: ax.scatter(data_sub[x], data_sub[y], s=plotting.model_to_scatter_size( data_sub.model_size, 0.75), c="w", alpha=1.0, zorder=20, linewidth=1.5, label="_hidden") # Lines to connect dots: ax.plot(data_sub[x], data_sub[y], "-", color="gray", alpha=0.7, zorder=10, linewidth=0.75) g = plotting.FacetGrid(data=df_plot, sharex=True, sharey=False, dropna=False, col="ModelFamily", col_order=family_order, row="rescaling_method", row_order=rescaling_methods, height=1.1, margin_titles=True, aspect=0.8) g.map_dataframe(subplot_fn, x="severity", y="relative_metric_value") g.set_titles(col_template="{col_name}", row_template="", size=mpl.rcParams["axes.titlesize"]) # Format axes: for ax in g.axes.flatten(): ax.set_xlim(-0.9, 5.9) plotting.show_spines(ax) ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(1)) ax.yaxis.set_minor_locator(mpl.ticker.MultipleLocator(0.02)) if ax.is_first_col() and plotting.row_num(ax) == 0: ax.set_ylabel("Classification error\n(Δ to largest model)") elif ax.is_first_col() and plotting.row_num(ax) == 1: ax.set_ylabel("ECE\n(Δ to largest model)") else: ax.set_yticklabels([]) if ax.is_first_row(): ax.set_title(family_order[plotting.col_num(ax)]) ax.set_ylim(-0.05 if compact_layout else -0.1, 0.3) ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.1)) if ax.is_last_row(): ax.set_xlabel("Corruption\nseverity") ax.set_ylim(-0.01 if compact_layout else -0.04, 0.03) ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.02)) ax.grid(axis="both") ax.grid(True, axis="y", which="minor") plotting.apply_to_fig_text(g.fig, display.prettify) g.fig.tight_layout() g.fig.subplots_adjust(wspace=0.15) return g
def plot(df_main: pd.DataFrame, gce_prefix: str = plotting.STD_GCE_PREFIX, rescaling_method: str = "temperature_scaling", add_guo: bool = False) -> mpl.figure.Figure: """Plots acc/calib and reliability diagrams on clean ImageNet (Figure 1).""" rescaling_methods = ["none", rescaling_method] family_order = display.get_model_families_sorted() if add_guo: family_order.append("guo") # Set up figure: fig = plt.figure(figsize=(display.FULL_WIDTH/2, 1.4)) spec = fig.add_gridspec(ncols=2, nrows=1) for ax_i, rescaling_method in enumerate(rescaling_methods): df_plot, cmap = _get_data(df_main, gce_prefix, family_order, rescaling_methods=[rescaling_method]) ax = fig.add_subplot(spec[:, ax_i]) big_ax = ax for i, family in enumerate(family_order): if family == "guo": continue data_sub = df_plot[df_plot.ModelFamily == family] if data_sub.empty: continue ax.scatter( data_sub["downstream_error"], data_sub["MetricValue"], s=plotting.model_to_scatter_size(data_sub.model_size), c=data_sub.family_index, cmap=cmap, vmin=0, vmax=len(family_order), marker=utils.assert_and_get_constant(data_sub.family_marker), alpha=0.7, linewidth=0.0, zorder=100 - i, # Z-order is same as model family order. label=family) # Manually add Guo et al data: # From Table 1 and Table S2 in https://arxiv.org/pdf/1706.04599.pdf. # First model is DenseNet161, second is ResNet152. if add_guo: size = plotting.model_to_scatter_size(1) color = [len(family_order) - 1] * 2 marker = "x" if rescaling_method == "none": ax.scatter([0.2257, 0.2231], [0.0628, 0.0548], s=size, c=color, marker=marker, alpha=0.7, label="guo") if rescaling_method == "temperature_scaling": ax.scatter([0.2257, 0.2231], [0.0199, 0.0186], s=size, c=color, marker=marker, alpha=0.7, label="guo") plotting.show_spines(ax) # Aspect ratios are tuned manually for display in the paper: ax.set_anchor("N") ax.grid(False, which="minor") ax.grid(True, axis="both") ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.1)) ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.01)) ax.set_ylim(bottom=0.01, top=0.09) ax.set_xlim(0.05, 0.5) ax.set_xlabel(display.XLABEL_INET_ERROR) if rescaling_method == "none": ax.set_title("Unscaled") elif rescaling_method == "temperature_scaling": ax.set_title("Temperature-scaled") ax.set_yticklabels("") if ax.is_first_col(): ax.set_ylabel("ECE") # Model family legend: handles, labels = plotting.get_model_family_legend(big_ax, family_order) legend = big_ax.legend( handles=handles, labels=labels, loc="upper right", frameon=True, labelspacing=0.3, handletextpad=0.1, borderpad=0.3, fontsize=4) legend.get_frame().set_linewidth(mpl.rcParams["axes.linewidth"]) legend.get_frame().set_edgecolor("lightgray") plotting.apply_to_fig_text(fig, display.prettify) plotting.apply_to_fig_text(fig, lambda x: x.replace("EfficientNet", "EffNet")) return fig