def plot(
    df_main: pd.DataFrame,
    gce_prefix: str = plotting.STD_GCE_PREFIX,
    plot_confidence: bool = False,
) -> sns.axisgrid.FacetGrid:
    """Shows that bias can depend on accuracy."""
    # Select data:
    mask = df_main.Metric.str.startswith(gce_prefix.split("num_bins")[0])
    mask &= (df_main.ModelName.str.startswith("jft-r50-x1")
             | df_main.ModelName.str.startswith("jft-r101-x3"))
    mask &= df_main.DatasetName == "imagenet(split='validation[20%:]')"
    mask &= df_main.rescaling_method.isin(["none", "temperature_scaling"])
    mask &= df_main.RawModelName.str.contains("size")
    mask &= df_main.RawModelName.str.contains("steps")
    df_plot = df_main[mask].copy()

    df_plot["size"] = df_plot.RawModelName.map(
        lambda name: int(re.search(r"size=(\d+)", name).groups()[0]))
    df_plot["steps"] = df_plot.RawModelName.map(
        lambda name: int(re.search(r"steps=(\d+)", name).groups()[0]))
    df_plot["num_bins"] = df_plot.Metric.map(lambda metric: int(
        re.search(r"(?<=num_bins=)(\d+)", metric).groups()[0]))
    df_plot = plotting.add_optimal_temperature_as_rescaling_method(df_plot)

    # Remove outlier runs:
    df_plot = df_plot[df_plot.steps != 457032]

    color = sns.color_palette("colorblind", n_colors=1)[0]

    def subplot_fn_all_models_by_size(data, x, y, **kwargs):
        del kwargs
        ax = plt.gca()
        # Plot data:
        ax.scatter(data[x],
                   data[y],
                   s=3,
                   c=color,
                   alpha=0.75,
                   zorder=30,
                   label="data",
                   linewidth=0)

    col_order = [15, 100, 5000]
    if plot_confidence:
        row_order = ["none", "temperature_scaling", "tau"]
    else:
        row_order = ["temperature_scaling"]

    g = plotting.FacetGrid(data=df_plot,
                           sharex=False,
                           sharey=False,
                           dropna=False,
                           col="num_bins",
                           col_order=col_order,
                           row="rescaling_method",
                           row_order=row_order,
                           height=1.2,
                           aspect=0.8,
                           margin_titles=True)

    g.map_dataframe(subplot_fn_all_models_by_size,
                    x="imagenet_error",
                    y="MetricValue")
    g.set_titles(col_template="{col_name:1.0f} bins", row_template="")

    # Format axes:
    for ax in g.axes.flatten():
        if ax.is_first_row():
            num_bins = col_order[plotting.col_num(ax)]
            ax.set_title(
                f"{num_bins} bins \n({40000/num_bins:1.0f} points/bin)",
                fontsize=mpl.rcParams["axes.labelsize"])

        if not ax.is_last_row():
            ax.set_xticklabels([])

        if plotting.row_num(ax) < 2:
            # Set yaxis range to be the same:
            y = ax.collections[0].get_offsets()[:, 1]
            midpoint = np.mean([y.min(), y.max()])
            ax.set_ylim(midpoint - 0.065 / 2, midpoint + 0.065 / 2)
            ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.02))

        if row_order[plotting.row_num(ax)] == "none":
            # ax.set_ylim(0, 0.075)
            if ax.is_first_col():
                ax.set_ylabel(display.YLABEL_ECE_UNSCALED)

        elif row_order[plotting.row_num(ax)] == "temperature_scaling":
            if ax.is_first_col():
                ax.set_ylabel(display.YLABEL_ECE_TEMP_SCALED_SHORT)

        else:  # This is the over/underconfidence row.
            ax.set_ylim(0.85, 1.05)
            ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.05))
            plotting.annotate_confidence_plot(ax)
            if ax.is_first_col():
                ax.set_ylabel("Optimal\ntemp. factor")

        ax.set_xlim(0.15, 0.35)
        ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.1))
        plotting.show_spines(ax)
        if ax.is_last_row():
            ax.set_xlabel("Classification\nerror")
        ax.grid(axis="both")

    g.fig.tight_layout()
    return g
