def test_new_setup_compat(): adata = synthetic_iid() adata.obs["cat1"] = np.random.randint(0, 5, size=(adata.shape[0], )) adata.obs["cat2"] = np.random.randint(0, 5, size=(adata.shape[0], )) adata.obs["cont1"] = np.random.normal(size=(adata.shape[0], )) adata.obs["cont2"] = np.random.normal(size=(adata.shape[0], )) adata2 = adata.copy() adata3 = adata.copy() SCVI.setup_anndata( adata, batch_key="batch", labels_key="labels", categorical_covariate_keys=["cat1", "cat2"], continuous_covariate_keys=["cont1", "cont2"], ) model = SCVI(adata) adata_manager = model.adata_manager model.view_anndata_setup(hide_state_registries=True) field_registries = adata_manager.registry[_constants._FIELD_REGISTRIES_KEY] field_registries_legacy_subset = { k: v for k, v in field_registries.items() if k in LEGACY_REGISTRY_KEYS } # Backwards compatibility test. adata2_manager = manager_from_setup_dict(SCVI, adata2, LEGACY_SETUP_DICT) np.testing.assert_equal( field_registries_legacy_subset, adata2_manager.registry[_constants._FIELD_REGISTRIES_KEY], ) # Test transfer. adata3_manager = adata_manager.transfer_setup(adata3) np.testing.assert_equal( field_registries, adata3_manager.registry[_constants._FIELD_REGISTRIES_KEY], )
def test_scvi(save_path): n_latent = 5 # Test with size factor. adata = synthetic_iid() adata.obs["size_factor"] = np.random.randint(1, 5, size=(adata.shape[0], )) SCVI.setup_anndata( adata, batch_key="batch", labels_key="labels", size_factor_key="size_factor", ) model = SCVI(adata, n_latent=n_latent) model.train(1, check_val_every_n_epoch=1, train_size=0.5) # Test with observed lib size. adata = synthetic_iid() SCVI.setup_anndata( adata, batch_key="batch", labels_key="labels", ) model = SCVI(adata, n_latent=n_latent) model.train(1, check_val_every_n_epoch=1, train_size=0.5) # Test without observed lib size. model = SCVI(adata, n_latent=n_latent, var_activation=Softplus(), use_observed_lib_size=False) model.train(1, check_val_every_n_epoch=1, train_size=0.5) model.train(1, check_val_every_n_epoch=1, train_size=0.5) # tests __repr__ print(model) # test view_registry model.view_anndata_setup() model.view_anndata_setup(hide_state_registries=True) assert model.is_trained is True z = model.get_latent_representation() assert z.shape == (adata.shape[0], n_latent) assert len(model.history["elbo_train"]) == 2 model.get_elbo() model.get_marginal_ll(n_mc_samples=3) model.get_reconstruction_error() model.get_normalized_expression(transform_batch="batch_1") adata2 = synthetic_iid() # test view_registry with different anndata before transfer setup with pytest.raises(ValueError): model.view_anndata_setup(adata=adata2) model.view_anndata_setup(adata=adata2, hide_state_registries=True) # test get methods with different anndata model.get_elbo(adata2) model.get_marginal_ll(adata2, n_mc_samples=3) model.get_reconstruction_error(adata2) latent = model.get_latent_representation(adata2, indices=[1, 2, 3]) assert latent.shape == (3, n_latent) denoised = model.get_normalized_expression(adata2) assert denoised.shape == adata.shape # test view_registry with different anndata after transfer setup model.view_anndata_setup(adata=adata2) model.view_anndata_setup(adata=adata2, hide_state_registries=True) denoised = model.get_normalized_expression(adata2, indices=[1, 2, 3], transform_batch="batch_1") denoised = model.get_normalized_expression( adata2, indices=[1, 2, 3], transform_batch=["batch_0", "batch_1"]) assert denoised.shape == (3, adata2.n_vars) sample = model.posterior_predictive_sample(adata2) assert sample.shape == adata2.shape sample = model.posterior_predictive_sample(adata2, indices=[1, 2, 3], gene_list=["1", "2"]) assert sample.shape == (3, 2) sample = model.posterior_predictive_sample(adata2, indices=[1, 2, 3], gene_list=["1", "2"], n_samples=3) assert sample.shape == (3, 2, 3) model.get_feature_correlation_matrix(correlation_type="pearson") model.get_feature_correlation_matrix( adata2, indices=[1, 2, 3], correlation_type="spearman", rna_size_factor=500, n_samples=5, ) model.get_feature_correlation_matrix( adata2, indices=[1, 2, 3], correlation_type="spearman", rna_size_factor=500, n_samples=5, transform_batch=["batch_0", "batch_1"], ) params = model.get_likelihood_parameters() assert params["mean"].shape == adata.shape assert (params["mean"].shape == params["dispersions"].shape == params["dropout"].shape) params = model.get_likelihood_parameters(adata2, indices=[1, 2, 3]) assert params["mean"].shape == (3, adata.n_vars) params = model.get_likelihood_parameters(adata2, indices=[1, 2, 3], n_samples=3, give_mean=True) assert params["mean"].shape == (3, adata.n_vars) model.get_latent_library_size() model.get_latent_library_size(adata2, indices=[1, 2, 3]) # test transfer_anndata_setup adata2 = synthetic_iid() model._validate_anndata(adata2) model.get_elbo(adata2) # test automatic transfer_anndata_setup + on a view adata = synthetic_iid() SCVI.setup_anndata( adata, batch_key="batch", labels_key="labels", ) model = SCVI(adata) adata2 = synthetic_iid() model.get_elbo(adata2[:10]) # test mismatched categories raises ValueError adata2 = synthetic_iid() adata2.obs.labels.cat.rename_categories(["a", "b", "c"], inplace=True) with pytest.raises(ValueError): model.get_elbo(adata2) # test differential expression model.differential_expression(groupby="labels", group1="label_1") model.differential_expression(groupby="labels", group1="label_1", group2="label_2", mode="change") model.differential_expression(groupby="labels") model.differential_expression(idx1=[0, 1, 2], idx2=[3, 4, 5]) model.differential_expression(idx1=[0, 1, 2]) # transform batch works with all different types a = synthetic_iid() batch = np.zeros(a.n_obs) batch[:64] += 1 a.obs["batch"] = batch SCVI.setup_anndata( a, batch_key="batch", ) m = SCVI(a) m.train(1, train_size=0.5) m.get_normalized_expression(transform_batch=1) m.get_normalized_expression(transform_batch=[0, 1]) # test get_likelihood_parameters() when dispersion=='gene-cell' model = SCVI(adata, dispersion="gene-cell") model.get_likelihood_parameters() # test train callbacks work a = synthetic_iid() SCVI.setup_anndata( a, batch_key="batch", labels_key="labels", ) m = scvi.model.SCVI(a) lr_monitor = LearningRateMonitor() m.train( callbacks=[lr_monitor], max_epochs=10, check_val_every_n_epoch=1, log_every_n_steps=1, plan_kwargs={"reduce_lr_on_plateau": True}, ) assert "lr-Adam" in m.history.keys()