Ejemplo n.º 1
0
    def run(self, cfg: Config):
        path = (
            f"results/jeanzay/results/sweeps/shared_ref_mnist/2021-04-16/13-15-58/"
            f"sigma:0-eta_lsa:0-eta_msa:1-eta_dsa:0-eta_ae:0-/params/step_49999/rank_{self.rank % 3}"
        )
        dataset = MNISTDataset()
        all_agents: List[AutoEncoder] = self._load_aes(path)

        mlps: List[MLP] = [MLP(30) for _ in all_agents]
        bsize = 512
        nsteps = 5000

        for mlp, agent in zip(mlps, all_agents):
            for i in range(nsteps):
                ims, targets = dataset.sample_with_label(bsize)
                encoding = agent.encode(ims)
                mlp.train(encoding, targets)
                acc = mlp.compute_acc(encoding, targets)
                self.writer.add(
                    (
                        self.rank,
                        i,
                        "MA" if agent.name != "baseline" else "Baseline",
                        acc,
                    ),
                    step=i,
                )
    def run(self, cfg: Config):
        path = (
            f"results/"
            f"sigma:0.936-eta_lsa:0.003-eta_msa:0.525-eta_dsa:0.707-eta_ae:0.631-/params/step_49999/rank_{self.rank % 3}"
        )
        dataset = MNISTDataset()
        all_agents: List[AutoEncoder] = self._load_aes(path)

        cnns: List[CNN] = [CNN() for _ in all_agents]
        bsize = 512
        nsteps = 1000

        for cnn, agent in zip(cnns, all_agents):
            for i in range(nsteps):
                ims, targets = dataset.sample_with_label(bsize)
                reconstruction = agent(ims)
                cnn.train(reconstruction, targets)
                acc = cnn.compute_acc(reconstruction, targets)
                self.writer.add(
                    (
                        self.rank,
                        i,
                        "MA" if agent.name != "baseline" else "Baseline",
                        acc,
                    ),
                    step=i,
                )
Ejemplo n.º 3
0
    def run(self, cfg: Config):

        if cfg.centralised:
            base_path = os.path.join(os.path.expandvars("$SCRATCH"),
                                     "results/sweeps/shared_ref_mnist/2021-05-20/21-26-45/")
            # base_path = "results/jeanzay/results/sweeps/shared_ref_mnist/2021-05-20/21-26-45/"

            ae_path = os.path.join(base_path, "eta_ae:1-eta_lsa:0.0-eta_msa:0.0-eta_dsa:0.0-sigma:0.67-")
            msa_path = os.path.join(base_path, "eta_ae:0.0-eta_lsa:0.0-eta_msa:1-eta_dsa:0.0-sigma:0.67-")
            lsa_path = os.path.join(base_path, "eta_ae:0.53-eta_lsa:0.01-eta_msa:0.74-eta_dsa:0.84-sigma:0.33-")
        else:
            base_path = os.path.expandvars("$SCRATCH")

            ae_path = os.path.join(base_path, "results/sweeps/shared_ref_mnist/2021-05-16/"
                                              "13-09-07/sigma:0.67-eta_ae:1.0-eta_msa:0.0-eta_lsa:0.0-eta_dsa:0.0-")
            msa_path = os.path.join(base_path, "results/sweeps/shared_ref_mnist/2021-05-16/"
                                               "13-09-07/sigma:0.67-eta_ae:0.0-eta_msa:1.0-eta_lsa:0.0-eta_dsa:0.0-")
            lsa_path = os.path.join(base_path, "results/sweeps/shared_ref_mnist/2021-05-15/"
                                               "13-21-57/sigma:0.33-eta_ae:0.67-eta_msa:0.67-eta_lsa:0.0-eta_dsa:0.0-")

        paths = {"AE": ae_path, "MTI": msa_path, "AE-MTM": lsa_path}

        for exp_name, path in paths.items():

            path = path + f"/params/step_39999/rank_{self.rank % 3}"
            dataset = MNISTDataset()
            all_agents: List[AutoEncoder] = self._load_aes(path)

            mlps: List[MLP] = [MLP(30) for _ in all_agents]

            for mlp, agent in zip(mlps, all_agents):
                for i in range(cfg.nsteps):
                    ims, targets = dataset.sample_with_label(cfg.bsize)
                    encoding = agent.encode(ims)
                    encoding = encoding + torch.randn_like(encoding) * cfg.sigma
                    mlp.train(encoding, targets)
                    acc = mlp.compute_acc(encoding, targets)
                    self.writer.add(
                        (
                            cfg.centralised,
                            exp_name,
                            self.rank,
                            i,
                            "MA" if agent.name != "baseline" else "Baseline",
                            acc,
                        ),
                        step=i,
                    )
                self.writer._write()