def plot(df_main: pd.DataFrame,
         df_reliability: pd.DataFrame,
         gce_prefix: str = plotting.STD_GCE_PREFIX,
         rescaling_method: str = "temperature_scaling",
         add_guo: bool = False) -> mpl.figure.Figure:
    """Plots acc/calib and reliability diagrams on clean ImageNet (Figure 1)."""

    family_order = display.get_model_families_sorted()
    if add_guo:
        family_order.append("guo")

    if rescaling_method == "both":
        rescaling_methods = ["none", "temperature_scaling"]
    else:
        rescaling_methods = [rescaling_method]

    # Set up figure:
    fig = plt.figure(figsize=(display.FULL_WIDTH, 1.6))
    if rescaling_method == "both":
        widths = [1.75, 1.75, 1, 1, 1]
    else:
        widths = [1.8, 1, 1, 1, 1, 1]
    heights = [0.3, 1]
    spec = fig.add_gridspec(ncols=len(widths),
                            nrows=len(heights),
                            width_ratios=widths,
                            height_ratios=heights)

    # First panels: acc vs calib ImageNet:
    for ax_i, rescaling_method in enumerate(rescaling_methods):

        df_plot, cmap = _get_data(df_main,
                                  gce_prefix,
                                  family_order,
                                  rescaling_methods=[rescaling_method])
        ax = fig.add_subplot(spec[0:2, ax_i], box_aspect=1.0)
        big_ax = ax
        for i, family in enumerate(family_order):
            if family == "guo":
                continue
            data_sub = df_plot[df_plot.ModelFamily == family]
            if data_sub.empty:
                continue
            ax.scatter(
                data_sub["downstream_error"],
                data_sub["MetricValue"],
                s=plotting.model_to_scatter_size(data_sub.model_size),
                c=data_sub.family_index,
                cmap=cmap,
                vmin=0,
                vmax=len(family_order),
                marker=utils.assert_and_get_constant(data_sub.family_marker),
                alpha=0.7,
                linewidth=0.0,
                zorder=100 - i,  # Z-order is same as model family order.
                label=family)

        # Manually add Guo et al data:
        # From Table 1 and Table S2 in https://arxiv.org/pdf/1706.04599.pdf.
        # First model is DenseNet161, second is ResNet152.
        if add_guo:
            size = plotting.model_to_scatter_size(1)
            color = [len(family_order) - 1] * 2
            marker = "x"
            if rescaling_method == "none":
                ax.scatter([0.2257, 0.2231], [0.0628, 0.0548],
                           s=size,
                           c=color,
                           marker=marker,
                           alpha=0.7,
                           label="guo")
            if rescaling_method == "temperature_scaling":
                ax.scatter([0.2257, 0.2231], [0.0199, 0.0186],
                           s=size,
                           c=color,
                           marker=marker,
                           alpha=0.7,
                           label="guo")

        plotting.show_spines(ax)
        ax.set_anchor("N")
        ax.grid(False, which="minor")
        ax.grid(True, axis="both")
        ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.1))
        ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.01))
        ax.set_ylim(bottom=0.01, top=0.09)
        ax.set_xlim(0.05, 0.55)
        ax.set_xlabel(display.XLABEL_INET_ERROR)
        if len(rescaling_methods) == 1:  # Showing just one rescaling method.
            if rescaling_method == "none":
                ax.set_ylabel(display.YLABEL_ECE_UNSCALED)
            elif rescaling_method == "temperature_scaling":
                ax.set_ylabel(display.YLABEL_ECE_TEMP_SCALED)
        else:  # Showing both rescaling methods.
            if rescaling_method == "none":
                ax.set_title("Unscaled")
            elif rescaling_method == "temperature_scaling":
                ax.set_title("Temperature-scaled")
            if ax.is_first_col():
                ax.set_ylabel("ECE")

    # Remaining panels: Reliability diagrams:
    offset = len(rescaling_methods)
    model_names = [
        "mixer/jft-300m/H/14",
        "vit-h/14",
        "bit-jft-r152-x4-480",
    ]
    if offset == 1:
        model_names += ["wsl_32x48d", "simclr-4x-fine-tuned-100"]
    dataset_name = "imagenet(split='validation[20%:]')"

    for i, model_name in enumerate(model_names):
        # Get predictions:
        mask = df_main.ModelName == model_name
        mask &= df_main.rescaling_method == "none"
        mask &= df_main.Metric == "accuracy"
        mask &= df_main.DatasetName == dataset_name
        raw_model_name = df_main[mask].RawModelName
        assert len(raw_model_name) <= 1, df_main[mask]
        if len(raw_model_name) == 0:  # pylint: disable=g-explicit-length-test
            continue

        binned = _get_binned_reliability_data(
            df_reliability, utils.assert_and_get_constant(raw_model_name),
            dataset_name)

        rel_ax = fig.add_subplot(spec[1, i + offset])
        _plot_confidence_and_reliability(conf_ax=fig.add_subplot(spec[0, i +
                                                                      offset]),
                                         rel_ax=rel_ax,
                                         binned=binned,
                                         model_name=model_name,
                                         first_col=offset)

        if rescaling_method == "none":
            rel_ax.set_xlabel("Confidence\n(unscaled)")
        elif rescaling_method == "temperature_scaled":
            rel_ax.set_xlabel("Confidence\n(temp. scaled)")

    # Model family legend:
    handles, labels = plotting.get_model_family_legend(big_ax, family_order)
    legend = big_ax.legend(handles=handles,
                           labels=labels,
                           loc="upper right",
                           frameon=True,
                           labelspacing=0.25,
                           handletextpad=0.1,
                           borderpad=0.3,
                           fontsize=4)
    legend.get_frame().set_linewidth(mpl.rcParams["axes.linewidth"])
    legend.get_frame().set_edgecolor("lightgray")

    plotting.apply_to_fig_text(fig, display.prettify)

    offset = 0.05
    for ax in fig.axes[1:]:
        box = ax.get_position()
        box.x0 += offset
        box.x1 += offset
        ax.set_position(box)

    return fig
Beispiel #3
0
def plot(df_main: pd.DataFrame,
         upstream_datasets: Sequence[str] = ("imagenet21k", "jft"),
         gce_prefix: str = plotting.STD_GCE_PREFIX) -> sns.axisgrid.FacetGrid:
    """Plot acc. and calibration with changing pretraining data size and steps."""
    df_plot = _get_data(df_main, upstream_datasets, gce_prefix)
    df_plot = df_plot[df_plot.rescaling_method == "temperature_scaling"]

    sns.set_style("ticks")
    g = sns.FacetGrid(
        data=df_plot,
        row="upstream_dataset",
        row_order=upstream_datasets,
        col="varying_key",
        col_order=["size", "steps"],
        sharex=True,
        sharey=False,
        dropna=False,
        height=1.0,
        aspect=1.5,
        margin_titles=True,
    )

    g.map_dataframe(subplot_fn, x="imagenet_error", y="MetricValue")

    g.set_titles(row_template="", col_template="{col_name} varies")

    for ax in g.axes.flat:
        plotting.show_spines(ax)
        # Limits:
        ax.set_xlim(0.1, 0.35)
        ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.05))
        if plotting.row_num(ax) < 2:
            ax.set_ylim(0, 0.05)

        # Labels:
        if ax.is_first_col():
            if plotting.row_num(ax) == 0: ax.set_ylabel("ECE\n(temp.-scaled)")
            if plotting.row_num(ax) == 1: ax.set_ylabel("ECE\n(temp.-scaled)")

        if ax.is_last_row():
            ax.set_xlabel("Classification error")
        ax.grid(True, axis="both", zorder=-2000)

        if upstream_datasets[plotting.row_num(ax)] == "jft":
            ax.text(0.105,
                    0.002,
                    "Pretraining: JFT",
                    ha="left",
                    va="bottom",
                    color="gray",
                    fontsize=5)
        if upstream_datasets[plotting.row_num(ax)] == "imagenet21k":
            ax.text(0.105,
                    0.002,
                    "Pretraining: ImageNet-21k",
                    ha="left",
                    va="bottom",
                    color="gray",
                    fontsize=5)

        # Legend:
        if ax.is_first_row():
            _add_legend(ax, df_plot)

    plotting.apply_to_fig_text(g.fig, display.prettify)
    g.fig.tight_layout()
    return g
