def subplot_fn(data, **kwargs):
    del kwargs
    ax = plt.gca()
    x = ax.my_xlabel = utils.assert_and_get_constant(data.x)
    y = ax.my_ylabel = utils.assert_and_get_constant(data.y)
    if y.endswith("_residual"):
      # Plot residual w.r.t. accuracy:
      y = y.replace("_residual", "")
      data = data.copy()
      data.loc[:, y] = get_residual(data["downstream_error"], data[y])

    for marker in data.family_marker.unique():
      data_sub = data[data.family_marker == marker]
      ax.scatter(
          data_sub[x],
          data_sub[y],
          s=plotting.model_to_scatter_size(data_sub.model_size),
          c=data_sub.family_index,
          cmap=cmap,
          vmin=0,
          vmax=len(family_order),
          marker=marker,
          alpha=0.7,
          zorder=30,
          label=utils.assert_and_get_constant(data_sub.ModelFamily),
          linewidth=0.0)
  def subplot_fn(**kwargs):
    x = kwargs["x"]
    y = kwargs["y"]
    data = kwargs["data"]
    utils.assert_no_duplicates_in_condition(
        data, group_by=["DatasetName", "ModelName"])
    ax = plt.gca()
    kwargs["color"] = utils.assert_and_get_constant(data.family_color)
    ax = sns.regplot(ax=ax, **kwargs)
    plotter = RegressionPlotter(x, y, data=data)

    # Get regression parameters and show in plot:
    grid = np.linspace(data[x].min(), data[x].max(), 100)
    beta_plot, beta_boots = plotter.get_params(grid)
    beta_plot = np.array(beta_plot)
    beta_boots = np.array(beta_boots)
    intercept = 10 ** np.median(beta_boots[0, :])
    intercept_ci = 10 ** sns_utils.ci(beta_boots[0, :])
    slope = np.median(beta_boots[1, :])
    slope_ci = sns_utils.ci(beta_boots[1, :])
    s = (f"a = {intercept:1.2f} ({intercept_ci[0]:1.2f}, "
         f"{intercept_ci[1]:1.2f})\nk = {slope:1.2f} ({slope_ci[0]:1.2f}, "
         f"{slope_ci[1]:1.2f})")
    ax.text(0.04, 0.96, s, va="top", ha="left", transform=ax.transAxes,
            fontsize=4, color=(0.3, 0.3, 0.3),
            bbox=dict(facecolor="w", alpha=0.8, boxstyle="square,pad=0.1"))
예제 #3
0
def add_metric_description_title(df_plot: pd.DataFrame,
                                 fig: mpl.figure.Figure,
                                 y: float = 1.0):
    """Adds a suptitle to the figure, describing the metric used."""
    assert df_plot.Metric.nunique() == 1, "More than one metric in DataFrame."
    binning_scheme = utils.assert_and_get_constant(df_plot.binning_scheme)
    num_bins = utils.assert_and_get_constant(df_plot.num_bins)
    norm = utils.assert_and_get_constant(df_plot.norm)
    title = (f"ECE variant: {binning_scheme} binning, "
             f"{num_bins:.0f} bins, "
             f"{norm} norm")
    display_names = {
        "adaptive": "equal-mass",
        "even": "equal-width",
        "l1": "L1",
        "l2": "L2",
    }
    for old, new in display_names.items():
        title = title.replace(old, new)

    fig.suptitle(title, y=y, verticalalignment="bottom")
 def subplot_fn(data, x, y, **kwargs):
   del kwargs
   ax = plt.gca()
   for marker in data.family_marker.unique():
     data_sub = data[data.family_marker == marker]
     ax.scatter(
         data_sub[x],
         data_sub[y],
         s=plotting.model_to_scatter_size(data_sub.model_size),
         c=data_sub.family_index,
         cmap=cmap,
         vmin=0,
         vmax=len(family_order),
         marker=marker,
         linewidth=0,
         alpha=0.7,
         zorder=30,
         label=utils.assert_and_get_constant(data_sub.ModelFamily))
