예제 #1
0
파일: test_scgen.py 프로젝트: abuchin/scgen
def test_scgen():

    adata = scvi.data.synthetic_iid(run_setup_anndata=False)
    setup_anndata(
        adata,
        batch_key="batch",
        labels_key="labels",
    )

    model = SCGEN(adata)
    model.train(
        max_epochs=1, batch_size=32, early_stopping=True, early_stopping_patience=25
    )

    # batch Removal
    model.batch_removal()

    # predict
    pred, delta = model.predict(
        ctrl_key="batch_0", stim_key="batch_1", celltype_to_predict="label_0"
    )
    pred.obs["batch"] = "pred"

    # reg mean and reg var
    ctrl_adata = adata[
        ((adata.obs["labels"] == "label_0") & (adata.obs["batch"] == "batch_0"))
    ]
    stim_adata = adata[
        ((adata.obs["labels"] == "label_0") & (adata.obs["batch"] == "batch_1"))
    ]
    eval_adata = ctrl_adata.concatenate(stim_adata, pred, batch_key="concat_batches")
    label_0 = adata[adata.obs["labels"] == "label_0"]
    sc.tl.rank_genes_groups(label_0, groupby="batch", method="wilcoxon")
    diff_genes = label_0.uns["rank_genes_groups"]["names"]["batch_1"]

    model.reg_mean_plot(
        eval_adata,
        axis_keys={"x": "pred", "y": "batch_1"},
        gene_list=diff_genes[:10],
        labels={"x": "predicted", "y": "ground truth"},
        save=False,
        show=False,
        legend=False,
    )
예제 #2
0
base_path = '/Users/zhongyuanke/data/'
file_rna = '/Users/zhongyuanke/data/dann_vae/multimodal/rna.h5ad'
file_atac = '/Users/zhongyuanke/data/dann_vae/multimodal/atac.h5ad'
seurat_celltype_path = base_path + 'multimodal/atac_pbmc_10k/celltype_filt.csv'
batch_size = 128

adata1 = sc.read_h5ad(file_atac)
adata2 = sc.read_h5ad(file_rna)
print(adata1)
print(adata2)
# adata_b1.obs_names_make_unique()
# adata_b2.obs_names_make_unique()
# adata_b3.obs_names_make_unique()

adata_all = tl.davae_preprocessing([adata1, adata2], n_top_genes=2000, hvg=False, lognorm=False)
adata_all.obs_names_make_unique()

adata_all = scgen.setup_anndata(adata_all, batch_key="batch_label", copy=True)
model = scgen.SCGEN(adata_all)
model.train(max_epochs=15,
    batch_size=32,
    early_stopping=True,
    early_stopping_patience=25,
    use_gpu=False)
corrected_adata = model.batch_removal()
sc.pp.neighbors(corrected_adata,use_rep='corrected_latent')
sc.tl.umap(corrected_adata)
sc.pl.umap(corrected_adata,color='batch')

corrected_adata.write('/Users/zhongyuanke/data/scgen/scgen_multimodal.h5ad')