def run(self, cfg: Config):
        path = (
            f"results/"
            f"sigma:0.936-eta_lsa:0.003-eta_msa:0.525-eta_dsa:0.707-eta_ae:0.631-/params/step_49999/rank_{self.rank % 3}"
        )
        dataset = MNISTDataset()
        all_agents: List[AutoEncoder] = self._load_aes(path)

        cnns: List[CNN] = [CNN() for _ in all_agents]
        bsize = 512
        nsteps = 1000

        for cnn, agent in zip(cnns, all_agents):
            for i in range(nsteps):
                ims, targets = dataset.sample_with_label(bsize)
                reconstruction = agent(ims)
                cnn.train(reconstruction, targets)
                acc = cnn.compute_acc(reconstruction, targets)
                self.writer.add(
                    (
                        self.rank,
                        i,
                        "MA" if agent.name != "baseline" else "Baseline",
                        acc,
                    ),
                    step=i,
                )
Ejemplo n.º 2
0
    def run(self, cfg: Config):
        path = (
            f"results/jeanzay/results/sweeps/shared_ref_mnist/2021-04-16/13-15-58/"
            f"sigma:0-eta_lsa:0-eta_msa:1-eta_dsa:0-eta_ae:0-/params/step_49999/rank_{self.rank % 3}"
        )
        dataset = MNISTDataset()
        all_agents: List[AutoEncoder] = self._load_aes(path)

        mlps: List[MLP] = [MLP(30) for _ in all_agents]
        bsize = 512
        nsteps = 5000

        for mlp, agent in zip(mlps, all_agents):
            for i in range(nsteps):
                ims, targets = dataset.sample_with_label(bsize)
                encoding = agent.encode(ims)
                mlp.train(encoding, targets)
                acc = mlp.compute_acc(encoding, targets)
                self.writer.add(
                    (
                        self.rank,
                        i,
                        "MA" if agent.name != "baseline" else "Baseline",
                        acc,
                    ),
                    step=i,
                )
Ejemplo n.º 3
0
    def run(self, cfg: Config):

        if cfg.centralised:
            base_path = os.path.join(os.path.expandvars("$SCRATCH"),
                                     "results/sweeps/shared_ref_mnist/2021-05-20/21-26-45/")
            # base_path = "results/jeanzay/results/sweeps/shared_ref_mnist/2021-05-20/21-26-45/"

            ae_path = os.path.join(base_path, "eta_ae:1-eta_lsa:0.0-eta_msa:0.0-eta_dsa:0.0-sigma:0.67-")
            msa_path = os.path.join(base_path, "eta_ae:0.0-eta_lsa:0.0-eta_msa:1-eta_dsa:0.0-sigma:0.67-")
            lsa_path = os.path.join(base_path, "eta_ae:0.53-eta_lsa:0.01-eta_msa:0.74-eta_dsa:0.84-sigma:0.33-")
        else:
            base_path = os.path.expandvars("$SCRATCH")

            ae_path = os.path.join(base_path, "results/sweeps/shared_ref_mnist/2021-05-16/"
                                              "13-09-07/sigma:0.67-eta_ae:1.0-eta_msa:0.0-eta_lsa:0.0-eta_dsa:0.0-")
            msa_path = os.path.join(base_path, "results/sweeps/shared_ref_mnist/2021-05-16/"
                                               "13-09-07/sigma:0.67-eta_ae:0.0-eta_msa:1.0-eta_lsa:0.0-eta_dsa:0.0-")
            lsa_path = os.path.join(base_path, "results/sweeps/shared_ref_mnist/2021-05-15/"
                                               "13-21-57/sigma:0.33-eta_ae:0.67-eta_msa:0.67-eta_lsa:0.0-eta_dsa:0.0-")

        paths = {"AE": ae_path, "MTI": msa_path, "AE-MTM": lsa_path}

        for exp_name, path in paths.items():

            path = path + f"/params/step_39999/rank_{self.rank % 3}"
            dataset = MNISTDataset()
            all_agents: List[AutoEncoder] = self._load_aes(path)

            mlps: List[MLP] = [MLP(30) for _ in all_agents]

            for mlp, agent in zip(mlps, all_agents):
                for i in range(cfg.nsteps):
                    ims, targets = dataset.sample_with_label(cfg.bsize)
                    encoding = agent.encode(ims)
                    encoding = encoding + torch.randn_like(encoding) * cfg.sigma
                    mlp.train(encoding, targets)
                    acc = mlp.compute_acc(encoding, targets)
                    self.writer.add(
                        (
                            cfg.centralised,
                            exp_name,
                            self.rank,
                            i,
                            "MA" if agent.name != "baseline" else "Baseline",
                            acc,
                        ),
                        step=i,
                    )
                self.writer._write()
