def make_figure(datasets, min_ess=12, aggregate=np.mean, aggregate_name="Mean", color="blue", no_xlabel=False):
    losses_per_hp_with_noise = get_complexity_losses_per_hp(datasets, min_ess, filter_noise=False)
    losses_per_hp_without_noise = get_complexity_losses_per_hp(datasets, min_ess, filter_noise=True)

    # Ordered by name
    ordered_measures = np.sort(list(losses_per_hp_with_noise.keys())).tolist()

    # Collect plot data
    d = [aggregate(losses_per_hp_without_noise[c]['all']) - aggregate(losses_per_hp_with_noise[c]['all'])
         for c in ordered_measures]

    # Make plot
    plt.clf()
    sns.barplot(x=ordered_measures, y=d, color=color)
    plt.axhline(0, color='black', linewidth=0.5, linestyle='--')
    plt.ylim(-np.max(np.abs(d)), np.max(np.abs(d)))
    plt.ylabel("%s(w/ filter) - %s(w/o filter)" % (aggregate_name, aggregate_name), fontsize=10)
    if not no_xlabel:
        plt.gca().set_xticklabels([pretty_measure(c) for c in ordered_measures], rotation=45, fontsize=10, ha="right")
    else:
        plt.gca().set_xticklabels([""] * len(ordered_measures))
    plt.gcf().set_size_inches(w=8, h=2.5)
    plt.savefig("figure_monte_carlo_noise_ablation__%s__ds_%s__mess_%f_cdf.pdf" %
                (aggregate_name.lower().replace("$", "").replace("_", "").replace("{", "").replace("}", ""),
                 "_".join(datasets), min_ess), bbox_inches="tight")
Пример #2
0
def make_figure(datasets, min_ess=12, filter_noise=True):
    data_key = "_".join(datasets)

    # Load the raw data (used to get losses in combined environments)
    data = load_data(DATA_PATH)
    data = data.loc[[r["hp.dataset"] in datasets
                     for _, r in data.iterrows()]]  # Select based on dataset

    # Load the precomputations
    precomp = pickle.load(
        open(
            ENVIRONMENT_CACHE_PATH +
            "/precomputations__filternoise%s__%s.pkl" %
            (str(filter_noise).lower(), data_key), "rb"))

    # Get the list of losses for each generalization measure
    # We will later report statistics of the distribution of losses for each measure
    losses = {}
    for c in list(precomp["env_losses"].keys()):
        print(c)

        # Use the FFT version of spectral measures if available
        if "_fft" not in c and c + "_fft" in precomp["env_losses"].keys():
            print("Skipping", c, "in favor of", c + "_fft")
            continue

        losses[c] = get_environment_losses(data,
                                           precomp=precomp,
                                           measure=c,
                                           min_ess=min_ess)

    # Order measures by mean sign error over all HPs
    ordered_measures = np.array(list(losses.keys()))[np.argsort(
        [np.max(losses[c]) for c in losses.keys()])].tolist()

    # Make plot
    f = plt.gcf()
    ax = plt.gca()

    bins = np.linspace(0, 1, 100)
    cbar_ax = f.add_axes([.91, .127, .02, .75])

    z = np.zeros((len(bins), len(ordered_measures)))
    for i, c in enumerate(ordered_measures):
        # Get losses
        l = losses[c]

        # Plot mean and max
        ax.axvline(i, linestyle="-", color="white", linewidth=3, zorder=999)

        if len(l) > 0:
            ax.plot([i, i + 1],
                    [np.mean(l) * 100, np.mean(l) * 100],
                    color="orange",
                    zorder=2,
                    linewidth=1.5,
                    linestyle=":")
            ax.plot(
                [i, i + 1],
                [np.percentile(l, q=90) * 100,
                 np.percentile(l, q=90) * 100],
                color="magenta",
                zorder=2,
                linewidth=1.5,
                linestyle="--")
            ax.plot([i, i + 1],
                    [np.max(l) * 100, np.max(l) * 100],
                    color="limegreen",
                    zorder=1,
                    linewidth=1.5)

            # Calculate CDF
            for j, b in enumerate(bins):
                z[j, i] = (l <= b).sum() / len(l)
        else:
            # No data = no environment had a total weight ≥ min weight
            ax.scatter([i + 0.5], [50], marker="x", color="red")

        if z.sum() > 0:
            heatmap = sns.heatmap(z,
                                  cmap="Blues_r",
                                  vmin=0.5,
                                  vmax=1,
                                  rasterized=True,
                                  ax=ax,
                                  cbar_ax=cbar_ax)
            heatmap.collections[0].colorbar.ax.tick_params(labelsize=8)

        ax.invert_yaxis()
        ax.set_yticks([0, 50, 100])
        ax.set_yticklabels([0, 0.5, 1], fontsize=6, rotation=0)
        ax.set_ylabel("Sign-error distribution\n(%d)" %
                      len(list(losses.values())[0]),
                      fontsize=8)
        ax.set_ylim(0, 101)

    ax.set_xticks(np.arange(len(ordered_measures)) + 0.5)
    ax.set_xticklabels([pretty_measure(c) for c in ordered_measures],
                       rotation=45,
                       fontsize=8,
                       ha="right")

    lines = [(Line2D([0], [0], color='limegreen', linewidth=1.5,
                     linestyle='-'), 'max'),
             (Line2D([0], [0], color='magenta', linewidth=1.5,
                     linestyle='--'), '90th percentile'),
             (Line2D([0], [0], color='orange', linewidth=1.5,
                     linestyle=':'), 'mean')]
    plt.legend(*zip(*lines),
               loc='upper center',
               ncol=len(lines),
               bbox_to_anchor=(-19.5, 1.1),
               labelspacing=10,
               fontsize=8)

    f.set_size_inches(w=10, h=4.8)
    plt.savefig(
        "figure__signerror_cdf_per_hp_easy_envs__ds_%s__mess_%f__filternoise_%s_cdf_per_hp.pdf"
        % (data_key, min_ess, str(filter_noise).lower()),
        bbox_inches="tight")