def plot(df_main: pd.DataFrame,
         gce_prefix: str = plotting.STD_GCE_PREFIX,
         rescaling_method: str = "temperature_scaling",
         add_guo: bool = False) -> mpl.figure.Figure:
    """Plots acc/calib and reliability diagrams on clean ImageNet (Figure 1)."""
    rescaling_methods = ["none", rescaling_method]
    family_order = display.get_model_families_sorted(
        ["mixer", "vit", "bit", "simclr"])
    if add_guo:
        family_order.append("guo")

    # Set up figure:
    fig = plt.figure(figsize=(display.FULL_WIDTH / 2, 2))
    spec = fig.add_gridspec(ncols=3, nrows=2)

    for col, bit_version in enumerate(
        ["BiT-ImageNet", "BiT-ImageNet21k", "BiT-JFT"]):
        # pylint: disable=g-long-lambda
        if bit_version == "BiT-ImageNet":
            display.get_standard_model_list = lambda: [
                m for m in display.MODEL_SIZE.keys() if not (m.startswith(
                    "bit-imagenet21k-") or m.startswith("bit-jft-"))
            ]
        elif bit_version == "BiT-ImageNet21k":
            display.get_standard_model_list = lambda: [
                m for m in display.MODEL_SIZE.keys() if not (m.startswith(
                    "bit-imagenet-") or m.startswith("bit-jft-"))
            ]
        elif bit_version == "BiT-JFT":
            display.get_standard_model_list = lambda: [
                m for m in display.MODEL_SIZE.keys() if not (m.startswith(
                    "bit-imagenet-") or m.startswith("bit-imagenet21k-"))
            ]
        else:
            raise ValueError(f"Unknown BiT version: {bit_version}")
        # pylint: enable=g-long-lambda

        for row, rescaling_method in enumerate(rescaling_methods):

            df_plot, cmap = _get_data(df_main,
                                      gce_prefix,
                                      family_order,
                                      rescaling_methods=[rescaling_method])
            ax = fig.add_subplot(spec[row, col])
            big_ax = ax
            for i, family in enumerate(family_order):
                if family == "guo":
                    continue
                data_sub = df_plot[df_plot.ModelFamily == family]
                if data_sub.empty:
                    continue
                ax.scatter(
                    data_sub["downstream_error"],
                    data_sub["MetricValue"],
                    s=plotting.model_to_scatter_size(data_sub.model_size),
                    c=data_sub.family_index,
                    cmap=cmap,
                    vmin=0,
                    vmax=len(family_order),
                    marker=utils.assert_and_get_constant(
                        data_sub.family_marker),
                    linewidth=0.5,
                    alpha=1.0 if "bit" in family else 0.5,
                    zorder=100 - i,  # Z-order is same as model family order.
                    label=family)

            # Manually add Guo et al data:
            # From Table 1 and Table S2 in https://arxiv.org/pdf/1706.04599.pdf.
            # First model is DenseNet161, second is ResNet152.
            if add_guo:
                size = plotting.model_to_scatter_size(1)
                color = [len(family_order) - 1] * 2
                marker = "x"
                if rescaling_method == "none":
                    ax.scatter([0.2257, 0.2231], [0.0628, 0.0548],
                               s=size,
                               c=color,
                               marker=marker,
                               alpha=0.7,
                               label="guo")
                if rescaling_method == "temperature_scaling":
                    ax.scatter([0.2257, 0.2231], [0.0199, 0.0186],
                               s=size,
                               c=color,
                               marker=marker,
                               alpha=0.7,
                               label="guo")

            plotting.show_spines(ax)
            # Aspect ratios are tuned manually for display in the paper:
            ax.set_anchor("N")
            ax.grid(False, which="minor")
            ax.grid(True, axis="both")
            ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.01))
            ax.set_ylim(bottom=0.0, top=0.09)
            ax.set_xlim(0.05, 0.3)
            ax.set_xlabel(display.XLABEL_INET_ERROR)
            if ax.is_first_row():
                ax.set_title(bit_version, fontsize=6)
                ax.set_xlabel("")
                ax.set_xticklabels("")
            else:
                ax.set_ylim(bottom=0.0, top=0.05)
            if ax.is_first_col():
                if rescaling_method == "none":
                    ax.set_ylabel(display.YLABEL_ECE_UNSCALED)
                elif rescaling_method == "temperature_scaling":
                    ax.set_ylabel(display.YLABEL_ECE_TEMP_SCALED)
            else:
                ax.set_yticklabels("")

    fig.tight_layout(pad=0.5)

    # Model family legend:
    handles, labels = plotting.get_model_family_legend(big_ax, family_order)

    plotting.apply_to_fig_text(fig, display.prettify)
    plotting.apply_to_fig_text(fig,
                               lambda x: x.replace("EfficientNet", "EffNet"))

    legend = fig.axes[0].legend(handles=handles,
                                labels=labels,
                                loc="upper center",
                                title="Model family",
                                bbox_to_anchor=(0.55, -0.025),
                                frameon=True,
                                bbox_transform=fig.transFigure,
                                ncol=len(family_order),
                                handletextpad=0.1)
    legend.get_frame().set_linewidth(mpl.rcParams["axes.linewidth"])
    legend.get_frame().set_edgecolor("lightgray")
    plotting.apply_to_fig_text(fig, display.prettify)

    return fig