def plot(df_main: pd.DataFrame,
         df_reliability: pd.DataFrame,
         gce_prefix: str = plotting.STD_GCE_PREFIX,
         rescaling_method: str = "temperature_scaling",
         add_guo: bool = False) -> mpl.figure.Figure:
    """Plots acc/calib and reliability diagrams on clean ImageNet (Figure 1)."""

    family_order = display.get_model_families_sorted()
    if add_guo:
        family_order.append("guo")

    if rescaling_method == "both":
        rescaling_methods = ["none", "temperature_scaling"]
    else:
        rescaling_methods = [rescaling_method]

    # Set up figure:
    fig = plt.figure(figsize=(display.FULL_WIDTH, 1.6))
    if rescaling_method == "both":
        widths = [1.75, 1.75, 1, 1, 1]
    else:
        widths = [1.8, 1, 1, 1, 1, 1]
    heights = [0.3, 1]
    spec = fig.add_gridspec(ncols=len(widths),
                            nrows=len(heights),
                            width_ratios=widths,
                            height_ratios=heights)

    # First panels: acc vs calib ImageNet:
    for ax_i, rescaling_method in enumerate(rescaling_methods):

        df_plot, cmap = _get_data(df_main,
                                  gce_prefix,
                                  family_order,
                                  rescaling_methods=[rescaling_method])
        ax = fig.add_subplot(spec[0:2, ax_i], box_aspect=1.0)
        big_ax = ax
        for i, family in enumerate(family_order):
            if family == "guo":
                continue
            data_sub = df_plot[df_plot.ModelFamily == family]
            if data_sub.empty:
                continue
            ax.scatter(
                data_sub["downstream_error"],
                data_sub["MetricValue"],
                s=plotting.model_to_scatter_size(data_sub.model_size),
                c=data_sub.family_index,
                cmap=cmap,
                vmin=0,
                vmax=len(family_order),
                marker=utils.assert_and_get_constant(data_sub.family_marker),
                alpha=0.7,
                linewidth=0.0,
                zorder=100 - i,  # Z-order is same as model family order.
                label=family)

        # Manually add Guo et al data:
        # From Table 1 and Table S2 in https://arxiv.org/pdf/1706.04599.pdf.
        # First model is DenseNet161, second is ResNet152.
        if add_guo:
            size = plotting.model_to_scatter_size(1)
            color = [len(family_order) - 1] * 2
            marker = "x"
            if rescaling_method == "none":
                ax.scatter([0.2257, 0.2231], [0.0628, 0.0548],
                           s=size,
                           c=color,
                           marker=marker,
                           alpha=0.7,
                           label="guo")
            if rescaling_method == "temperature_scaling":
                ax.scatter([0.2257, 0.2231], [0.0199, 0.0186],
                           s=size,
                           c=color,
                           marker=marker,
                           alpha=0.7,
                           label="guo")

        plotting.show_spines(ax)
        ax.set_anchor("N")
        ax.grid(False, which="minor")
        ax.grid(True, axis="both")
        ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.1))
        ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.01))
        ax.set_ylim(bottom=0.01, top=0.09)
        ax.set_xlim(0.05, 0.55)
        ax.set_xlabel(display.XLABEL_INET_ERROR)
        if len(rescaling_methods) == 1:  # Showing just one rescaling method.
            if rescaling_method == "none":
                ax.set_ylabel(display.YLABEL_ECE_UNSCALED)
            elif rescaling_method == "temperature_scaling":
                ax.set_ylabel(display.YLABEL_ECE_TEMP_SCALED)
        else:  # Showing both rescaling methods.
            if rescaling_method == "none":
                ax.set_title("Unscaled")
            elif rescaling_method == "temperature_scaling":
                ax.set_title("Temperature-scaled")
            if ax.is_first_col():
                ax.set_ylabel("ECE")

    # Remaining panels: Reliability diagrams:
    offset = len(rescaling_methods)
    model_names = [
        "mixer/jft-300m/H/14",
        "vit-h/14",
        "bit-jft-r152-x4-480",
    ]
    if offset == 1:
        model_names += ["wsl_32x48d", "simclr-4x-fine-tuned-100"]
    dataset_name = "imagenet(split='validation[20%:]')"

    for i, model_name in enumerate(model_names):
        # Get predictions:
        mask = df_main.ModelName == model_name
        mask &= df_main.rescaling_method == "none"
        mask &= df_main.Metric == "accuracy"
        mask &= df_main.DatasetName == dataset_name
        raw_model_name = df_main[mask].RawModelName
        assert len(raw_model_name) <= 1, df_main[mask]
        if len(raw_model_name) == 0:  # pylint: disable=g-explicit-length-test
            continue

        binned = _get_binned_reliability_data(
            df_reliability, utils.assert_and_get_constant(raw_model_name),
            dataset_name)

        rel_ax = fig.add_subplot(spec[1, i + offset])
        _plot_confidence_and_reliability(conf_ax=fig.add_subplot(spec[0, i +
                                                                      offset]),
                                         rel_ax=rel_ax,
                                         binned=binned,
                                         model_name=model_name,
                                         first_col=offset)

        if rescaling_method == "none":
            rel_ax.set_xlabel("Confidence\n(unscaled)")
        elif rescaling_method == "temperature_scaled":
            rel_ax.set_xlabel("Confidence\n(temp. scaled)")

    # Model family legend:
    handles, labels = plotting.get_model_family_legend(big_ax, family_order)
    legend = big_ax.legend(handles=handles,
                           labels=labels,
                           loc="upper right",
                           frameon=True,
                           labelspacing=0.25,
                           handletextpad=0.1,
                           borderpad=0.3,
                           fontsize=4)
    legend.get_frame().set_linewidth(mpl.rcParams["axes.linewidth"])
    legend.get_frame().set_edgecolor("lightgray")

    plotting.apply_to_fig_text(fig, display.prettify)

    offset = 0.05
    for ax in fig.axes[1:]:
        box = ax.get_position()
        box.x0 += offset
        box.x1 += offset
        ax.set_position(box)

    return fig