Пример #3
0
def make_figure(datasets, min_ess=12, filter_noise=True):
    data_key = "_".join(datasets)

    # Load precomputations
    precomp = pickle.load(
        open(
            ENVIRONMENT_CACHE_PATH +
            "/precomputations__filternoise%s__%s.pkl" %
            (str(filter_noise).lower(), data_key), "rb"))

    # Get the losses for each generalization measure (also called complexity measure here), per hp
    complexity_losses_per_hp = {}
    for c in precomp["env_losses"].keys():

        # Use the FFT version of spectral measures if available
        if "_fft" not in c and c + "_fft" in precomp["env_losses"].keys():
            print("Skipping", c, "in favor of", c + "_fft")
            continue

        # For the current generalization measure, get the sign-errors in each environment where an
        # HP varies and store this per HP
        complexity_losses_per_hp[c] = {}
        for hp in precomp["hps"]:
            complexity_losses_per_hp[c][hp] = get_all_losses(hp,
                                                             precomp,
                                                             measure=c,
                                                             min_ess=min_ess)
        complexity_losses_per_hp[c]["all"] = np.hstack([
            complexity_losses_per_hp[c][h] for h in complexity_losses_per_hp[c]
        ])

        # Sanity check
        assert complexity_losses_per_hp[c]["all"].shape[0] == \
            sum(complexity_losses_per_hp[c][hp].shape[0] for hp in precomp["hps"])

    # Order measures by mean sign error over all HPs
    ordered_measures = \
        np.array(list(complexity_losses_per_hp.keys()))[np.argsort([np.mean(complexity_losses_per_hp[c]["all"])
                                                                    for c in complexity_losses_per_hp])].tolist()
    # ordered_measures = np.sort(list(env_losses.keys())).tolist()  # Uncomment to order by name

    # Ordering and rendering used in the plot
    ordered_hps = [
        "all", "hp.lr", "hp.model_depth", "hp.model_width", "hp.train_size",
        "hp.dataset"
    ]
    pretty_hps = {
        "all": "All",
        "hp.lr": "LR",
        "hp.model_depth": "Depth",
        "hp.model_width": "Width",
        "hp.train_size": "Train size",
        "hp.dataset": "Dataset"
    }

    # Don't plot dataset axis if there is only a single one
    if len(datasets) == 1:
        ordered_hps.remove("hp.dataset")
        precomp["hps"].remove("hp.dataset")

    bins = np.linspace(0, 1, 100)
    f, axes = plt.subplots(ncols=1,
                           nrows=len(ordered_hps),
                           sharex=True,
                           sharey=True)
    cbar_ax = f.add_axes([.91, .127, .02, .75])
    for ax, hp in zip(axes, ordered_hps):
        z = np.zeros((len(bins), len(ordered_measures)))
        for i, c in enumerate(ordered_measures):
            # Get losses
            losses = complexity_losses_per_hp[c][hp]

            # Plot mean and max
            ax.axvline(i,
                       linestyle="-",
                       color="white",
                       linewidth=3,
                       zorder=999)

            if len(losses) > 0:
                # We need to multiply by 100 for the lines to appear in the correct place, because the heatmap's y-axis
                # goes from 0 to 100 (number of bins) and loss goes from 0 to 1.
                ax.plot([i, i + 1],
                        [np.mean(losses) * 100,
                         np.mean(losses) * 100],
                        color="orange",
                        zorder=2,
                        linewidth=1.5,
                        linestyle=":")
                ax.plot([i, i + 1], [
                    np.percentile(losses, q=90) * 100,
                    np.percentile(losses, q=90) * 100
                ],
                        color="magenta",
                        zorder=2,
                        linewidth=1.5,
                        linestyle="--")
                ax.plot([i, i + 1],
                        [np.max(losses) * 100,
                         np.max(losses) * 100],
                        color="limegreen",
                        zorder=1,
                        linewidth=1.5)

                # Calculate CDF
                for j, b in enumerate(bins):
                    z[j, i] = (losses <= b).sum() / len(losses)
            else:
                # No data = no environment had a total weight ≥ min weight
                ax.scatter([i + 0.5], [50], marker="x", color="red")

        if z.sum() > 0:
            heatmap = sns.heatmap(z,
                                  cmap="Blues_r",
                                  vmin=0.5,
                                  vmax=1,
                                  rasterized=True,
                                  ax=ax,
                                  cbar_ax=cbar_ax)
            heatmap.collections[0].colorbar.ax.tick_params(labelsize=8)

        ax.invert_yaxis(
        )  # Seaborn will by default range the y axis from 100 to 0
        ax.set_yticks([0, 50, 100])
        ax.set_yticklabels([0, 0.5, 1], fontsize=6, rotation=0)
        ax.set_ylabel(pretty_hps[hp] + "\n(%d)" %
                      (len(complexity_losses_per_hp[list(
                          complexity_losses_per_hp.keys())[0]][hp])),
                      fontsize=8)
        ax.set_ylim(-1, 102)

    axes[-1].set_xticks(np.arange(len(ordered_measures)) + 0.5)
    axes[-1].set_xticklabels([pretty_measure(c) for c in ordered_measures],
                             rotation=45,
                             fontsize=8,
                             ha="right")

    lines = [(Line2D([0], [0], color='limegreen', linewidth=1.5,
                     linestyle='-'), 'max'),
             (Line2D([0], [0], color='magenta', linewidth=1.5,
                     linestyle='--'), '90th percentile'),
             (Line2D([0], [0], color='orange', linewidth=1.5,
                     linestyle=':'), 'mean')]
    plt.legend(*zip(*lines),
               loc='upper center',
               ncol=len(lines),
               bbox_to_anchor=(-19.5, 1.1),
               labelspacing=10,
               fontsize=8)

    f.set_size_inches(w=10, h=4.8)
    plt.savefig(
        "figure__signerror_cdf_per_hp__ds_%s__mess_%f__filternoise_%s_cdf_per_hp.pdf"
        % (data_key, min_ess, str(filter_noise).lower()),
        bbox_inches="tight")