def plot(df_main: pd.DataFrame,
         gce_prefix: str = plotting.STD_GCE_PREFIX,
         rescaling_methods: Optional[List[str]] = None,
         plot_confidence: bool = False,
         add_legend: bool = True,
         legend_position: str = "right",
         height: float = 1.05,
         aspect: float = 1.2,
         add_metric_description: bool = False,
         ) -> sns.axisgrid.FacetGrid:
  """Plots acc/calib for ImageNet and natural OOD datasets."""
  # Settings:
  rescaling_methods = rescaling_methods or ["none", "temperature_scaling"]
  if plot_confidence:
    rescaling_methods += ["tau"]
  family_order = display.get_model_families_sorted()
  dataset_order = display.OOD_DATASET_ORDER
  # Select data:
  mask = df_main.Metric.str.startswith(gce_prefix)
  mask &= df_main.ModelName.isin(display.get_standard_model_list())
  mask &= df_main.DatasetName.isin(dataset_order)
  mask &= df_main.rescaling_method.isin(rescaling_methods)
  mask &= ~(df_main.DatasetName.str.startswith("imagenet_a") &
            (~df_main.use_dataset_labelset.eq(True)))
  mask &= ~(df_main.DatasetName.str.startswith("imagenet_r") &
            (~df_main.use_dataset_labelset.eq(True)))
  family_order = display.get_model_families_sorted()
  mask &= df_main.ModelFamily.isin(family_order)
  mask &= df_main.DatasetName.isin(dataset_order)
  df_plot = df_main[mask].copy()
  df_plot, cmap = display.add_display_data(df_plot, family_order)

  # Remove "use_dataset_labelset=True" to have uniform metric name:
  df_plot.Metric = df_plot.Metric.str.replace("use_dataset_labelset=True,", "")

  # Add "optimal temperature" as another rescaling method, so that seaborn can
  # plot it as a third row:
  df_tau = df_plot[df_plot.rescaling_method == "temperature_scaling"].copy()
  df_tau.rescaling_method = "tau"
  df_tau.MetricValue = df_tau.tau_on_eval_data
  df_plot = pd.concat([df_plot, df_tau])

  def subplot_fn(data, x, y, **kwargs):
    del kwargs
    ax = plt.gca()
    for marker in data.family_marker.unique():
      data_sub = data[data.family_marker == marker]
      ax.scatter(
          data_sub[x],
          data_sub[y],
          s=plotting.model_to_scatter_size(data_sub.model_size),
          c=data_sub.family_index,
          cmap=cmap,
          vmin=0,
          vmax=len(family_order),
          marker=marker,
          linewidth=0,
          alpha=0.7,
          zorder=30,
          label=utils.assert_and_get_constant(data_sub.ModelFamily))

  g = plotting.FacetGrid(
      data=df_plot,
      sharex=False,
      sharey=False,
      dropna=False,
      col="DatasetName",
      col_order=dataset_order,
      row="rescaling_method",
      row_order=rescaling_methods,
      height=height,
      aspect=aspect,
      margin_titles=True)

  g.map_dataframe(subplot_fn, x="downstream_error", y="MetricValue")

  g.set_titles(col_template="{col_name}", row_template="",
               size=mpl.rcParams["axes.titlesize"])

  for ax in g.axes.flat:
    plotting.show_spines(ax)
    ax.grid(True, axis="both")
    ax.grid(True, axis="both", which="minor")

    ax.set_xlim(left=0.1)
    ax.set_ylim(bottom=0.0)
    ax.xaxis.set_minor_locator(mpl.ticker.MultipleLocator(0.1))
    if rescaling_methods[plotting.row_num(ax)] != "tau":
      ax.set_yticks([], minor=True)
      ax.yaxis.set_minor_locator(mpl.ticker.MultipleLocator(0.02))
      if dataset_order[plotting.col_num(
          ax)] == "imagenet(split='validation[20%:]')":
        if rescaling_methods[plotting.row_num(ax)] == "none":
          ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.04))
      if dataset_order[plotting.col_num(
          ax)] == "imagenet_v2(variant='MATCHED_FREQUENCY')":
        ax.set_ylim(top=0.14)
        if rescaling_methods[plotting.row_num(ax)] == "none":
          ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.02))
        if rescaling_methods[plotting.row_num(ax)] == "temperature_scaling":
          ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.02))
      if dataset_order[plotting.col_num(ax)] == "imagenet_r":
        ax.set_ylim(top=0.3)
        ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.1))

    if dataset_order[plotting.col_num(
        ax)] == "imagenet_v2(variant='MATCHED_FREQUENCY')":
      ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.1))
    if dataset_order[plotting.col_num(ax)] == "imagenet_a":
      if rescaling_methods[plotting.row_num(ax)] != "tau":
        ax.set_ylim(top=0.6)
      ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.2))

    # Labels:
    if ax.is_first_col():
      if rescaling_methods[plotting.row_num(ax)] == "none":
        ax.set_ylabel(display.YLABEL_ECE_UNSCALED)
      elif rescaling_methods[plotting.row_num(ax)] == "temperature_scaling":
        ax.set_ylabel(display.YLABEL_ECE_TEMP_SCALED)
      elif rescaling_methods[plotting.row_num(ax)] == "beta_scaling":
        ax.set_ylabel(display.YLABEL_ECE_BETA_SCALED)
      elif rescaling_methods[plotting.row_num(ax)] == "tau":
        ax.set_ylabel(display.YLABEL_TEMP_FACTOR)
      else:
        ax.set_ylabel(rescaling_methods[plotting.row_num(ax)])

    if ax.is_last_row():
      ax.set_xlabel(display.XLABEL_CLASSIFICATION_ERROR)
    else:
      ax.set_xticklabels("")

  if add_metric_description:
    plotting.add_metric_description_title(df_plot, g.fig, y=1.05)

  plotting.apply_to_fig_text(g.fig, display.prettify)
  g.fig.tight_layout(pad=0)
  g.fig.subplots_adjust(wspace=0.2, hspace=0.2)

  for ax in g.axes.flat:
    if rescaling_methods[plotting.row_num(ax)] == "tau":
      ax.set_ylim(0.5, 2.5)
      plotting.annotate_confidence_plot(ax)

  if add_legend:
    handles, labels = plotting.get_model_family_legend(
        g.axes.flat[0], family_order)
    if legend_position == "below":
      # Model family legend below plot:
      legend = g.axes.flat[0].legend(
          handles=handles, labels=labels, loc="upper center",
          title="Model family", bbox_to_anchor=(0.5, 0.0), frameon=True,
          bbox_transform=g.fig.transFigure, ncol=len(family_order),
          handletextpad=0.1)
    elif legend_position == "right":
      # Model family legend next to plot:
      legend = g.axes.flat[0].legend(
          handles=handles, labels=labels, loc="center left",
          title="Model family", bbox_to_anchor=(1.00, 0.53), frameon=True,
          bbox_transform=g.fig.transFigure, ncol=1,
          handletextpad=0.1)
    legend.get_frame().set_linewidth(mpl.rcParams["axes.linewidth"])
    legend.get_frame().set_edgecolor("lightgray")
    plotting.apply_to_fig_text(g.fig, display.prettify)
  return g
Beispiel #6
0
def plot_with_confidence(
        df_main: pd.DataFrame,
        upstream_dataset: str = "jft",
        gce_prefix: str = plotting.STD_GCE_PREFIX) -> sns.axisgrid.FacetGrid:
    """Plots acc/calib for ImageNet-C."""
    df_plot = _get_data(df_main, [upstream_dataset], gce_prefix)
    rescaling_methods = ["none", "temperature_scaling", "tau"]

    sns.set_style("ticks")
    g = sns.FacetGrid(
        data=df_plot,
        row="rescaling_method",
        row_order=rescaling_methods,
        col="varying_key",
        col_order=["size", "steps"],
        sharex=True,
        sharey=False,
        dropna=False,
        height=1.0,
        aspect=1.5,
        margin_titles=True,
    )

    g.map_dataframe(subplot_fn, x="imagenet_error", y="MetricValue")

    g.set_titles(row_template="", col_template="{col_name} varies")

    for ax in g.axes.flat:
        plotting.show_spines(ax)
        # Limits:
        ax.set_xlim(0.15, 0.35)
        ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.05))
        ax.set_ylim(0, 0.05)

        # Labels:
        if ax.is_first_col():
            if rescaling_methods[plotting.row_num(ax)] == "none":
                ax.set_ylabel(display.YLABEL_ECE_UNSCALED)
            if rescaling_methods[plotting.row_num(
                    ax)] == "temperature_scaling":
                ax.set_ylabel(display.YLABEL_ECE_TEMP_SCALED_SHORT)
            if rescaling_methods[plotting.row_num(ax)] == "tau":
                ax.set_ylabel(display.YLABEL_TEMP_FACTOR_SHORT)

        if ax.is_last_row():
            ax.set_xlabel("Classification error")
        ax.grid(True, axis="both", zorder=-2000)

        # Legend:
        if ax.is_first_row():
            _add_legend(ax, df_plot)

    g.fig.suptitle(
        f"Pretraining dataset: {display.prettify(upstream_dataset)}", x=0.55)

    plotting.apply_to_fig_text(g.fig, display.prettify)
    g.fig.tight_layout()

    for ax in g.axes.flat:
        if rescaling_methods[plotting.row_num(ax)] == "tau":
            ax.set_ylim(0.85, 1.05)
            ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.05))
            plotting.annotate_confidence_plot(ax)
    return g