def plot_reliability_diagrams(
        df_main: pd.DataFrame,
        df_reliability: pd.DataFrame,
        family: str,
        rescaling_method: str = "temperature_scaling",
        dataset_name: str = "imagenet(split='validation[20%:]')",
        gce_prefix=plotting.STD_GCE_PREFIX) -> mpl.figure.Figure:
    """Plots acc/calib and reliability diagrams on clean ImageNet (Figure 1)."""

    df_plot, _ = _get_data(df_main,
                           gce_prefix, [family],
                           rescaling_methods=[rescaling_method],
                           dataset_name=dataset_name)
    df_models = df_plot.drop_duplicates(subset=["ModelName"])
    df_models = df_models.sort_values(by="model_size")
    model_names = df_models.ModelName.to_list()

    # Set up figure:
    num_cols = max(5, len(model_names))
    width = num_cols * 0.80
    fig = plt.figure(figsize=(width, 1.4))
    spec = fig.add_gridspec(ncols=num_cols, nrows=2, height_ratios=[0.4, 1])

    for i in range(num_cols):
        if i >= len(model_names):
            # Add axes as placeholders for formatting but set to invisible:
            fig.add_subplot(spec[0, i]).set_visible(False)
            fig.add_subplot(spec[1, i]).set_visible(False)
            continue
        model_name = model_names[i]
        # Get predictions:
        mask = df_main.ModelName == model_name
        mask &= df_main.rescaling_method == rescaling_method
        mask &= df_main.Metric == "accuracy"
        mask &= df_main.DatasetName == dataset_name
        raw_model_name = df_main[mask].RawModelName
        assert len(raw_model_name) == 1
        binned = _get_binned_reliability_data(
            df_reliability, utils.assert_and_get_constant(raw_model_name),
            dataset_name)

        rel_ax = fig.add_subplot(spec[1, i])
        _plot_confidence_and_reliability(conf_ax=fig.add_subplot(spec[0, i]),
                                         rel_ax=rel_ax,
                                         binned=binned,
                                         model_name=model_name)
        if rescaling_method == "none":
            rel_ax.set_xlabel("Confidence\n(unscaled)")
        elif rescaling_method == "temperature_scaling":
            rel_ax.set_xlabel("Confidence\n(temp. scaled)")

    def prettify(s):
        s = display.prettify(s)
        s = s.replace("MLP-Mixer-", "MLP-Mixer\n")
        return s

    plotting.apply_to_fig_text(fig, prettify)
    plotting.apply_to_fig_text(fig,
                               lambda x: x.replace("EfficientNet", "EffNet"))

    fig.subplots_adjust(hspace=-0.05)

    return fig
