def create_trackers(runs): trackers = {} for i_run, run in enumerate(runs): for f in run.files(per_page=10000): if not f.name.startswith("export") or "/confusion" not in f.name: continue if f.name not in trackers: trackers[f.name] = StatTracker() full_name = os.path.join(VER_DIR, f.name) print(f"Downloading {full_name}") f.download(root=VER_DIR, replace=True) data = torch.load(full_name) data = data.astype(np.float32) if "confusion_difference" not in f.name: data = data / np.sum(data, axis=1, keepdims=True) data = data * 100 trackers[f.name].add(data) # break # # if i_run >= 2: # break return trackers
def run(name: str, shape: List[int], y_lim=None, coord=None): runs = lib.get_runs([name]) trackers = {} for r in runs: for k, s in r.summary.items(): if "/mask_stat/shared/" not in k or ( not (k.endswith("_weight") and k.startswith("permuted_mnist/perm_"))): continue if k not in trackers: trackers[k] = StatTracker() trackers[k].add(s * 100) perm_by_layer = {} for perm in range(1, 10, 2): prefix = f"permuted_mnist/perm_{perm}/mask_stat/shared/" for k, v in trackers.items(): if not k.startswith(prefix): continue layer_name = k.split("/")[-1] if layer_name not in perm_by_layer: perm_by_layer[layer_name] = [] perm_by_layer[layer_name].append(v) fig = plt.figure(figsize=shape) # ax = fig.add_subplot(111,aspect=0.06) n_col = 5 d = n_col + 1 names = OrderedDict() names["Layer 1"] = "layers_0_weight" names["Layer 2"] = "layers_1_weight" names["Layer 3"] = "layers_2_weight" names["Layer 4"] = "layers_3_weight" for j in range(n_col): stats = [perm_by_layer[n][j].get() for n in names.values()] plt.bar([i * d + j for i in range(len(names))], [s.mean for s in stats], yerr=[s.std for s in stats], align='center') plt.xticks([d * x + n_col / 2 - 0.5 for x in range(len(names))], names.keys()) plt.legend([f"T{c*2+2}" for c in range(n_col)], ncol=2, loc="upper center") plt.ylabel("Weights shared [\\%]") if y_lim is not None: plt.ylim(*y_lim) if coord is not None: fig.axes[0].yaxis.set_label_coords(*coord) fig.savefig(f"{basedir}/{name}.pdf", bbox_inches='tight', pad_inches=0.01)
def get_relative_drop(dir, runs): trackers = {k: StatTracker() for k in cifar10_classes} for run in runs: t_dir = os.path.join(BASE_DIR, dir, run.id) ref = torch.load(os.path.join(t_dir, "export/class_removal/confusion_reference.pth")).astype(np.float32) ref = ref / np.sum(ref, axis=1, keepdims=True) ref = np.diag(ref) for i, cls in enumerate(cifar10_classes): d = torch.load(os.path.join(t_dir,"export/class_removal/confusion_difference/",cls+".pth")) d = np.diag(d) rel_drop = d / ref * 100 trackers[cls].add(-rel_drop[i]) return [trackers[k].get() for k in cifar10_classes]
for k, v in run.summary.items(): kparts = k.split("/") if kparts[-1] != "n_1" or "/all/" in k or not k.startswith( "mask_stat/"): continue print(k, v) shared = run.summary["/".join(kparts[:-1] + ["shared_1"])] print("SHARED", shared, v * shared, v) tsum += v ssum += v * shared if grp not in all_stats: all_stats[grp] = StatTracker() all_stats[grp].add(ssum / tsum) order = [ "layer_sizes_400,400,200", "layer_sizes_800,800,800,800", "layer_sizes_2000,2000,2000,2000", "layer_sizes_4000,4000,4000,4000" ] stats = [all_stats[o].get() for o in order] fig = plt.figure(figsize=[6, 2]) plt.bar([x for x in range(len(order))], [s.mean * 100 for s in stats], yerr=[s.std * 100 for s in stats], align='center') plt.xticks([x for x in range(len(order))], ["small", "medium", "big", "huge"]) plt.ylabel("Total shared [\\%]")
else: step = r.config["stop_after"] hist = r.scan_history(keys=["validation/hard/accuracy/total"], min_step=step - 1, max_step=step + 1) accuracies["train"] = None for h in hist: assert accuracies["train"] is None accuracies["train"] = h["validation/hard/accuracy/total"] accuracies["control"] = r.summary[ "analysis_results/verify/hard/accuracy/total"] accuracies["hard"] = r.summary["analysis_results/hard/hard/accuracy/total"] if r.sweep.id not in trackers: trackers[r.sweep.id] = {k: StatTracker() for k in accuracies.keys()} for k, v in accuracies.items(): trackers[r.sweep.id][k].add(v * 100) trackers = { k: trackers[lib.source.sweep_table[n]] for k, n in run_table.items() } cols = ["train", "control", "hard"] p = {n: [trackers[k][n].get() for k in run_table.keys()] for n in cols} fig = plt.figure(figsize=[4, 0.9]) for i, n in enumerate(cols): plt.bar([(len(cols) + 1) * x + i for x in range(len(run_table))],
all_stats = {} for grp, rn in runs.items(): if grp not in all_stats: all_stats[grp] = {} stats = all_stats[grp] for r in rn: for k, v in r.summary.items(): if not k.startswith("mask_stat/") or "/n_" not in k: continue if k not in stats: stats[k] = StatTracker() stats[k].add(v) if not all_stats[grp]: del all_stats[grp] def friendly_name(name: str) -> str: if name.startswith("mask_"): name = name[5:] if name.endswith("_weight"): name = name[:-7] name = name.replace("_weight_", "_")
def add_tracker(trackers, name, data): if name not in trackers: trackers[name] = StatTracker() hist, _ = np.histogram(data, N_POINTS, [0, 1]) trackers[name].add(hist)
print(k, v) shared = run.summary["/".join(kparts[:-1] + ["shared_1"])] print("SHARED", shared, v * shared, v) tsum += v ssum += v * shared stat_name = [run.config["layer_sizes"], run.config["mask_loss_weight"]] if stat_name[0] not in sharing_stats: sharing_stats[stat_name[0]] = {} accuracy_stats[stat_name[0]] = {} if stat_name[1] not in sharing_stats[stat_name[0]]: sharing_stats[stat_name[0]][stat_name[1]] = StatTracker() accuracy_stats[stat_name[0]][stat_name[1]] = StatTracker() accuracy_stats[stat_name[0]][stat_name[1]].add( run.summary["analyzer/baseline/validation/iid/accuracy"] * 100) sharing_stats[stat_name[0]][stat_name[1]].add(ssum / tsum * 100) def plot(accuracy_stats, sharing_stats): sharing = list(sorted(sharing_stats.keys())) def get_col(): return [sharing_stats[s].get().mean for s in sharing] def get_y(): return [accuracy_stats[s].get().mean for s in sharing]
import lib from lib import StatTracker runs = lib.get_runs(["cifar10_mask_stability"]) stat = StatTracker() for r in runs: for k, v in r.summary.items(): if k.startswith("masks_stability/") and "/pair_" in k: stat.add(v) print(stat) print("Count", stat.n)