Ejemplo n.º 1
0
    def run(self, cfg: Config):
        path = (
            f"results/jeanzay/results/sweeps/shared_ref_mnist/2021-04-16/13-15-58/"
            f"sigma:0.757-eta_lsa:0.004-eta_msa:0.483-eta_dsa:0.623-eta_ae:0.153-/params/step_49999/rank_{self.rank}"
        )
        dataset = MNISTDataset()
        all_agents: List[AutoEncoder] = self._load_aes(path)

        mlps: List[MLP] = [MLP(30) for _ in all_agents]
        bsize = 512
        nsteps = 5000

        digits = list(range(10))

        for mlp, agent in zip(mlps, all_agents):
            for i in range(nsteps):
                ims_all_digits = []
                labels_all_digits = []
                # sample all digits. We need this to flexibly decide which
                # digits we want to classify.
                for digit in digits:
                    ims_all_digits.append(
                        dataset.sample_digit(digit, bsize // len(digits)))
                    labels_all_digits.append(
                        torch.empty(bsize // len(digits)).fill_(digit).long())
                # make batches out of the digits
                ims = torch.cat(ims_all_digits)
                targets = torch.cat(labels_all_digits)

                # fancy zip magic for shuffling batch
                # proably subotimal though since it leaves tensor form
                ims, targets = list(
                    zip(*random.sample(list(zip(ims, targets)), k=len(ims))))
                ims = torch.stack(ims)
                targets = torch.stack(targets)

                encoding = agent.encode(ims)
                mlp.train(encoding, targets)
                acc = mlp.compute_acc(encoding, targets)
                self.writer.add((agent.name, i, acc), step=i)
Ejemplo n.º 2
0
def plot_reconstruction_sim_measure(root_path: str, name_of_exp: str,
                                    path_to_plot: str):
    dataset = MNISTDataset()
    agents = _load_aes(
        os.path.join(root_path, name_of_exp, "params", f"step_{int(EPOCH)}",
                     "rank_0"))

    results: List[Tuple[str, Number]] = []
    for i in range(10):
        batch = dataset.sample_digit(i)
        for agent in agents:
            rec = agent(batch)
            rec = (rec - rec.mean()) / (rec.std() + 0.0001)
            for a, b in itertools.combinations(rec, r=2):
                diff = F.mse_loss(a, b)
                results.append(
                    ("MA" if agent.name != "baseline" else "baseline",
                     diff.item()))
    df = pd.DataFrame(results, columns=["Agent", "Difference"])
    sns.barplot(data=df, x="Agent", y="Difference")
    plt_name = f"plots/{path_to_plot}/decoding_space_diff"
    plt.savefig(plt_name + ".svg")
    plt.savefig(plt_name + ".pdf")
    plt.close()