def subplot_fn(data, **kwargs):
    del kwargs
    ax = plt.gca()
    x = ax.my_xlabel = utils.assert_and_get_constant(data.x)
    y = ax.my_ylabel = utils.assert_and_get_constant(data.y)
    if y.endswith("_residual"):
      # Plot residual w.r.t. accuracy:
      y = y.replace("_residual", "")
      data = data.copy()
      data.loc[:, y] = get_residual(data["downstream_error"], data[y])

    for marker in data.family_marker.unique():
      data_sub = data[data.family_marker == marker]
      ax.scatter(
          data_sub[x],
          data_sub[y],
          s=plotting.model_to_scatter_size(data_sub.model_size),
          c=data_sub.family_index,
          cmap=cmap,
          vmin=0,
          vmax=len(family_order),
          marker=marker,
          alpha=0.7,
          zorder=30,
          label=utils.assert_and_get_constant(data_sub.ModelFamily),
          linewidth=0.0)
 def subplot_fn(data, x, y, **kwargs):
     del kwargs
     ax = plt.gca()
     data = utils.average_imagenet_c_corruption_types(
         data, group_by=["severity", "model_size"])
     cmap = sns.color_palette("flare",
                              n_colors=data.severity.nunique(),
                              as_cmap=True)
     # Plot data:
     ax.scatter(data[x],
                data[y],
                s=plotting.model_to_scatter_size(data.model_size, 0.75),
                c=data.severity,
                linewidth=0,
                cmap=cmap,
                alpha=0.7,
                zorder=30,
                label="data")
     # White dots to mask lines:
     ax.scatter(data[x],
                data[y],
                s=plotting.model_to_scatter_size(data.model_size, 0.75),
                c="w",
                alpha=1.0,
                zorder=20,
                linewidth=1.5,
                label="_hidden")
     # Lines to connect dots:
     for condition in np.unique(data["severity"]):
         data_sub = data[data["severity"] == condition].copy()
         data_sub = data_sub.sort_values(by=data_sub.columns.to_list())
         ax.plot(data_sub[x],
                 data_sub[y],
                 "-",
                 color="gray",
                 alpha=0.7,
                 zorder=10,
                 linewidth=1)
 def subplot_fn(data, x, y, **kwargs):
     del kwargs
     ax = plt.gca()
     for condition in np.unique(data["model_size"]):
         data_sub = data[data["model_size"] == condition].copy()
         data_sub = data_sub.sort_values(by=data_sub.columns.to_list())
         # Plot data:
         ax.scatter(data_sub[x],
                    data_sub[y],
                    s=plotting.model_to_scatter_size(
                        data_sub.model_size, 0.75),
                    c=data_sub.severity,
                    cmap=cmap,
                    alpha=0.7,
                    zorder=30,
                    label="data",
                    vmin=0,
                    vmax=5,
                    linewidth=0)
         # White dots to mask lines:
         ax.scatter(data_sub[x],
                    data_sub[y],
                    s=plotting.model_to_scatter_size(
                        data_sub.model_size, 0.75),
                    c="w",
                    alpha=1.0,
                    zorder=20,
                    linewidth=1.5,
                    label="_hidden")
         # Lines to connect dots:
         ax.plot(data_sub[x],
                 data_sub[y],
                 "-",
                 color="gray",
                 alpha=0.7,
                 zorder=10,
                 linewidth=0.75)
 def subplot_fn(data, x, y, **kwargs):
   del kwargs
   ax = plt.gca()
   for marker in data.family_marker.unique():
     data_sub = data[data.family_marker == marker]
     ax.scatter(
         data_sub[x],
         data_sub[y],
         s=plotting.model_to_scatter_size(data_sub.model_size),
         c=data_sub.family_index,
         cmap=cmap,
         vmin=0,
         vmax=len(family_order),
         marker=marker,
         linewidth=0,
         alpha=0.7,
         zorder=30,
         label=utils.assert_and_get_constant(data_sub.ModelFamily))
