def test_destvi(save_path): # Step1 learn CondSCVI n_latent = 2 n_labels = 5 n_layers = 2 dataset = synthetic_iid(n_labels=n_labels) sc_model = CondSCVI(dataset, n_latent=n_latent, n_layers=n_layers) sc_model.train(1, train_size=1) # step 2 learn destVI with multiple amortization scheme for amor_scheme in ["both", "none", "proportion", "latent"]: spatial_model = DestVI.from_rna_model( dataset, sc_model, amortization=amor_scheme, ) spatial_model.train(max_epochs=1) assert not np.isnan(spatial_model.history["elbo_train"].values[0][0]) assert spatial_model.get_proportions().shape == (dataset.n_obs, n_labels) assert spatial_model.get_gamma(return_numpy=True).shape == ( dataset.n_obs, n_latent, n_labels, ) assert spatial_model.get_scale_for_ct("label_0", np.arange(50)).shape == ( 50, dataset.n_vars, )
def from_rna_model( cls, st_adata: AnnData, sc_model: CondSCVI, vamp_prior_p: int = 50, layer: Optional[str] = None, **module_kwargs, ): """ Alternate constructor for exploiting a pre-trained model on a RNA-seq dataset. Parameters ---------- st_adata registed anndata object sc_model trained CondSCVI model vamp_prior_p number of mixture parameter for VampPrior calculations **model_kwargs Keyword args for :class:`~scvi.model.DestVI` """ decoder_state_dict = sc_model.module.decoder.state_dict() px_decoder_state_dict = sc_model.module.px_decoder.state_dict() px_r = sc_model.module.px_r.detach().cpu().numpy() mapping = sc_model.adata_manager.get_state_registry( REGISTRY_KEYS.LABELS_KEY ).categorical_mapping if vamp_prior_p is None: mean_vprior = None var_vprior = None else: mean_vprior, var_vprior = sc_model.get_vamp_prior( sc_model.adata, p=vamp_prior_p ) cls.setup_anndata(st_adata, layer=layer) return cls( st_adata, mapping, decoder_state_dict, px_decoder_state_dict, px_r, sc_model.module.n_hidden, sc_model.module.n_latent, sc_model.module.n_layers, mean_vprior=mean_vprior, var_vprior=var_vprior, **module_kwargs, )
def destvi_raw(adata, test=False): from scvi.model import CondSCVI from scvi.model import DestVI adata_sc = adata.uns["sc_reference"].copy() CondSCVI.setup_anndata(adata_sc, labels_key="label", layer=None) sc_model = CondSCVI(adata_sc, weight_obs=False) sc_model.train() DestVI.setup_anndata(adata, layer=None) st_model = DestVI.from_rna_model(adata, sc_model) st_model.train(max_epochs=2500) adata.obsm["proportions_pred"] = st_model.get_proportions() return adata
def test_condscvi(save_path): dataset = synthetic_iid(n_labels=5) model = CondSCVI(dataset) model.train(1, train_size=1) model.get_latent_representation() model.get_vamp_prior(dataset) model = CondSCVI(dataset, weight_obs=True) model.train(1, train_size=1) model.get_latent_representation() model.get_vamp_prior(dataset)