def plot_alternative_metrics(
    df_main: pd.DataFrame,
    xs: List[str], ys: List[str],
    gce_prefix: str = plotting.STD_GCE_PREFIX,
    add_legend: bool = True,
    add_metric_description: bool = False,
) -> sns.axisgrid.FacetGrid:
  """Plots acc/calib for ImageNet and natural OOD datasets."""
  # Settings:
  rescaling_method = "temperature_scaling"
  family_order = display.get_model_families_sorted()
  dataset_order = (["imagenet(split='validation[20%:]')"] +
                   display.OOD_DATASET_ORDER)

  # Select data:
  mask = df_main.Metric.str.startswith(gce_prefix)
  mask &= df_main.ModelName.isin(display.get_standard_model_list())
  mask &= df_main.DatasetName.isin(dataset_order)
  mask &= df_main.rescaling_method.isin([rescaling_method])
  mask &= ~(df_main.DatasetName.str.startswith("imagenet_a") &
            (~df_main.use_dataset_labelset.eq(True)))
  mask &= ~(df_main.DatasetName.str.startswith("imagenet_r") &
            (~df_main.use_dataset_labelset.eq(True)))
  family_order = display.get_model_families_sorted()
  mask &= df_main.ModelFamily.isin(family_order)
  mask &= df_main.DatasetName.isin(dataset_order)
  df_plot = df_main[mask].copy()
  df_plot, cmap = display.add_display_data(df_plot, family_order)

  # Remove "use_dataset_labelset=True" to have uniform metric name:
  df_plot.Metric = df_plot.Metric.str.replace("use_dataset_labelset=True,", "")

  # Repeat dataframe with rows:
  df_orig = df_plot
  df_plot = pd.DataFrame()
  assert not ((len(set(xs)) > 1) and (len(set(ys)) > 1)), (
      "One of x and y must contain a constant value.")
  varying_set = xs if (len(set(xs)) > 1) else ys
  for x, y, varying in zip(xs, ys, varying_set):
    df_here = df_orig.copy()
    df_here["x"] = x
    df_here["y"] = y
    df_here["row"] = varying
    df_plot = pd.concat([df_plot, df_here])

  def get_residual(classification_error, metric):
    reg = linear_model.LinearRegression()
    metric = metric.to_numpy()
    not_nan = ~np.isnan(metric)
    x = classification_error.to_numpy()[:, None]
    reg.fit(x[not_nan, :], metric[not_nan])
    return metric - reg.predict(x)

  def subplot_fn(data, **kwargs):
    del kwargs
    ax = plt.gca()
    x = ax.my_xlabel = utils.assert_and_get_constant(data.x)
    y = ax.my_ylabel = utils.assert_and_get_constant(data.y)
    if y.endswith("_residual"):
      # Plot residual w.r.t. accuracy:
      y = y.replace("_residual", "")
      data = data.copy()
      data.loc[:, y] = get_residual(data["downstream_error"], data[y])

    for marker in data.family_marker.unique():
      data_sub = data[data.family_marker == marker]
      ax.scatter(
          data_sub[x],
          data_sub[y],
          s=plotting.model_to_scatter_size(data_sub.model_size),
          c=data_sub.family_index,
          cmap=cmap,
          vmin=0,
          vmax=len(family_order),
          marker=marker,
          alpha=0.7,
          zorder=30,
          label=utils.assert_and_get_constant(data_sub.ModelFamily),
          linewidth=0.0)

  g = plotting.FacetGrid(
      data=df_plot,
      sharex=False,
      sharey=False,
      dropna=False,
      col="DatasetName",
      col_order=dataset_order,
      row="row",
      row_order=varying_set,
      height=1.05,
      aspect=1.3,
      margin_titles=True)

  g.map_dataframe(subplot_fn)

  g.set_titles(col_template="{col_name}", row_template="",
               size=mpl.rcParams["axes.titlesize"])

  for ax in g.axes.flat:
    plotting.show_spines(ax)
    ax.xaxis.set_minor_locator(mpl.ticker.MultipleLocator(0.1))

    col = plotting.col_num(ax)
    if dataset_order[col] == "imagenet_v2(variant='MATCHED_FREQUENCY')":
      ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.1))
    if dataset_order[col] == "imagenet_r":
      ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.2))
    if dataset_order[col] == "imagenet_a":
      ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.2))

    if varying_set == xs:
      ax.set_xlabel(ax.my_xlabel)
      if ax.is_first_col():
        ax.set_ylabel(ax.my_ylabel)
    elif varying_set == ys:
      if ax.is_last_row():
        ax.set_xlabel(ax.my_xlabel)
      else:
        ax.set_xticklabels([])
      if ax.is_first_col():
        ax.set_ylabel(ax.my_ylabel)
    ax.grid(True, axis="x", which="minor")
    ax.grid(True, axis="both")

  if add_metric_description:
    plotting.add_metric_description_title(df_plot, g.fig, y=1.05)

  def prettify(x):
    x = display.prettify(x)
    if x == "downstream_error": return display.XLABEL_CLASSIFICATION_ERROR
    if x == "MetricValue": return display.YLABEL_ECE_TEMP_SCALED_SHORT
    if x == "brier": return "Brier score\n(temp.-scaled)"
    if x == "brier_div_error": return "Brier / class. error"
    if x == "brier_residual": return "Brier (residual)"
    if x == "nll": return "NLL\n(temp.-scaled)"
    if x == "nll_div_error": return "NLL / class. error"
    if x == "nll_residual": return "NLL (residual)"
    return x

  plotting.apply_to_fig_text(g.fig, prettify)
  g.fig.tight_layout()
  g.fig.subplots_adjust(wspace=0.4, hspace=0.1 if varying_set == ys else 0.5)

  # Model family legend below plot:
  if add_legend:
    handles, labels = plotting.get_model_family_legend(
        g.axes.flat[0], family_order)
    legend = g.axes.flat[0].legend(
        handles=handles, labels=labels, loc="upper center",
        title="Model family", bbox_to_anchor=(0.5, 0), frameon=True,
        bbox_transform=g.fig.transFigure, ncol=len(family_order),
        handletextpad=0.1)
    legend.get_frame().set_linewidth(mpl.rcParams["axes.linewidth"])
    legend.get_frame().set_edgecolor("lightgray")
    plotting.apply_to_fig_text(g.fig, prettify)

  return g