def plot(df_main: pd.DataFrame,
         gce_prefix: str = plotting.STD_GCE_PREFIX,
         rescaling_method: str = "temperature_scaling",
         add_guo: bool = False) -> mpl.figure.Figure:
    """Plots acc/calib and reliability diagrams on clean ImageNet (Figure 1)."""
    rescaling_methods = ["none", rescaling_method]
    family_order = display.get_model_families_sorted(
        ["mixer", "vit", "bit", "simclr"])
    if add_guo:
        family_order.append("guo")

    # Set up figure:
    fig = plt.figure(figsize=(display.FULL_WIDTH / 2, 2))
    spec = fig.add_gridspec(ncols=3, nrows=2)

    for col, bit_version in enumerate(
        ["BiT-ImageNet", "BiT-ImageNet21k", "BiT-JFT"]):
        # pylint: disable=g-long-lambda
        if bit_version == "BiT-ImageNet":
            display.get_standard_model_list = lambda: [
                m for m in display.MODEL_SIZE.keys() if not (m.startswith(
                    "bit-imagenet21k-") or m.startswith("bit-jft-"))
            ]
        elif bit_version == "BiT-ImageNet21k":
            display.get_standard_model_list = lambda: [
                m for m in display.MODEL_SIZE.keys() if not (m.startswith(
                    "bit-imagenet-") or m.startswith("bit-jft-"))
            ]
        elif bit_version == "BiT-JFT":
            display.get_standard_model_list = lambda: [
                m for m in display.MODEL_SIZE.keys() if not (m.startswith(
                    "bit-imagenet-") or m.startswith("bit-imagenet21k-"))
            ]
        else:
            raise ValueError(f"Unknown BiT version: {bit_version}")
        # pylint: enable=g-long-lambda

        for row, rescaling_method in enumerate(rescaling_methods):

            df_plot, cmap = _get_data(df_main,
                                      gce_prefix,
                                      family_order,
                                      rescaling_methods=[rescaling_method])
            ax = fig.add_subplot(spec[row, col])
            big_ax = ax
            for i, family in enumerate(family_order):
                if family == "guo":
                    continue
                data_sub = df_plot[df_plot.ModelFamily == family]
                if data_sub.empty:
                    continue
                ax.scatter(
                    data_sub["downstream_error"],
                    data_sub["MetricValue"],
                    s=plotting.model_to_scatter_size(data_sub.model_size),
                    c=data_sub.family_index,
                    cmap=cmap,
                    vmin=0,
                    vmax=len(family_order),
                    marker=utils.assert_and_get_constant(
                        data_sub.family_marker),
                    linewidth=0.5,
                    alpha=1.0 if "bit" in family else 0.5,
                    zorder=100 - i,  # Z-order is same as model family order.
                    label=family)

            # Manually add Guo et al data:
            # From Table 1 and Table S2 in https://arxiv.org/pdf/1706.04599.pdf.
            # First model is DenseNet161, second is ResNet152.
            if add_guo:
                size = plotting.model_to_scatter_size(1)
                color = [len(family_order) - 1] * 2
                marker = "x"
                if rescaling_method == "none":
                    ax.scatter([0.2257, 0.2231], [0.0628, 0.0548],
                               s=size,
                               c=color,
                               marker=marker,
                               alpha=0.7,
                               label="guo")
                if rescaling_method == "temperature_scaling":
                    ax.scatter([0.2257, 0.2231], [0.0199, 0.0186],
                               s=size,
                               c=color,
                               marker=marker,
                               alpha=0.7,
                               label="guo")

            plotting.show_spines(ax)
            # Aspect ratios are tuned manually for display in the paper:
            ax.set_anchor("N")
            ax.grid(False, which="minor")
            ax.grid(True, axis="both")
            ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.01))
            ax.set_ylim(bottom=0.0, top=0.09)
            ax.set_xlim(0.05, 0.3)
            ax.set_xlabel(display.XLABEL_INET_ERROR)
            if ax.is_first_row():
                ax.set_title(bit_version, fontsize=6)
                ax.set_xlabel("")
                ax.set_xticklabels("")
            else:
                ax.set_ylim(bottom=0.0, top=0.05)
            if ax.is_first_col():
                if rescaling_method == "none":
                    ax.set_ylabel(display.YLABEL_ECE_UNSCALED)
                elif rescaling_method == "temperature_scaling":
                    ax.set_ylabel(display.YLABEL_ECE_TEMP_SCALED)
            else:
                ax.set_yticklabels("")

    fig.tight_layout(pad=0.5)

    # Model family legend:
    handles, labels = plotting.get_model_family_legend(big_ax, family_order)

    plotting.apply_to_fig_text(fig, display.prettify)
    plotting.apply_to_fig_text(fig,
                               lambda x: x.replace("EfficientNet", "EffNet"))

    legend = fig.axes[0].legend(handles=handles,
                                labels=labels,
                                loc="upper center",
                                title="Model family",
                                bbox_to_anchor=(0.55, -0.025),
                                frameon=True,
                                bbox_transform=fig.transFigure,
                                ncol=len(family_order),
                                handletextpad=0.1)
    legend.get_frame().set_linewidth(mpl.rcParams["axes.linewidth"])
    legend.get_frame().set_edgecolor("lightgray")
    plotting.apply_to_fig_text(fig, display.prettify)

    return fig