def main(args):
    ns = [
        10,
        100,
        1000,
        10000,
    ]
    epsilons = [0.2, 1]

    ds = MNISTDataset()
    data_x, data_y = ds.test_set.data.unsqueeze(1) / 255.0, ds.test_set.targets
    results_path = "results/full_res_cifar.csv"

    if not os.path.exists(results_path):
        results = evaluate_experiment(
            args.dti_path,
            args.mtm_path,
            data_x,
            data_y,
        )
        results.to_csv("results/full_res_cifar.csv")
    else:
        results = pd.read_csv(results_path)

    save_path = ("results/"
                 f"{args.name}"
                 f"_train{args.train_steps}"
                 f"_seed{args.seeds}"
                 f"_point{args.points}_cifar")

    plot_curves(results, ns, save_path)
    metrics_df = reprieve.compute_metrics(results, ns, epsilons)
    reprieve.render_latex(metrics_df, save_path=f"{save_path}metrics.tex")
Ejemplo n.º 5
0
def plot_tsne(path, path_to_plot, tag):
    dataset = MNISTDataset()
    ims, labels = dataset.sample_with_label(10000)
    all_agents = _load_aes(path)

    results = []

    for ae in all_agents:
        encoded = ae.encode(ims)
        embedding = TSNE(n_components=2,
                         random_state=4444,
                         perplexity=50,
                         n_jobs=8).fit_transform(encoded.detach())
        for emb, label in zip(embedding[::5], labels[::5]):
            results.append((emb[0], emb[1], int(label.item()), ae.name))

    _generate_tsne_relplot(results, path_to_plot, tag)
def plot_img_reconstructions(
    root_path: str,
    name_of_best_exp: str,
    path_to_plot: str,
    baseline: bool = False,
    epoch: int = 49999,
):
    dataset = MNISTDataset()
    ae = AutoEncoder(30, False, False, 0.001, "test")
    ae.load_state_dict(
        torch.load(
            os.path.join(
                root_path,
                name_of_best_exp,
                f"params/step_{epoch}/rank_0/{'A' if not baseline else 'baseline'}.pt",
            ),
            map_location=torch.device("cpu"),
        )
    )
    digits: torch.Tensor = dataset.sample(50)

    _, axes = plt.subplots(
        nrows=10,
        ncols=10,
        figsize=(10, 8),
        gridspec_kw=dict(
            wspace=0.0, hspace=0.0, top=0.95, bottom=0.05, left=0.17, right=0.845
        ),
    )
    axes = axes.reshape(50, 2)

    for digit, ax_column in zip(digits, axes):
        ax_column[0].imshow(digit.squeeze().detach())
        ax_column[0].set_axis_off()
        rec = ae(digit.reshape(1, 1, 28, 28))
        ax_column[1].imshow(rec.squeeze().detach())
        ax_column[1].set_axis_off()
    plt.show()
    exit(1)
    plt_path = f"plots/{path_to_plot}/reconstructions_baseline_{baseline}"
    plt.savefig(plt_path + ".pdf")
    plt.savefig(plt_path + ".svg")
    plt.close()
Ejemplo n.º 7
0
    def run(self, cfg: Config):

        paths = []

        tb_path = f"{self.path}/tb/{self.rank}"
        self.tb = SummaryWriter(tb_path)

        self.dataset = MNISTDataset()
        agents = self.load_aes(
            os.path.join(
                self.path,
                "params",
                "step_39999",
            )
        )
        mlps = []
        for agent in agents:
            mlp = self.train_classifier(agent)
            mlps.append(mlp)
        self.compute_cross_agent_cls(agents, mlps)