Ejemplo n.º 4
0
def plot_tsne(path, path_to_plot, tag):
    dataset = MNISTDataset()
    ims, labels = dataset.sample_with_label(10000)
    all_agents = _load_aes(path)

    results = []

    for ae in all_agents:
        encoded = ae.encode(ims)
        embedding = TSNE(n_components=2,
                         random_state=4444,
                         perplexity=50,
                         n_jobs=8).fit_transform(encoded.detach())
        for emb, label in zip(embedding[::5], labels[::5]):
            results.append((emb[0], emb[1], int(label.item()), ae.name))

    _generate_tsne_relplot(results, path_to_plot, tag)
Ejemplo n.º 5
0
class Experiment(BaseExperiment):
    def run(self, cfg: Config):

        paths = []

        tb_path = f"{self.path}/tb/{self.rank}"
        self.tb = SummaryWriter(tb_path)

        self.dataset = MNISTDataset()
        agents = self.load_aes(
            os.path.join(
                self.path,
                "params",
                "step_39999",
            )
        )
        mlps = []
        for agent in agents:
            mlp = self.train_classifier(agent)
            mlps.append(mlp)
        self.compute_cross_agent_cls(agents, mlps)

    def load_aes(self, path: str) -> List[AutoEncoder]:
        autoencoders = [
            AutoEncoder(30, False, False, 0.001, name)
            for name in string.ascii_uppercase[:3]
        ]
        base1 = AutoEncoder(30, False, False, 0.001, "baseline1").to(self.dev)
        base2 = AutoEncoder(30, False, False, 0.001, "baseline2").to(self.dev)
        baselines = [base1, base2]

        for agent in autoencoders:
            agent.load_state_dict(
                torch.load(
                    f"{path}/rank_{int(self.rank) % 5}/{agent.name}.pt",
                    map_location=self.dev,
                )
            )
        for i, agent in enumerate(baselines):
            agent.load_state_dict(
                torch.load(
                    f"{path}/rank_{(int(self.rank) + i) % 5}/{agent.name}.pt",
                    map_location=self.dev,
                )
            )
        return autoencoders + baselines

    def train_classifier(self, agent: AutoEncoder):
        mlp = MLP(30).to(self.dev)
        agent.to(self.dev)

        for i in range(int(self.cfg.nsteps)):
            X, y = map(
                lambda x: x.to(self.dev),
                self.dataset.sample_with_label(int(self.cfg.bsize)),
            )
            latent = agent.encode(X)
            mlp.train(latent, y)
            acc = mlp.compute_acc(latent, y)
            self.tb.add_scalar("Accuracy-Post", acc, global_step=i)
            # self.writer.add((acc.item(), agent.name), step=i)
        return mlp

    def compute_cross_agent_cls(self, agents: List[AutoEncoder], mlps: List[MLP]):
        ma_aes, ma_mlps = agents[:3], mlps[:3]
        sa_aes, sa_mlps = agents[3:], mlps[3:]

        X, y = map(
            lambda x: x.to(self.dev),
            self.dataset.sample_with_label(int(self.cfg.bsize)),
        )
        self._compute_cross_acc(X, y, ma_aes, ma_mlps, "MA")
        self._compute_cross_acc(X, y, sa_aes, sa_mlps, "Base")

    def _compute_cross_acc(self, X, y, aes, mlps, tag, rot=1):
        for i, (ae, mlp) in enumerate(zip(aes, mlps[rot:] + mlps[:rot])):
            latent = ae.encode(X)
            acc = mlp.compute_acc(latent, y)
            self.writer.add((tag, acc), step=i, tag="cross_agent_accuracy_override")
            self.tb.add_scalar(f"cross_agent_acc_{tag}", acc)

    def load_data(reader: TidyReader) -> Any:
        return reader.read(columns=["Rank", "Step", "Agent", "Accuracy"])

    def plot(df: DataFrame, plot_path: str) -> None:
        df.to_csv(plot_path + "/data.csv")
        sns.barplot(data=df, x="Agent", y="Accuracy")
        plt_name = f"{plot_path}/cross_agent_pred_acc_latent"
        plt.savefig(plt_name + ".svg")
        plt.savefig(plt_name + ".pdf")
        plt.savefig(plt_name + ".png")