Esempio n. 1
0
def visualize_trained_network_results(network,
                                      train,
                                      cell_type,
                                      conditions={
                                          "ctrl": "control",
                                          "stim": "stimulated"
                                      },
                                      condition_key="condition",
                                      cell_type_key="cell_type",
                                      path_to_save="./figures/",
                                      plot_umap=True,
                                      plot_reg=True):
    plt.close("all")
    os.makedirs(path_to_save, exist_ok=True)
    sc.settings.figdir = os.path.abspath(path_to_save)
    if isinstance(network, scgen.VAEArithKeras):
        if sparse.issparse(train.X):
            latent = network.to_latent(train.X.A)
        else:
            latent = network.to_latent(train.X)
        latent = sc.AnnData(X=latent,
                            obs={
                                condition_key:
                                train.obs[condition_key].tolist(),
                                cell_type_key:
                                train.obs[cell_type_key].tolist()
                            })
        if plot_umap:
            sc.pp.neighbors(latent)
            sc.tl.umap(latent)
            sc.pl.umap(latent,
                       color=[condition_key, cell_type_key],
                       save=f"_latent",
                       show=False)

        cell_type_data = train[train.obs[cell_type_key] == cell_type]

        pred, delta = network.predict(adata=cell_type_data,
                                      conditions=conditions,
                                      cell_type_key=cell_type_key,
                                      condition_key=condition_key,
                                      celltype_to_predict=cell_type)

        pred_adata = anndata.AnnData(
            pred,
            obs={condition_key: ["pred"] * len(pred)},
            var={"var_names": cell_type_data.var_names})
        all_adata = cell_type_data.concatenate(pred_adata)
        sc.tl.rank_genes_groups(cell_type_data,
                                groupby=condition_key,
                                n_genes=100)
        diff_genes = cell_type_data.uns["rank_genes_groups"]["names"][
            conditions["stim"]]
        if plot_reg:
            scgen.plotting.reg_mean_plot(all_adata,
                                         condition_key=condition_key,
                                         axis_keys={
                                             "x": "pred",
                                             "y": conditions["stim"]
                                         },
                                         gene_list=diff_genes[:5],
                                         path_to_save=os.path.join(
                                             path_to_save,
                                             f"reg_mean_all_genes.pdf"))

            scgen.plotting.reg_var_plot(all_adata,
                                        condition_key=condition_key,
                                        axis_keys={
                                            "x": "pred",
                                            "y": conditions["stim"]
                                        },
                                        gene_list=diff_genes[:5],
                                        path_to_save=os.path.join(
                                            path_to_save,
                                            f"reg_var_all_genes.pdf"))

            all_adata_top_100_genes = all_adata.copy()[:, diff_genes.tolist()]

            scgen.plotting.reg_mean_plot(all_adata_top_100_genes,
                                         condition_key=condition_key,
                                         axis_keys={
                                             "x": "pred",
                                             "y": conditions["stim"]
                                         },
                                         gene_list=diff_genes[:5],
                                         path_to_save=os.path.join(
                                             path_to_save,
                                             f"reg_mean_top_100_genes.pdf"))

            scgen.plotting.reg_var_plot(all_adata_top_100_genes,
                                        condition_key=condition_key,
                                        axis_keys={
                                            "x": "pred",
                                            "y": conditions["stim"]
                                        },
                                        gene_list=diff_genes[:5],
                                        path_to_save=os.path.join(
                                            path_to_save,
                                            f"reg_var_top_100_genes.pdf"))

            all_adata_top_50_genes = all_adata.copy()[:,
                                                      diff_genes.tolist()[:50]]

            scgen.plotting.reg_mean_plot(all_adata_top_50_genes,
                                         condition_key=condition_key,
                                         axis_keys={
                                             "x": "pred",
                                             "y": conditions["stim"]
                                         },
                                         gene_list=diff_genes[:5],
                                         path_to_save=os.path.join(
                                             path_to_save,
                                             f"reg_mean_top_50_genes.pdf"))

            scgen.plotting.reg_var_plot(all_adata_top_50_genes,
                                        condition_key=condition_key,
                                        axis_keys={
                                            "x": "pred",
                                            "y": conditions["stim"]
                                        },
                                        gene_list=diff_genes[:5],
                                        path_to_save=os.path.join(
                                            path_to_save,
                                            f"reg_var_top_50_genes.pdf"))

            if plot_umap:
                sc.pp.neighbors(all_adata)
                sc.tl.umap(all_adata)
                sc.pl.umap(all_adata,
                           color=condition_key,
                           save="pred_all_genes",
                           show=False)

                sc.pp.neighbors(all_adata_top_100_genes)
                sc.tl.umap(all_adata_top_100_genes)
                sc.pl.umap(all_adata_top_100_genes,
                           color=condition_key,
                           save="pred_top_100_genes",
                           show=False)

                sc.pp.neighbors(all_adata_top_50_genes)
                sc.tl.umap(all_adata_top_50_genes)
                sc.pl.umap(all_adata_top_50_genes,
                           color=condition_key,
                           save="pred_top_50_genes",
                           show=False)

        sc.pl.violin(all_adata,
                     keys=diff_genes.tolist()[0],
                     groupby=condition_key,
                     save=f"_{diff_genes.tolist()[0]}",
                     show=False)

        plt.close("all")

    elif isinstance(network, scgen.VAEArith):
        if sparse.issparse(train.X):
            latent = network.to_latent(train.X.A)
        else:
            latent = network.to_latent(train.X)
        latent = sc.AnnData(X=latent,
                            obs={
                                condition_key:
                                train.obs[condition_key].tolist(),
                                cell_type_key:
                                train.obs[cell_type_key].tolist()
                            })
        if plot_umap:
            sc.pp.neighbors(latent)
            sc.tl.umap(latent)
            sc.pl.umap(latent,
                       color=[condition_key, cell_type_key],
                       save=f"_latent",
                       show=False)

        cell_type_data = train[train.obs[cell_type_key] == cell_type]

        pred, delta = network.predict(adata=cell_type_data,
                                      conditions=conditions,
                                      cell_type_key=cell_type_key,
                                      condition_key=condition_key,
                                      celltype_to_predict=cell_type)

        pred_adata = anndata.AnnData(
            pred,
            obs={condition_key: ["pred"] * len(pred)},
            var={"var_names": cell_type_data.var_names})
        all_adata = cell_type_data.concatenate(pred_adata)
        sc.tl.rank_genes_groups(cell_type_data,
                                groupby=condition_key,
                                n_genes=100)
        diff_genes = cell_type_data.uns["rank_genes_groups"]["names"][
            conditions["stim"]]
        if plot_reg:
            scgen.plotting.reg_mean_plot(all_adata,
                                         condition_key=condition_key,
                                         axis_keys={
                                             "x": "pred",
                                             "y": conditions["stim"]
                                         },
                                         gene_list=diff_genes[:5],
                                         path_to_save=os.path.join(
                                             path_to_save,
                                             f"reg_mean_all_genes.pdf"))

            scgen.plotting.reg_var_plot(all_adata,
                                        condition_key=condition_key,
                                        axis_keys={
                                            "x": "pred",
                                            "y": conditions["stim"]
                                        },
                                        gene_list=diff_genes[:5],
                                        path_to_save=os.path.join(
                                            path_to_save,
                                            f"reg_var_all_genes.pdf"))

            all_adata_top_100_genes = all_adata.copy()[:, diff_genes.tolist()]

            scgen.plotting.reg_mean_plot(all_adata_top_100_genes,
                                         condition_key=condition_key,
                                         axis_keys={
                                             "x": "pred",
                                             "y": conditions["stim"]
                                         },
                                         gene_list=diff_genes[:5],
                                         path_to_save=os.path.join(
                                             path_to_save,
                                             f"reg_mean_top_100_genes.pdf"))

            scgen.plotting.reg_var_plot(all_adata_top_100_genes,
                                        condition_key=condition_key,
                                        axis_keys={
                                            "x": "pred",
                                            "y": conditions["stim"]
                                        },
                                        gene_list=diff_genes[:5],
                                        path_to_save=os.path.join(
                                            path_to_save,
                                            f"reg_var_top_100_genes.pdf"))

            all_adata_top_50_genes = all_adata.copy()[:,
                                                      diff_genes.tolist()[:50]]

            scgen.plotting.reg_mean_plot(all_adata_top_50_genes,
                                         condition_key=condition_key,
                                         axis_keys={
                                             "x": "pred",
                                             "y": conditions["stim"]
                                         },
                                         gene_list=diff_genes[:5],
                                         path_to_save=os.path.join(
                                             path_to_save,
                                             f"reg_mean_top_50_genes.pdf"))

            scgen.plotting.reg_var_plot(all_adata_top_50_genes,
                                        condition_key=condition_key,
                                        axis_keys={
                                            "x": "pred",
                                            "y": conditions["stim"]
                                        },
                                        gene_list=diff_genes[:5],
                                        path_to_save=os.path.join(
                                            path_to_save,
                                            f"reg_var_top_50_genes.pdf"))

            if plot_umap:
                sc.pp.neighbors(all_adata)
                sc.tl.umap(all_adata)
                sc.pl.umap(all_adata,
                           color=condition_key,
                           save="pred_all_genes",
                           show=False)

                sc.pp.neighbors(all_adata_top_100_genes)
                sc.tl.umap(all_adata_top_100_genes)
                sc.pl.umap(all_adata_top_100_genes,
                           color=condition_key,
                           save="pred_top_100_genes",
                           show=False)

                sc.pp.neighbors(all_adata_top_50_genes)
                sc.tl.umap(all_adata_top_50_genes)
                sc.pl.umap(all_adata_top_50_genes,
                           color=condition_key,
                           save="pred_top_50_genes",
                           show=False)

        sc.pl.violin(all_adata,
                     keys=diff_genes.tolist()[0],
                     groupby=condition_key,
                     save=f"_{diff_genes.tolist()[0]}",
                     show=False)

        plt.close("all")

    elif isinstance(network, scgen.CVAE):
        true_labels, _ = scgen.label_encoder(train)

        if sparse.issparse(train.X):
            latent = network.to_latent(train.X.A, labels=true_labels)
        else:
            latent = network.to_latent(train.X, labels=true_labels)
        latent = sc.AnnData(X=latent,
                            obs={
                                condition_key:
                                train.obs[condition_key].tolist(),
                                cell_type_key:
                                train.obs[cell_type_key].tolist()
                            })
        if plot_umap:
            sc.pp.neighbors(latent)
            sc.tl.umap(latent)
            sc.pl.umap(latent,
                       color=[condition_key, cell_type_key],
                       save=f"_latent",
                       show=False)

        cell_type_data = train[train.obs[cell_type_key] == cell_type]
        fake_labels = np.ones(shape=(cell_type_data.shape[0], 1))

        pred = network.predict(data=cell_type_data, labels=fake_labels)

        pred_adata = anndata.AnnData(
            pred,
            obs={condition_key: ["pred"] * len(pred)},
            var={"var_names": cell_type_data.var_names})

        all_adata = cell_type_data.concatenate(pred_adata)
        sc.tl.rank_genes_groups(cell_type_data,
                                groupby=condition_key,
                                n_genes=100)
        diff_genes = cell_type_data.uns["rank_genes_groups"]["names"][
            conditions["stim"]]
        if plot_reg:
            scgen.plotting.reg_mean_plot(all_adata,
                                         condition_key=condition_key,
                                         axis_keys={
                                             "x": "pred",
                                             "y": conditions["stim"]
                                         },
                                         gene_list=diff_genes[:5],
                                         path_to_save=os.path.join(
                                             path_to_save,
                                             f"reg_mean_all_genes.pdf"))

            scgen.plotting.reg_var_plot(all_adata,
                                        condition_key=condition_key,
                                        axis_keys={
                                            "x": "pred",
                                            "y": conditions["stim"]
                                        },
                                        gene_list=diff_genes[:5],
                                        path_to_save=os.path.join(
                                            path_to_save,
                                            f"reg_var_all_genes.pdf"))

            all_adata_top_100_genes = all_adata.copy()[:, diff_genes.tolist()]

            scgen.plotting.reg_mean_plot(all_adata_top_100_genes,
                                         condition_key=condition_key,
                                         axis_keys={
                                             "x": "pred",
                                             "y": conditions["stim"]
                                         },
                                         gene_list=diff_genes[:5],
                                         path_to_save=os.path.join(
                                             path_to_save,
                                             f"reg_mean_top_100_genes.pdf"))

            scgen.plotting.reg_var_plot(all_adata_top_100_genes,
                                        condition_key=condition_key,
                                        axis_keys={
                                            "x": "pred",
                                            "y": conditions["stim"]
                                        },
                                        gene_list=diff_genes[:5],
                                        path_to_save=os.path.join(
                                            path_to_save,
                                            f"reg_var_top_100_genes.pdf"))

            all_adata_top_50_genes = all_adata.copy()[:,
                                                      diff_genes.tolist()[:50]]

            scgen.plotting.reg_mean_plot(all_adata_top_50_genes,
                                         condition_key=condition_key,
                                         axis_keys={
                                             "x": "pred",
                                             "y": conditions["stim"]
                                         },
                                         gene_list=diff_genes[:5],
                                         path_to_save=os.path.join(
                                             path_to_save,
                                             f"reg_mean_top_50_genes.pdf"))

            scgen.plotting.reg_var_plot(all_adata_top_50_genes,
                                        condition_key=condition_key,
                                        axis_keys={
                                            "x": "pred",
                                            "y": conditions["stim"]
                                        },
                                        gene_list=diff_genes[:5],
                                        path_to_save=os.path.join(
                                            path_to_save,
                                            f"reg_var_top_50_genes.pdf"))

            if plot_umap:
                sc.pp.neighbors(all_adata)
                sc.tl.umap(all_adata)
                sc.pl.umap(all_adata,
                           color=condition_key,
                           save="pred_all_genes",
                           show=False)

                sc.pp.neighbors(all_adata_top_100_genes)
                sc.tl.umap(all_adata_top_100_genes)
                sc.pl.umap(all_adata_top_100_genes,
                           color=condition_key,
                           save="pred_top_100_genes",
                           show=False)

                sc.pp.neighbors(all_adata_top_50_genes)
                sc.tl.umap(all_adata_top_50_genes)
                sc.pl.umap(all_adata_top_50_genes,
                           color=condition_key,
                           save="pred_top_50_genes",
                           show=False)

        sc.pl.violin(all_adata,
                     keys=diff_genes.tolist()[0],
                     groupby=condition_key,
                     save=f"_{diff_genes.tolist()[0]}",
                     show=False)

        plt.close("all")
Esempio n. 2
0
import scgen
import scanpy as sc
import numpy as np

train = sc.read("./data/train.h5ad")
# train = train[train.obs["cell_type"] == "CD4T"]
train = train[~((train.obs["cell_type"] == "CD4T") &
                (train.obs["condition"] == "stimulated"))]
z_dim = 20
network = scgen.CVAE(x_dimension=train.X.shape[1],
                     z_dimension=z_dim,
                     alpha=0.1)
network.restore_model()
# network.train(train, n_epochs=100)

labels, _ = scgen.label_encoder(train)
latent = network.to_latent(train.X.A, labels=labels)
adata = sc.AnnData(X=latent,
                   obs={
                       "condition": train.obs["condition"].tolist(),
                       "cell_type": train.obs["cell_type"].tolist()
                   })
sc.pp.neighbors(adata)
sc.tl.umap(adata)
sc.pl.umap(adata, color=["condition", "cell_type"], save=f"train_{z_dim}")
mmd = network.to_mmd_layer(train.X.A, labels=labels)
adata_mmd = sc.AnnData(X=mmd,
                       obs={
                           "condition": train.obs["condition"].tolist(),
                           "cell_type": train.obs["cell_type"].tolist()
                       })