Ejemplo n.º 8
0
    def run(self, cfg: Config):
        path = (
            f"results/jeanzay/results/sweeps/shared_ref_mnist/2021-04-16/13-15-58/"
            f"sigma:0.757-eta_lsa:0.004-eta_msa:0.483-eta_dsa:0.623-eta_ae:0.153-/params/step_49999/rank_{self.rank}"
        )
        dataset = MNISTDataset()
        all_agents: List[AutoEncoder] = self._load_aes(path)

        mlps: List[MLP] = [MLP(30) for _ in all_agents]
        bsize = 512
        nsteps = 5000

        digits = list(range(10))

        for mlp, agent in zip(mlps, all_agents):
            for i in range(nsteps):
                ims_all_digits = []
                labels_all_digits = []
                # sample all digits. We need this to flexibly decide which
                # digits we want to classify.
                for digit in digits:
                    ims_all_digits.append(
                        dataset.sample_digit(digit, bsize // len(digits)))
                    labels_all_digits.append(
                        torch.empty(bsize // len(digits)).fill_(digit).long())
                # make batches out of the digits
                ims = torch.cat(ims_all_digits)
                targets = torch.cat(labels_all_digits)

                # fancy zip magic for shuffling batch
                # proably subotimal though since it leaves tensor form
                ims, targets = list(
                    zip(*random.sample(list(zip(ims, targets)), k=len(ims))))
                ims = torch.stack(ims)
                targets = torch.stack(targets)

                encoding = agent.encode(ims)
                mlp.train(encoding, targets)
                acc = mlp.compute_acc(encoding, targets)
                self.writer.add((agent.name, i, acc), step=i)
Ejemplo n.º 9
0
def get_results():
    results_path = "results/full_res.csv"
    if not os.path.exists(results_path):
        ds = MNISTDataset()
        data_x, data_y = ds.test_set.data.unsqueeze(
            1) / 255.0, ds.test_set.targets
        results = evaluate_experiment(
            args.dti_path,
            args.mtm_path,
            data_x,
            data_y,
        )
        return results
    else:
        return pd.read_csv(results_path)
Ejemplo n.º 10
0
def plot_reconstruction_sim_measure(root_path: str, name_of_exp: str,
                                    path_to_plot: str):
    dataset = MNISTDataset()
    agents = _load_aes(
        os.path.join(root_path, name_of_exp, "params", f"step_{int(EPOCH)}",
                     "rank_0"))

    results: List[Tuple[str, Number]] = []
    for i in range(10):
        batch = dataset.sample_digit(i)
        for agent in agents:
            rec = agent(batch)
            rec = (rec - rec.mean()) / (rec.std() + 0.0001)
            for a, b in itertools.combinations(rec, r=2):
                diff = F.mse_loss(a, b)
                results.append(
                    ("MA" if agent.name != "baseline" else "baseline",
                     diff.item()))
    df = pd.DataFrame(results, columns=["Agent", "Difference"])
    sns.barplot(data=df, x="Agent", y="Difference")
    plt_name = f"plots/{path_to_plot}/decoding_space_diff"
    plt.savefig(plt_name + ".svg")
    plt.savefig(plt_name + ".pdf")
    plt.close()
Ejemplo n.º 11
0
class Experiment(BaseExperiment):
    def run(self, cfg: Config):

        paths = []

        tb_path = f"{self.path}/tb/{self.rank}"
        self.tb = SummaryWriter(tb_path)

        self.dataset = MNISTDataset()
        agents = self.load_aes(
            os.path.join(
                self.path,
                "params",
                "step_39999",
            )
        )
        mlps = []
        for agent in agents:
            mlp = self.train_classifier(agent)
            mlps.append(mlp)
        self.compute_cross_agent_cls(agents, mlps)

    def load_aes(self, path: str) -> List[AutoEncoder]:
        autoencoders = [
            AutoEncoder(30, False, False, 0.001, name)
            for name in string.ascii_uppercase[:3]
        ]
        base1 = AutoEncoder(30, False, False, 0.001, "baseline1").to(self.dev)
        base2 = AutoEncoder(30, False, False, 0.001, "baseline2").to(self.dev)
        baselines = [base1, base2]

        for agent in autoencoders:
            agent.load_state_dict(
                torch.load(
                    f"{path}/rank_{int(self.rank) % 5}/{agent.name}.pt",
                    map_location=self.dev,
                )
            )
        for i, agent in enumerate(baselines):
            agent.load_state_dict(
                torch.load(
                    f"{path}/rank_{(int(self.rank) + i) % 5}/{agent.name}.pt",
                    map_location=self.dev,
                )
            )
        return autoencoders + baselines

    def train_classifier(self, agent: AutoEncoder):
        mlp = MLP(30).to(self.dev)
        agent.to(self.dev)

        for i in range(int(self.cfg.nsteps)):
            X, y = map(
                lambda x: x.to(self.dev),
                self.dataset.sample_with_label(int(self.cfg.bsize)),
            )
            latent = agent.encode(X)
            mlp.train(latent, y)
            acc = mlp.compute_acc(latent, y)
            self.tb.add_scalar("Accuracy-Post", acc, global_step=i)
            # self.writer.add((acc.item(), agent.name), step=i)
        return mlp

    def compute_cross_agent_cls(self, agents: List[AutoEncoder], mlps: List[MLP]):
        ma_aes, ma_mlps = agents[:3], mlps[:3]
        sa_aes, sa_mlps = agents[3:], mlps[3:]

        X, y = map(
            lambda x: x.to(self.dev),
            self.dataset.sample_with_label(int(self.cfg.bsize)),
        )
        self._compute_cross_acc(X, y, ma_aes, ma_mlps, "MA")
        self._compute_cross_acc(X, y, sa_aes, sa_mlps, "Base")

    def _compute_cross_acc(self, X, y, aes, mlps, tag, rot=1):
        for i, (ae, mlp) in enumerate(zip(aes, mlps[rot:] + mlps[:rot])):
            latent = ae.encode(X)
            acc = mlp.compute_acc(latent, y)
            self.writer.add((tag, acc), step=i, tag="cross_agent_accuracy_override")
            self.tb.add_scalar(f"cross_agent_acc_{tag}", acc)

    def load_data(reader: TidyReader) -> Any:
        return reader.read(columns=["Rank", "Step", "Agent", "Accuracy"])

    def plot(df: DataFrame, plot_path: str) -> None:
        df.to_csv(plot_path + "/data.csv")
        sns.barplot(data=df, x="Agent", y="Accuracy")
        plt_name = f"{plot_path}/cross_agent_pred_acc_latent"
        plt.savefig(plt_name + ".svg")
        plt.savefig(plt_name + ".pdf")
        plt.savefig(plt_name + ".png")