def plot(df_main: pd.DataFrame,
         family_order: Optional[List[str]] = None,
         rescaling_method: str = "temperature_scaling",
         gce_prefix: str = plotting.STD_GCE_PREFIX) -> sns.axisgrid.FacetGrid:
  """Plots a regression of accuracy and calibration, split by model family."""
  # Settings:
  if family_order is None:
    family_order = display.get_model_families_sorted()
  dataset_order = [
      "imagenet(split='validation[20%:]')",
      "imagenet_v2(variant='MATCHED_FREQUENCY')",
      "imagenet_r",
      "imagenet_a",
  ]

  # Select data:
  mask = df_main.Metric.str.startswith(gce_prefix)
  mask &= df_main.ModelName.isin(display.get_standard_model_list())
  mask &= df_main.rescaling_method == rescaling_method
  mask &= df_main.ModelFamily.isin(family_order)
  mask &= ~(df_main.DatasetName.str.startswith("imagenet_a") &
            ~df_main.use_dataset_labelset.eq(True))
  mask &= ~(df_main.DatasetName.str.startswith("imagenet_r") &
            ~df_main.use_dataset_labelset.eq(True))
  df_plot = df_main[mask].copy()

  # Select OOD datasets:
  mask = df_plot.DatasetName.isin(dataset_order)
  mask |= df_plot.DatasetName.str.startswith("imagenet_c")
  df_plot = df_plot[mask].copy()

  # Add plotting-related data:
  df_plot, _ = display.add_display_data(df_plot, family_order)
  df_plot["dataset_group"] = "others"
  im_c_mask = df_plot.DatasetName.str.startswith("imagenet_c")
  df_plot.loc[im_c_mask, "dataset_group"] = "imagenet_c"
  dataset_group_order = ["imagenet_c", "others"]

  def subplot_fn(**kwargs):
    x = kwargs["x"]
    y = kwargs["y"]
    data = kwargs["data"]
    utils.assert_no_duplicates_in_condition(
        data, group_by=["DatasetName", "ModelName"])
    ax = plt.gca()
    kwargs["color"] = utils.assert_and_get_constant(data.family_color)
    ax = sns.regplot(ax=ax, **kwargs)
    plotter = RegressionPlotter(x, y, data=data)

    # Get regression parameters and show in plot:
    grid = np.linspace(data[x].min(), data[x].max(), 100)
    beta_plot, beta_boots = plotter.get_params(grid)
    beta_plot = np.array(beta_plot)
    beta_boots = np.array(beta_boots)
    intercept = 10 ** np.median(beta_boots[0, :])
    intercept_ci = 10 ** sns_utils.ci(beta_boots[0, :])
    slope = np.median(beta_boots[1, :])
    slope_ci = sns_utils.ci(beta_boots[1, :])
    s = (f"a = {intercept:1.2f} ({intercept_ci[0]:1.2f}, "
         f"{intercept_ci[1]:1.2f})\nk = {slope:1.2f} ({slope_ci[0]:1.2f}, "
         f"{slope_ci[1]:1.2f})")
    ax.text(0.04, 0.96, s, va="top", ha="left", transform=ax.transAxes,
            fontsize=4, color=(0.3, 0.3, 0.3),
            bbox=dict(facecolor="w", alpha=0.8, boxstyle="square,pad=0.1"))

  df_plot["MetricValue_log"] = np.log10(df_plot.MetricValue)
  df_plot["downstream_error_log"] = np.log10(df_plot.downstream_error)

  g = plotting.FacetGrid(
      data=df_plot,
      sharex=False,
      sharey=False,
      col="ModelFamily",
      col_order=family_order,
      hue="ModelFamily",
      hue_order=family_order,
      row="dataset_group",
      row_order=dataset_group_order,
      height=1.0,
      aspect=0.9,)

  g.map_dataframe(subplot_fn, x="downstream_error_log", y="MetricValue_log",
                  scatter_kws={"alpha": 0.5, "linewidths": 0.0, "s": 2})

  g.set_titles(template="{col_name}", size=mpl.rcParams["axes.titlesize"])

  for ax in g.axes.flat:
    plotting.show_spines(ax)
    ax.set_xlim(np.log10(0.1), np.log10(1.0))
    xticks = np.arange(0.1, 1.0 + 0.001, 0.1)
    ax.set_xticks(np.log10(xticks))
    if ax.is_last_row():
      show = [0.1, 0.2, 0.4, 1.0]
      xticklabels = [f"{x:0.1f}"if x in show else "" for x in xticks]
      ax.set_xticklabels(xticklabels)
      ax.set_title("")
    else:
      ax.set_xticklabels([])

    ax.set_ylim(np.log10(0.01), np.log10(0.8))
    yticks = np.arange(0.05, 0.8 + 0.001, 0.05)
    ax.set_yticks(np.log10(yticks))
    if ax.is_first_col():
      show = [0.1, 0.2, 0.4, 0.8]
      yticklabels = [f"{x:0.1f}"if x in show else "" for x in yticks]
      ax.set_yticklabels(yticklabels)
    else:
      ax.set_yticklabels([])

    ax.grid(True, axis="both")

    # Labels:
    if ax.is_last_row():
      ax.set_xlabel(display.XLABEL_CLASSIFICATION_ERROR)
    if ax.is_first_col():
      ax.set_ylabel(display.YLABEL_ECE_TEMP_SCALED_SHORT)
  plotting.apply_to_fig_text(g.fig, display.prettify)
  g.fig.tight_layout(w_pad=0)
  return g