def plot(df_main: pd.DataFrame,
         df_reliability: pd.DataFrame,
         gce_prefix: str = plotting.STD_GCE_PREFIX,
         rescaling_method: str = "temperature_scaling",
         add_guo: bool = False) -> mpl.figure.Figure:
    """Plots acc/calib and reliability diagrams on clean ImageNet (Figure 1)."""

    family_order = display.get_model_families_sorted()
    if add_guo:
        family_order.append("guo")

    if rescaling_method == "both":
        rescaling_methods = ["none", "temperature_scaling"]
    else:
        rescaling_methods = [rescaling_method]

    # Set up figure:
    fig = plt.figure(figsize=(display.FULL_WIDTH, 1.6))
    if rescaling_method == "both":
        widths = [1.75, 1.75, 1, 1, 1]
    else:
        widths = [1.8, 1, 1, 1, 1, 1]
    heights = [0.3, 1]
    spec = fig.add_gridspec(ncols=len(widths),
                            nrows=len(heights),
                            width_ratios=widths,
                            height_ratios=heights)

    # First panels: acc vs calib ImageNet:
    for ax_i, rescaling_method in enumerate(rescaling_methods):

        df_plot, cmap = _get_data(df_main,
                                  gce_prefix,
                                  family_order,
                                  rescaling_methods=[rescaling_method])
        ax = fig.add_subplot(spec[0:2, ax_i], box_aspect=1.0)
        big_ax = ax
        for i, family in enumerate(family_order):
            if family == "guo":
                continue
            data_sub = df_plot[df_plot.ModelFamily == family]
            if data_sub.empty:
                continue
            ax.scatter(
                data_sub["downstream_error"],
                data_sub["MetricValue"],
                s=plotting.model_to_scatter_size(data_sub.model_size),
                c=data_sub.family_index,
                cmap=cmap,
                vmin=0,
                vmax=len(family_order),
                marker=utils.assert_and_get_constant(data_sub.family_marker),
                alpha=0.7,
                linewidth=0.0,
                zorder=100 - i,  # Z-order is same as model family order.
                label=family)

        # Manually add Guo et al data:
        # From Table 1 and Table S2 in https://arxiv.org/pdf/1706.04599.pdf.
        # First model is DenseNet161, second is ResNet152.
        if add_guo:
            size = plotting.model_to_scatter_size(1)
            color = [len(family_order) - 1] * 2
            marker = "x"
            if rescaling_method == "none":
                ax.scatter([0.2257, 0.2231], [0.0628, 0.0548],
                           s=size,
                           c=color,
                           marker=marker,
                           alpha=0.7,
                           label="guo")
            if rescaling_method == "temperature_scaling":
                ax.scatter([0.2257, 0.2231], [0.0199, 0.0186],
                           s=size,
                           c=color,
                           marker=marker,
                           alpha=0.7,
                           label="guo")

        plotting.show_spines(ax)
        ax.set_anchor("N")
        ax.grid(False, which="minor")
        ax.grid(True, axis="both")
        ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.1))
        ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.01))
        ax.set_ylim(bottom=0.01, top=0.09)
        ax.set_xlim(0.05, 0.55)
        ax.set_xlabel(display.XLABEL_INET_ERROR)
        if len(rescaling_methods) == 1:  # Showing just one rescaling method.
            if rescaling_method == "none":
                ax.set_ylabel(display.YLABEL_ECE_UNSCALED)
            elif rescaling_method == "temperature_scaling":
                ax.set_ylabel(display.YLABEL_ECE_TEMP_SCALED)
        else:  # Showing both rescaling methods.
            if rescaling_method == "none":
                ax.set_title("Unscaled")
            elif rescaling_method == "temperature_scaling":
                ax.set_title("Temperature-scaled")
            if ax.is_first_col():
                ax.set_ylabel("ECE")

    # Remaining panels: Reliability diagrams:
    offset = len(rescaling_methods)
    model_names = [
        "mixer/jft-300m/H/14",
        "vit-h/14",
        "bit-jft-r152-x4-480",
    ]
    if offset == 1:
        model_names += ["wsl_32x48d", "simclr-4x-fine-tuned-100"]
    dataset_name = "imagenet(split='validation[20%:]')"

    for i, model_name in enumerate(model_names):
        # Get predictions:
        mask = df_main.ModelName == model_name
        mask &= df_main.rescaling_method == "none"
        mask &= df_main.Metric == "accuracy"
        mask &= df_main.DatasetName == dataset_name
        raw_model_name = df_main[mask].RawModelName
        assert len(raw_model_name) <= 1, df_main[mask]
        if len(raw_model_name) == 0:  # pylint: disable=g-explicit-length-test
            continue

        binned = _get_binned_reliability_data(
            df_reliability, utils.assert_and_get_constant(raw_model_name),
            dataset_name)

        rel_ax = fig.add_subplot(spec[1, i + offset])
        _plot_confidence_and_reliability(conf_ax=fig.add_subplot(spec[0, i +
                                                                      offset]),
                                         rel_ax=rel_ax,
                                         binned=binned,
                                         model_name=model_name,
                                         first_col=offset)

        if rescaling_method == "none":
            rel_ax.set_xlabel("Confidence\n(unscaled)")
        elif rescaling_method == "temperature_scaled":
            rel_ax.set_xlabel("Confidence\n(temp. scaled)")

    # Model family legend:
    handles, labels = plotting.get_model_family_legend(big_ax, family_order)
    legend = big_ax.legend(handles=handles,
                           labels=labels,
                           loc="upper right",
                           frameon=True,
                           labelspacing=0.25,
                           handletextpad=0.1,
                           borderpad=0.3,
                           fontsize=4)
    legend.get_frame().set_linewidth(mpl.rcParams["axes.linewidth"])
    legend.get_frame().set_edgecolor("lightgray")

    plotting.apply_to_fig_text(fig, display.prettify)

    offset = 0.05
    for ax in fig.axes[1:]:
        box = ax.get_position()
        box.x0 += offset
        box.x1 += offset
        ax.set_position(box)

    return fig