def plot(df_main: pd.DataFrame,
         gce_prefix: str = plotting.STD_GCE_PREFIX,
         rescaling_method: str = "temperature_scaling",
         add_guo: bool = False) -> mpl.figure.Figure:
  """Plots acc/calib and reliability diagrams on clean ImageNet (Figure 1)."""
  rescaling_methods = ["none", rescaling_method]
  family_order = display.get_model_families_sorted()
  if add_guo:
    family_order.append("guo")

  # Set up figure:
  fig = plt.figure(figsize=(display.FULL_WIDTH/2, 1.4))
  spec = fig.add_gridspec(ncols=2, nrows=1)

  for ax_i, rescaling_method in enumerate(rescaling_methods):

    df_plot, cmap = _get_data(df_main, gce_prefix, family_order,
                              rescaling_methods=[rescaling_method])
    ax = fig.add_subplot(spec[:, ax_i])
    big_ax = ax
    for i, family in enumerate(family_order):
      if family == "guo":
        continue
      data_sub = df_plot[df_plot.ModelFamily == family]
      if data_sub.empty:
        continue
      ax.scatter(
          data_sub["downstream_error"],
          data_sub["MetricValue"],
          s=plotting.model_to_scatter_size(data_sub.model_size),
          c=data_sub.family_index,
          cmap=cmap,
          vmin=0,
          vmax=len(family_order),
          marker=utils.assert_and_get_constant(data_sub.family_marker),
          alpha=0.7,
          linewidth=0.0,
          zorder=100 - i,  # Z-order is same as model family order.
          label=family)

    # Manually add Guo et al data:
    # From Table 1 and Table S2 in https://arxiv.org/pdf/1706.04599.pdf.
    # First model is DenseNet161, second is ResNet152.
    if add_guo:
      size = plotting.model_to_scatter_size(1)
      color = [len(family_order) - 1] * 2
      marker = "x"
      if rescaling_method == "none":
        ax.scatter([0.2257, 0.2231], [0.0628, 0.0548],
                   s=size, c=color, marker=marker, alpha=0.7, label="guo")
      if rescaling_method == "temperature_scaling":
        ax.scatter([0.2257, 0.2231], [0.0199, 0.0186],
                   s=size, c=color, marker=marker, alpha=0.7, label="guo")

    plotting.show_spines(ax)
    # Aspect ratios are tuned manually for display in the paper:
    ax.set_anchor("N")
    ax.grid(False, which="minor")
    ax.grid(True, axis="both")
    ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.1))
    ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.01))
    ax.set_ylim(bottom=0.01, top=0.09)
    ax.set_xlim(0.05, 0.5)
    ax.set_xlabel(display.XLABEL_INET_ERROR)
    if rescaling_method == "none":
      ax.set_title("Unscaled")
    elif rescaling_method == "temperature_scaling":
      ax.set_title("Temperature-scaled")
      ax.set_yticklabels("")
    if ax.is_first_col():
      ax.set_ylabel("ECE")

  # Model family legend:
  handles, labels = plotting.get_model_family_legend(big_ax, family_order)
  legend = big_ax.legend(
      handles=handles, labels=labels, loc="upper right", frameon=True,
      labelspacing=0.3, handletextpad=0.1, borderpad=0.3, fontsize=4)
  legend.get_frame().set_linewidth(mpl.rcParams["axes.linewidth"])
  legend.get_frame().set_edgecolor("lightgray")

  plotting.apply_to_fig_text(fig, display.prettify)
  plotting.apply_to_fig_text(fig, lambda x: x.replace("EfficientNet", "EffNet"))

  return fig