def plot(
    df_main: pd.DataFrame,
    gce_prefix: str = plotting.STD_GCE_PREFIX,
    rescaling_methods: Optional[List[str]] = None,
    family_order: Optional[List[str]] = None,
    plot_confidence: bool = False,
    add_legend: bool = True,
    add_metric_description: bool = False,
) -> sns.axisgrid.FacetGrid:
    """Plots accuracy and calibration for ImageNet-C."""
    # Settings:
    rescaling_methods = rescaling_methods or ["none", "temperature_scaling"]
    if plot_confidence:
        rescaling_methods += ["tau"]
    family_order = family_order or ["mixer", "vit", "bit", "simclr", "wsl"]

    df_plot = _get_data(df_main, gce_prefix, family_order)

    # Add "optimal temperature" as another rescaling method, so that seaborn can
    # plot it as a third row:
    df_plot = plotting.add_optimal_temperature_as_rescaling_method(df_plot)

    df_plot["model_size"] = df_plot.ModelName.map(display.MODEL_SIZE)

    def subplot_fn(data, x, y, **kwargs):
        del kwargs
        ax = plt.gca()
        data = utils.average_imagenet_c_corruption_types(
            data, group_by=["severity", "model_size"])
        cmap = sns.color_palette("flare",
                                 n_colors=data.severity.nunique(),
                                 as_cmap=True)
        # Plot data:
        ax.scatter(data[x],
                   data[y],
                   s=plotting.model_to_scatter_size(data.model_size, 0.75),
                   c=data.severity,
                   linewidth=0,
                   cmap=cmap,
                   alpha=0.7,
                   zorder=30,
                   label="data")
        # White dots to mask lines:
        ax.scatter(data[x],
                   data[y],
                   s=plotting.model_to_scatter_size(data.model_size, 0.75),
                   c="w",
                   alpha=1.0,
                   zorder=20,
                   linewidth=1.5,
                   label="_hidden")
        # Lines to connect dots:
        for condition in np.unique(data["severity"]):
            data_sub = data[data["severity"] == condition].copy()
            data_sub = data_sub.sort_values(by=data_sub.columns.to_list())
            ax.plot(data_sub[x],
                    data_sub[y],
                    "-",
                    color="gray",
                    alpha=0.7,
                    zorder=10,
                    linewidth=1)

    g = plotting.FacetGrid(data=df_plot,
                           sharex=True,
                           sharey=False,
                           dropna=False,
                           row="rescaling_method",
                           row_order=rescaling_methods,
                           col="ModelFamily",
                           col_order=family_order,
                           height=1,
                           aspect=1,
                           margin_titles=True)

    g.map_dataframe(subplot_fn, x="downstream_error", y="MetricValue")

    g.set_titles(col_template="{col_name}",
                 row_template="",
                 size=mpl.rcParams["axes.titlesize"])

    # Format axes:
    for ax in g.axes.flat:
        if rescaling_methods[plotting.row_num(ax)] == "none":
            ax.set_ylim(0, 0.25)
            ax.set_yticks(np.arange(0, 0.272, 0.05))
            if ax.is_first_col():
                ax.set_ylabel(display.YLABEL_ECE_UNSCALED)

        elif rescaling_methods[plotting.row_num(ax)] == "temperature_scaling":
            ax.set_ylim(0, 0.15)
            ax.set_yticks(np.arange(0, 0.172, 0.05))
            if ax.is_first_col():
                ax.set_ylabel(display.YLABEL_ECE_TEMP_SCALED)

        elif rescaling_methods[plotting.row_num(ax)] == "tau":
            ax.set_ylim(0.6, 1.6)
            plotting.annotate_confidence_plot(ax)
            if ax.is_first_col():
                ax.set_ylabel(display.YLABEL_TEMP_FACTOR)

        ax.set_xlim(0.0, 0.8)
        ax.set_xticks(np.arange(0, 0.81, 0.2))
        plotting.show_spines(ax)
        if ax.is_last_row():
            ax.set_xlabel(display.XLABEL_INET_C_ERROR)
        ax.grid(axis="both")

        if not ax.is_first_col():
            ax.set_yticklabels("")

    g.fig.tight_layout()

    # Severity legend at bottom:
    if add_legend:
        scatter_objects = plotting.find_path_collection(g.axes.flat[1],
                                                        label="data")
        handles, labels = scatter_objects[0].legend_elements(prop="colors")
        severity_legend = plt.legend(handles=handles,
                                     labels=labels,
                                     loc="upper center",
                                     title="ImageNet-C corruption severity",
                                     bbox_to_anchor=(0.5, 0),
                                     frameon=True,
                                     bbox_transform=g.fig.transFigure,
                                     ncol=6,
                                     handletextpad=0.1)
        severity_legend.get_frame().set_linewidth(
            mpl.rcParams["axes.linewidth"])
        severity_legend.get_frame().set_edgecolor("lightgray")
        g.fig.add_artist(severity_legend)

    if add_metric_description:
        plotting.add_metric_description_title(df_plot, g.fig)

    plotting.apply_to_fig_text(g.fig, display.prettify)

    return g
