Beispiel #1
0
    def create_trackers(runs):
        trackers = {}

        for i_run, run in enumerate(runs):
            for f in run.files(per_page=10000):
                if not f.name.startswith("export") or "/confusion" not in f.name:
                    continue

                if f.name not in trackers:
                    trackers[f.name] = StatTracker()

                full_name = os.path.join(VER_DIR, f.name)
                print(f"Downloading {full_name}")
                f.download(root=VER_DIR, replace=True)

                data = torch.load(full_name)
                data = data.astype(np.float32)
                if "confusion_difference" not in f.name:
                    data = data / np.sum(data, axis=1, keepdims=True)
                data = data * 100
                trackers[f.name].add(data)
                # break
            #
            # if i_run >= 2:
            #     break

        return trackers
Beispiel #2
0
def run(name: str, shape: List[int], y_lim=None, coord=None):

    runs = lib.get_runs([name])

    trackers = {}

    for r in runs:
        for k, s in r.summary.items():
            if "/mask_stat/shared/" not in k or (
                    not (k.endswith("_weight")
                         and k.startswith("permuted_mnist/perm_"))):
                continue
            if k not in trackers:
                trackers[k] = StatTracker()

            trackers[k].add(s * 100)

    perm_by_layer = {}

    for perm in range(1, 10, 2):
        prefix = f"permuted_mnist/perm_{perm}/mask_stat/shared/"
        for k, v in trackers.items():
            if not k.startswith(prefix):
                continue

            layer_name = k.split("/")[-1]
            if layer_name not in perm_by_layer:
                perm_by_layer[layer_name] = []

            perm_by_layer[layer_name].append(v)

    fig = plt.figure(figsize=shape)
    # ax = fig.add_subplot(111,aspect=0.06)
    n_col = 5
    d = n_col + 1

    names = OrderedDict()
    names["Layer 1"] = "layers_0_weight"
    names["Layer 2"] = "layers_1_weight"
    names["Layer 3"] = "layers_2_weight"
    names["Layer 4"] = "layers_3_weight"

    for j in range(n_col):
        stats = [perm_by_layer[n][j].get() for n in names.values()]
        plt.bar([i * d + j for i in range(len(names))],
                [s.mean for s in stats],
                yerr=[s.std for s in stats],
                align='center')

    plt.xticks([d * x + n_col / 2 - 0.5 for x in range(len(names))],
               names.keys())
    plt.legend([f"T{c*2+2}" for c in range(n_col)], ncol=2, loc="upper center")
    plt.ylabel("Weights shared [\\%]")
    if y_lim is not None:
        plt.ylim(*y_lim)

    if coord is not None:
        fig.axes[0].yaxis.set_label_coords(*coord)
    fig.savefig(f"{basedir}/{name}.pdf", bbox_inches='tight', pad_inches=0.01)
def get_relative_drop(dir, runs):
    trackers = {k: StatTracker() for k in cifar10_classes}

    for run in runs:
        t_dir = os.path.join(BASE_DIR, dir, run.id)
        ref = torch.load(os.path.join(t_dir, "export/class_removal/confusion_reference.pth")).astype(np.float32)
        ref = ref / np.sum(ref, axis=1, keepdims=True)
        ref = np.diag(ref)

        for i, cls in enumerate(cifar10_classes):
            d = torch.load(os.path.join(t_dir,"export/class_removal/confusion_difference/",cls+".pth"))
            d = np.diag(d)

            rel_drop = d / ref * 100
            trackers[cls].add(-rel_drop[i])

    return [trackers[k].get() for k in cifar10_classes]
Beispiel #4
0
        for k, v in run.summary.items():
            kparts = k.split("/")
            if kparts[-1] != "n_1" or "/all/" in k or not k.startswith(
                    "mask_stat/"):
                continue

            print(k, v)

            shared = run.summary["/".join(kparts[:-1] + ["shared_1"])]

            print("SHARED", shared, v * shared, v)
            tsum += v
            ssum += v * shared

        if grp not in all_stats:
            all_stats[grp] = StatTracker()

        all_stats[grp].add(ssum / tsum)