def plot(df_main: pd.DataFrame,
         gce_prefix: str = plotting.STD_GCE_PREFIX,
         rescaling_method: str = "temperature_scaling",
         add_guo: bool = False) -> mpl.figure.Figure:
    """Plots acc/calib and reliability diagrams on clean ImageNet (Figure 1)."""
    rescaling_methods = ["none", rescaling_method]
    family_order = display.get_model_families_sorted(
        ["mixer", "vit", "bit", "simclr"])
    if add_guo:
        family_order.append("guo")

    # Set up figure:
    fig = plt.figure(figsize=(display.FULL_WIDTH / 2, 2))
    spec = fig.add_gridspec(ncols=3, nrows=2)

    for col, bit_version in enumerate(
        ["BiT-ImageNet", "BiT-ImageNet21k", "BiT-JFT"]):
        # pylint: disable=g-long-lambda
        if bit_version == "BiT-ImageNet":
            display.get_standard_model_list = lambda: [
                m for m in display.MODEL_SIZE.keys() if not (m.startswith(
                    "bit-imagenet21k-") or m.startswith("bit-jft-"))
            ]
        elif bit_version == "BiT-ImageNet21k":
            display.get_standard_model_list = lambda: [
                m for m in display.MODEL_SIZE.keys() if not (m.startswith(
                    "bit-imagenet-") or m.startswith("bit-jft-"))
            ]
        elif bit_version == "BiT-JFT":
            display.get_standard_model_list = lambda: [
                m for m in display.MODEL_SIZE.keys() if not (m.startswith(
                    "bit-imagenet-") or m.startswith("bit-imagenet21k-"))
            ]
        else:
            raise ValueError(f"Unknown BiT version: {bit_version}")
        # pylint: enable=g-long-lambda

        for row, rescaling_method in enumerate(rescaling_methods):

            df_plot, cmap = _get_data(df_main,
                                      gce_prefix,
                                      family_order,
                                      rescaling_methods=[rescaling_method])
            ax = fig.add_subplot(spec[row, col])
            big_ax = ax
            for i, family in enumerate(family_order):
                if family == "guo":
                    continue
                data_sub = df_plot[df_plot.ModelFamily == family]
                if data_sub.empty:
                    continue
                ax.scatter(
                    data_sub["downstream_error"],
                    data_sub["MetricValue"],
                    s=plotting.model_to_scatter_size(data_sub.model_size),
                    c=data_sub.family_index,
                    cmap=cmap,
                    vmin=0,
                    vmax=len(family_order),
                    marker=utils.assert_and_get_constant(
                        data_sub.family_marker),
                    linewidth=0.5,
                    alpha=1.0 if "bit" in family else 0.5,
                    zorder=100 - i,  # Z-order is same as model family order.
                    label=family)

            # Manually add Guo et al data:
            # From Table 1 and Table S2 in https://arxiv.org/pdf/1706.04599.pdf.
            # First model is DenseNet161, second is ResNet152.
            if add_guo:
                size = plotting.model_to_scatter_size(1)
                color = [len(family_order) - 1] * 2
                marker = "x"
                if rescaling_method == "none":
                    ax.scatter([0.2257, 0.2231], [0.0628, 0.0548],
                               s=size,
                               c=color,
                               marker=marker,
                               alpha=0.7,
                               label="guo")
                if rescaling_method == "temperature_scaling":
                    ax.scatter([0.2257, 0.2231], [0.0199, 0.0186],
                               s=size,
                               c=color,
                               marker=marker,
                               alpha=0.7,
                               label="guo")

            plotting.show_spines(ax)
            # Aspect ratios are tuned manually for display in the paper:
            ax.set_anchor("N")
            ax.grid(False, which="minor")
            ax.grid(True, axis="both")
            ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.01))
            ax.set_ylim(bottom=0.0, top=0.09)
            ax.set_xlim(0.05, 0.3)
            ax.set_xlabel(display.XLABEL_INET_ERROR)
            if ax.is_first_row():
                ax.set_title(bit_version, fontsize=6)
                ax.set_xlabel("")
                ax.set_xticklabels("")
            else:
                ax.set_ylim(bottom=0.0, top=0.05)
            if ax.is_first_col():
                if rescaling_method == "none":
                    ax.set_ylabel(display.YLABEL_ECE_UNSCALED)
                elif rescaling_method == "temperature_scaling":
                    ax.set_ylabel(display.YLABEL_ECE_TEMP_SCALED)
            else:
                ax.set_yticklabels("")

    fig.tight_layout(pad=0.5)

    # Model family legend:
    handles, labels = plotting.get_model_family_legend(big_ax, family_order)

    plotting.apply_to_fig_text(fig, display.prettify)
    plotting.apply_to_fig_text(fig,
                               lambda x: x.replace("EfficientNet", "EffNet"))

    legend = fig.axes[0].legend(handles=handles,
                                labels=labels,
                                loc="upper center",
                                title="Model family",
                                bbox_to_anchor=(0.55, -0.025),
                                frameon=True,
                                bbox_transform=fig.transFigure,
                                ncol=len(family_order),
                                handletextpad=0.1)
    legend.get_frame().set_linewidth(mpl.rcParams["axes.linewidth"])
    legend.get_frame().set_edgecolor("lightgray")
    plotting.apply_to_fig_text(fig, display.prettify)

    return fig