def plot_error_increase_vs_model_size(
        df_main: pd.DataFrame,
        compact_layout: bool = True,
        gce_prefix=plotting.STD_GCE_PREFIX) -> sns.axisgrid.FacetGrid:
    """Plots acc/calib for ImageNet-C."""
    # Settings:
    rescaling_methods = ["temperature_scaling"]
    if compact_layout:
        family_order = ["mixer", "vit", "bit"]
    else:
        family_order = display.get_model_families_sorted()
        family_order.remove(
            "alexnet")  # Remove AlexNet bc. it only has one size.

    df_plot = _get_data(df_main, gce_prefix, family_order)
    df_plot = df_plot[df_plot.rescaling_method.isin(rescaling_methods)].copy()

    # Add downstream_error as another "rescaling method", so that seaborn can
    # plot it as a separate row:
    df_tau = df_plot.copy()
    df_tau.rescaling_method = "downstream_error"
    df_tau.MetricValue = df_tau.downstream_error
    df_plot = pd.concat([df_plot, df_tau])
    rescaling_methods = ["temperature_scaling", "downstream_error"]

    df_plot = utils.average_imagenet_c_corruption_types(
        df_plot,
        group_by=["ModelName", "Metric", "severity", "rescaling_method"])

    # Normalize per severity:
    rescaling_methods = df_plot.rescaling_method.unique()
    datasets = df_plot.DatasetName.unique()
    df_plot["relative_metric_value"] = np.nan
    for family in family_order:
        for severity in df_plot.severity.unique():
            for method in rescaling_methods:
                for dataset_ in datasets:
                    mask = df_plot.severity == severity
                    mask &= df_plot.rescaling_method == method
                    mask &= df_plot.DatasetName == dataset_
                    mask &= df_plot.ModelFamily == family
                    df_masked = df_plot[mask].copy()
                    if not df_masked.shape[0]:
                        continue
                    size_mask = df_masked.model_size == df_masked.model_size.max(
                    )
                    largest_model_value = float(df_masked.loc[size_mask,
                                                              "MetricValue"])
                    df_plot.loc[mask, "relative_metric_value"] = (
                        df_masked.MetricValue - largest_model_value)

    cmap = sns.color_palette("flare",
                             n_colors=df_plot.severity.nunique(),
                             as_cmap=True)

    def subplot_fn(data, x, y, **kwargs):
        del kwargs
        ax = plt.gca()
        for condition in np.unique(data["model_size"]):
            data_sub = data[data["model_size"] == condition].copy()
            data_sub = data_sub.sort_values(by=data_sub.columns.to_list())
            # Plot data:
            ax.scatter(data_sub[x],
                       data_sub[y],
                       s=plotting.model_to_scatter_size(
                           data_sub.model_size, 0.75),
                       c=data_sub.severity,
                       cmap=cmap,
                       alpha=0.7,
                       zorder=30,
                       label="data",
                       vmin=0,
                       vmax=5,
                       linewidth=0)
            # White dots to mask lines:
            ax.scatter(data_sub[x],
                       data_sub[y],
                       s=plotting.model_to_scatter_size(
                           data_sub.model_size, 0.75),
                       c="w",
                       alpha=1.0,
                       zorder=20,
                       linewidth=1.5,
                       label="_hidden")
            # Lines to connect dots:
            ax.plot(data_sub[x],
                    data_sub[y],
                    "-",
                    color="gray",
                    alpha=0.7,
                    zorder=10,
                    linewidth=0.75)

    g = plotting.FacetGrid(data=df_plot,
                           sharex=True,
                           sharey=False,
                           dropna=False,
                           col="ModelFamily",
                           col_order=family_order,
                           row="rescaling_method",
                           row_order=rescaling_methods,
                           height=1.1,
                           margin_titles=True,
                           aspect=0.8)
    g.map_dataframe(subplot_fn, x="severity", y="relative_metric_value")
    g.set_titles(col_template="{col_name}",
                 row_template="",
                 size=mpl.rcParams["axes.titlesize"])

    # Format axes:
    for ax in g.axes.flatten():
        ax.set_xlim(-0.9, 5.9)
        plotting.show_spines(ax)
        ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(1))
        ax.yaxis.set_minor_locator(mpl.ticker.MultipleLocator(0.02))
        if ax.is_first_col() and plotting.row_num(ax) == 0:
            ax.set_ylabel("Classification error\n(Δ to largest model)")
        elif ax.is_first_col() and plotting.row_num(ax) == 1:
            ax.set_ylabel("ECE\n(Δ to largest model)")
        else:
            ax.set_yticklabels([])

        if ax.is_first_row():
            ax.set_title(family_order[plotting.col_num(ax)])
            ax.set_ylim(-0.05 if compact_layout else -0.1, 0.3)
            ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.1))

        if ax.is_last_row():
            ax.set_xlabel("Corruption\nseverity")
            ax.set_ylim(-0.01 if compact_layout else -0.04, 0.03)
            ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.02))
        ax.grid(axis="both")
        ax.grid(True, axis="y", which="minor")

    plotting.apply_to_fig_text(g.fig, display.prettify)
    g.fig.tight_layout()
    g.fig.subplots_adjust(wspace=0.15)
    return g
def plot(df_main: pd.DataFrame,
         gce_prefix: str = plotting.STD_GCE_PREFIX,
         rescaling_method: str = "temperature_scaling",
         add_guo: bool = False) -> mpl.figure.Figure:
  """Plots acc/calib and reliability diagrams on clean ImageNet (Figure 1)."""
  rescaling_methods = ["none", rescaling_method]
  family_order = display.get_model_families_sorted()
  if add_guo:
    family_order.append("guo")

  # Set up figure:
  fig = plt.figure(figsize=(display.FULL_WIDTH/2, 1.4))
  spec = fig.add_gridspec(ncols=2, nrows=1)

  for ax_i, rescaling_method in enumerate(rescaling_methods):

    df_plot, cmap = _get_data(df_main, gce_prefix, family_order,
                              rescaling_methods=[rescaling_method])
    ax = fig.add_subplot(spec[:, ax_i])
    big_ax = ax
    for i, family in enumerate(family_order):
      if family == "guo":
        continue
      data_sub = df_plot[df_plot.ModelFamily == family]
      if data_sub.empty:
        continue
      ax.scatter(
          data_sub["downstream_error"],
          data_sub["MetricValue"],
          s=plotting.model_to_scatter_size(data_sub.model_size),
          c=data_sub.family_index,
          cmap=cmap,
          vmin=0,
          vmax=len(family_order),
          marker=utils.assert_and_get_constant(data_sub.family_marker),
          alpha=0.7,
          linewidth=0.0,
          zorder=100 - i,  # Z-order is same as model family order.
          label=family)

    # Manually add Guo et al data:
    # From Table 1 and Table S2 in https://arxiv.org/pdf/1706.04599.pdf.
    # First model is DenseNet161, second is ResNet152.
    if add_guo:
      size = plotting.model_to_scatter_size(1)
      color = [len(family_order) - 1] * 2
      marker = "x"
      if rescaling_method == "none":
        ax.scatter([0.2257, 0.2231], [0.0628, 0.0548],
                   s=size, c=color, marker=marker, alpha=0.7, label="guo")
      if rescaling_method == "temperature_scaling":
        ax.scatter([0.2257, 0.2231], [0.0199, 0.0186],
                   s=size, c=color, marker=marker, alpha=0.7, label="guo")

    plotting.show_spines(ax)
    # Aspect ratios are tuned manually for display in the paper:
    ax.set_anchor("N")
    ax.grid(False, which="minor")
    ax.grid(True, axis="both")
    ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.1))
    ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.01))
    ax.set_ylim(bottom=0.01, top=0.09)
    ax.set_xlim(0.05, 0.5)
    ax.set_xlabel(display.XLABEL_INET_ERROR)
    if rescaling_method == "none":
      ax.set_title("Unscaled")
    elif rescaling_method == "temperature_scaling":
      ax.set_title("Temperature-scaled")
      ax.set_yticklabels("")
    if ax.is_first_col():
      ax.set_ylabel("ECE")

  # Model family legend:
  handles, labels = plotting.get_model_family_legend(big_ax, family_order)
  legend = big_ax.legend(
      handles=handles, labels=labels, loc="upper right", frameon=True,
      labelspacing=0.3, handletextpad=0.1, borderpad=0.3, fontsize=4)
  legend.get_frame().set_linewidth(mpl.rcParams["axes.linewidth"])
  legend.get_frame().set_edgecolor("lightgray")

  plotting.apply_to_fig_text(fig, display.prettify)
  plotting.apply_to_fig_text(fig, lambda x: x.replace("EfficientNet", "EffNet"))

  return fig