コード例 #1
0
def test_destvi(save_path):
    # Step1 learn CondSCVI
    n_latent = 2
    n_labels = 5
    n_layers = 2
    dataset = synthetic_iid(n_labels=n_labels)
    CondSCVI.setup_anndata(dataset, labels_key="labels")
    sc_model = CondSCVI(dataset, n_latent=n_latent, n_layers=n_layers)
    sc_model.train(1, train_size=1)

    # step 2 learn destVI with multiple amortization scheme

    for amor_scheme in ["both", "none", "proportion", "latent"]:
        spatial_model = DestVI.from_rna_model(
            dataset,
            sc_model,
            amortization=amor_scheme,
        )
        spatial_model.train(max_epochs=1)
        assert not np.isnan(spatial_model.history["elbo_train"].values[0][0])

        assert spatial_model.get_proportions().shape == (dataset.n_obs,
                                                         n_labels)
        assert spatial_model.get_gamma(return_numpy=True).shape == (
            dataset.n_obs,
            n_latent,
            n_labels,
        )

        assert spatial_model.get_scale_for_ct("label_0",
                                              np.arange(50)).shape == (
                                                  50,
                                                  dataset.n_vars,
                                              )
コード例 #2
0
def destvi_raw(adata, test=False):
    from scvi.model import CondSCVI
    from scvi.model import DestVI

    adata_sc = adata.uns["sc_reference"].copy()
    CondSCVI.setup_anndata(adata_sc, labels_key="label", layer=None)
    sc_model = CondSCVI(adata_sc, weight_obs=False)
    sc_model.train()
    DestVI.setup_anndata(adata, layer=None)

    st_model = DestVI.from_rna_model(adata, sc_model)
    st_model.train(max_epochs=2500)
    adata.obsm["proportions_pred"] = st_model.get_proportions()
    return adata
コード例 #3
0
ファイル: test_models.py プロジェクト: vitkl/scvi-tools
def test_condscvi(save_path):
    dataset = synthetic_iid(n_labels=5, run_setup_anndata=False)
    CondSCVI.setup_anndata(
        dataset,
        "labels",
    )
    model = CondSCVI(dataset)
    model.train(1, train_size=1)
    model.get_latent_representation()
    model.get_vamp_prior(dataset)

    model = CondSCVI(dataset, weight_obs=True)
    model.train(1, train_size=1)
    model.get_latent_representation()
    model.get_vamp_prior(dataset)