order = [
    "layer_sizes_400,400,200", "layer_sizes_800,800,800,800",
    "layer_sizes_2000,2000,2000,2000", "layer_sizes_4000,4000,4000,4000"
]
stats = [all_stats[o].get() for o in order]

fig = plt.figure(figsize=[6, 2])
plt.bar([x for x in range(len(order))], [s.mean * 100 for s in stats],
        yerr=[s.std * 100 for s in stats],
        align='center')
plt.xticks([x for x in range(len(order))], ["small", "medium", "big", "huge"])
plt.ylabel("Total shared [\\%]")
Beispiel #5
0
    else:
        step = r.config["stop_after"]
        hist = r.scan_history(keys=["validation/hard/accuracy/total"],
                              min_step=step - 1,
                              max_step=step + 1)
        accuracies["train"] = None
        for h in hist:
            assert accuracies["train"] is None
            accuracies["train"] = h["validation/hard/accuracy/total"]

    accuracies["control"] = r.summary[
        "analysis_results/verify/hard/accuracy/total"]
    accuracies["hard"] = r.summary["analysis_results/hard/hard/accuracy/total"]

    if r.sweep.id not in trackers:
        trackers[r.sweep.id] = {k: StatTracker() for k in accuracies.keys()}

    for k, v in accuracies.items():
        trackers[r.sweep.id][k].add(v * 100)

trackers = {
    k: trackers[lib.source.sweep_table[n]]
    for k, n in run_table.items()
}

cols = ["train", "control", "hard"]
p = {n: [trackers[k][n].get() for k in run_table.keys()] for n in cols}

fig = plt.figure(figsize=[4, 0.9])
for i, n in enumerate(cols):
    plt.bar([(len(cols) + 1) * x + i for x in range(len(run_table))],
all_stats = {}

for grp, rn in runs.items():
    if grp not in all_stats:
        all_stats[grp] = {}

    stats = all_stats[grp]

    for r in rn:
        for k, v in r.summary.items():
            if not k.startswith("mask_stat/") or "/n_" not in k:
                continue

            if k not in stats:
                stats[k] = StatTracker()

            stats[k].add(v)

    if not all_stats[grp]:
        del all_stats[grp]


def friendly_name(name: str) -> str:
    if name.startswith("mask_"):
        name = name[5:]

    if name.endswith("_weight"):
        name = name[:-7]

    name = name.replace("_weight_", "_")
def add_tracker(trackers, name, data):
    if name not in trackers:
        trackers[name] = StatTracker()

    hist, _ = np.histogram(data, N_POINTS, [0, 1])
    trackers[name].add(hist)
            print(k, v)

            shared = run.summary["/".join(kparts[:-1] + ["shared_1"])]

            print("SHARED", shared, v * shared, v)
            tsum += v
            ssum += v * shared

        stat_name = [run.config["layer_sizes"], run.config["mask_loss_weight"]]

        if stat_name[0] not in sharing_stats:
            sharing_stats[stat_name[0]] = {}
            accuracy_stats[stat_name[0]] = {}

        if stat_name[1] not in sharing_stats[stat_name[0]]:
            sharing_stats[stat_name[0]][stat_name[1]] = StatTracker()
            accuracy_stats[stat_name[0]][stat_name[1]] = StatTracker()

        accuracy_stats[stat_name[0]][stat_name[1]].add(
            run.summary["analyzer/baseline/validation/iid/accuracy"] * 100)
        sharing_stats[stat_name[0]][stat_name[1]].add(ssum / tsum * 100)


def plot(accuracy_stats, sharing_stats):
    sharing = list(sorted(sharing_stats.keys()))

    def get_col():
        return [sharing_stats[s].get().mean for s in sharing]

    def get_y():
        return [accuracy_stats[s].get().mean for s in sharing]
Beispiel #9
0
import lib
from lib import StatTracker

runs = lib.get_runs(["cifar10_mask_stability"])

stat = StatTracker()
for r in runs:
    for k, v in r.summary.items():
        if k.startswith("masks_stability/") and "/pair_" in k:
            stat.add(v)

print(stat)
print("Count", stat.n)