Пример #4
0
def make_figure(datasets, measure, hp, min_ess=12, filter_noise=True):
    data_key = "_".join(datasets)

    precomp = pickle.load(
        open(
            ENVIRONMENT_CACHE_PATH +
            "/precomputations__filternoise%s__%s.pkl" %
            (str(filter_noise).lower(), data_key), "rb"))

    box_losses = triangle_cdf_plots_get_losses(hp,
                                               precomp,
                                               measure,
                                               min_ess=min_ess)

    values = np.unique([vals[0] for vals in box_losses])

    f, axes = plt.subplots(ncols=len(values),
                           nrows=len(values),
                           sharex=True,
                           sharey=True)
    cbar_ax = f.add_axes([.77, .127, .05, .55])
    for i, v1 in enumerate(values):
        for j, v2 in enumerate(values):
            if j >= i:
                axes[i, j].set_visible(False)
                continue

            bins = np.linspace(0, 1, 100)

            if len(box_losses[(v1, v2)]) != 0:
                # Calculate CDF
                z = np.zeros((len(bins), 1))
                for k, b in enumerate(bins):
                    z[k] = (box_losses[(v1, v2)] <= b).sum() / len(
                        box_losses[(v1, v2)])

                heatmap = sns.heatmap(z,
                                      cmap="Blues_r",
                                      vmin=0.5,
                                      vmax=1,
                                      rasterized=True,
                                      ax=axes[i, j],
                                      cbar_ax=cbar_ax)

                axes[i, j].axhline(np.percentile(box_losses[(v1, v2)], q=90) *
                                   100,
                                   color="magenta",
                                   linestyle="--",
                                   linewidth=1.5,
                                   zorder=2)
                axes[i, j].axhline(np.mean(box_losses[(v1, v2)]) * 100,
                                   color="orange",
                                   linestyle=":",
                                   linewidth=1.5,
                                   zorder=2)
                axes[i, j].axhline(np.max(box_losses[(v1, v2)]) * 100,
                                   color="limegreen",
                                   linewidth=1.5,
                                   zorder=1)
            else:
                # No data = no environment had a total weight ≥ min weight
                axes[i, j].scatter([0.5], [50], color="red", marker="x", s=30)

            axes[i, j].invert_yaxis()
            axes[i, j].set_ylim([-1, 102])
            axes[i, j].set_yticks([0, 50, 100])
            axes[i, j].set_yticklabels([0, 0.5, 1], fontsize=5, rotation=0)
            axes[i, j].set_xticklabels([], fontsize=6, rotation=0)
            axes[i, j].xaxis.set_visible(False)
            axes[i, j].yaxis.set_visible(False)

    # Set axis labels
    for i, v1 in enumerate(values):
        if hp != "hp.dataset":
            if isinstance(v1, float):
                axes[i, 0].set_ylabel("%s=%f" % (pretty_hps[hp], v1),
                                      fontsize=6)
                axes[-1, i].set_xlabel("%s=%f" % (pretty_hps[hp], v1),
                                       fontsize=6,
                                       rotation=45,
                                       ha="right")
            else:
                font_size = 5 if hp == "hp.model_width" else 6
                axes[i, 0].set_ylabel("%s=%d" % (pretty_hps[hp], v1),
                                      fontsize=font_size)
                axes[-1, i].set_xlabel("%s=%d" % (pretty_hps[hp], v1),
                                       fontsize=font_size,
                                       rotation=45,
                                       ha="right")
        else:
            axes[i, 0].set_ylabel("%s=%s" % (pretty_hps[hp], v1), fontsize=6)
            axes[-1, i].set_xlabel("%s=%s" % (pretty_hps[hp], v1),
                                   fontsize=6,
                                   rotation=45,
                                   ha="right")
        axes[i, 0].yaxis.set_visible(True)
        axes[-1, i].xaxis.set_visible(True)

    heatmap.collections[0].colorbar.ax.tick_params(labelsize=6)
    plt.subplots_adjust(wspace=0.2, hspace=0.2)

    lines = [(Line2D([0], [0], color='limegreen', linewidth=1.5,
                     linestyle='-'), 'max'),
             (Line2D([0], [0], color='magenta', linewidth=1.5,
                     linestyle='--'), '90th percentile'),
             (Line2D([0], [0], color='orange', linewidth=1.5,
                     linestyle=':'), 'mean')]
    plt.legend(*zip(*lines),
               loc='upper center',
               ncol=len(lines),
               bbox_to_anchor=(-6.7, 1.22),
               columnspacing=1,
               fontsize=4.5)

    f.set_size_inches(w=1.3, h=2.7)
    plt.savefig(
        "figure_triangle_cdf__ds_%s__mess_%f__gm_%s__filternoise_%s_hp_%s.pdf"
        % (data_key, min_ess, pretty_measure(measure),
           str(filter_noise).lower(), hp),
        bbox_inches="tight")