def plot(df_main: pd.DataFrame,
         gce_prefix: str = plotting.STD_GCE_PREFIX,
         rescaling_method: str = "temperature_scaling",
         add_guo: bool = False) -> mpl.figure.Figure:
  """Plots acc/calib and reliability diagrams on clean ImageNet (Figure 1)."""
  rescaling_methods = ["none", rescaling_method]
  family_order = display.get_model_families_sorted()
  if add_guo:
    family_order.append("guo")

  # Set up figure:
  fig = plt.figure(figsize=(display.FULL_WIDTH/2, 1.4))
  spec = fig.add_gridspec(ncols=2, nrows=1)

  for ax_i, rescaling_method in enumerate(rescaling_methods):

    df_plot, cmap = _get_data(df_main, gce_prefix, family_order,
                              rescaling_methods=[rescaling_method])
    ax = fig.add_subplot(spec[:, ax_i])
    big_ax = ax
    for i, family in enumerate(family_order):
      if family == "guo":
        continue
      data_sub = df_plot[df_plot.ModelFamily == family]
      if data_sub.empty:
        continue
      ax.scatter(
          data_sub["downstream_error"],
          data_sub["MetricValue"],
          s=plotting.model_to_scatter_size(data_sub.model_size),
          c=data_sub.family_index,
          cmap=cmap,
          vmin=0,
          vmax=len(family_order),
          marker=utils.assert_and_get_constant(data_sub.family_marker),
          alpha=0.7,
          linewidth=0.0,
          zorder=100 - i,  # Z-order is same as model family order.
          label=family)

    # Manually add Guo et al data:
    # From Table 1 and Table S2 in https://arxiv.org/pdf/1706.04599.pdf.
    # First model is DenseNet161, second is ResNet152.
    if add_guo:
      size = plotting.model_to_scatter_size(1)
      color = [len(family_order) - 1] * 2
      marker = "x"
      if rescaling_method == "none":
        ax.scatter([0.2257, 0.2231], [0.0628, 0.0548],
                   s=size, c=color, marker=marker, alpha=0.7, label="guo")
      if rescaling_method == "temperature_scaling":
        ax.scatter([0.2257, 0.2231], [0.0199, 0.0186],
                   s=size, c=color, marker=marker, alpha=0.7, label="guo")

    plotting.show_spines(ax)
    # Aspect ratios are tuned manually for display in the paper:
    ax.set_anchor("N")
    ax.grid(False, which="minor")
    ax.grid(True, axis="both")
    ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.1))
    ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.01))
    ax.set_ylim(bottom=0.01, top=0.09)
    ax.set_xlim(0.05, 0.5)
    ax.set_xlabel(display.XLABEL_INET_ERROR)
    if rescaling_method == "none":
      ax.set_title("Unscaled")
    elif rescaling_method == "temperature_scaling":
      ax.set_title("Temperature-scaled")
      ax.set_yticklabels("")
    if ax.is_first_col():
      ax.set_ylabel("ECE")

  # Model family legend:
  handles, labels = plotting.get_model_family_legend(big_ax, family_order)
  legend = big_ax.legend(
      handles=handles, labels=labels, loc="upper right", frameon=True,
      labelspacing=0.3, handletextpad=0.1, borderpad=0.3, fontsize=4)
  legend.get_frame().set_linewidth(mpl.rcParams["axes.linewidth"])
  legend.get_frame().set_edgecolor("lightgray")

  plotting.apply_to_fig_text(fig, display.prettify)
  plotting.apply_to_fig_text(fig, lambda x: x.replace("EfficientNet", "EffNet"))

  return fig