def test_saving_and_loading(save_path): def test_save_load_model(cls, adata, save_path): model = cls(adata, latent_distribution="normal") model.train(1, train_size=0.2) z1 = model.get_latent_representation(adata) test_idx1 = model.validation_indices model.save(save_path, overwrite=True, save_anndata=True) model = cls.load(save_path) model.get_latent_representation() tmp_adata = scvi.data.synthetic_iid(n_genes=200) with pytest.raises(ValueError): cls.load(save_path, tmp_adata) model = cls.load(save_path, adata) z2 = model.get_latent_representation() test_idx2 = model.validation_indices np.testing.assert_array_equal(z1, z2) np.testing.assert_array_equal(test_idx1, test_idx2) assert model.is_trained is True save_path = os.path.join(save_path, "tmp") adata = synthetic_iid() for cls in [SCVI, LinearSCVI, TOTALVI]: print(cls) test_save_load_model(cls, adata, save_path) # AUTOZI model = AUTOZI(adata, latent_distribution="normal") model.train(1, train_size=0.5) ab1 = model.get_alphas_betas() model.save(save_path, overwrite=True, save_anndata=True) model = AUTOZI.load(save_path) model.get_latent_representation() tmp_adata = scvi.data.synthetic_iid(n_genes=200) with pytest.raises(ValueError): AUTOZI.load(save_path, tmp_adata) model = AUTOZI.load(save_path, adata) ab2 = model.get_alphas_betas() np.testing.assert_array_equal(ab1["alpha_posterior"], ab2["alpha_posterior"]) np.testing.assert_array_equal(ab1["beta_posterior"], ab2["beta_posterior"]) assert model.is_trained is True # SCANVI model = SCANVI(adata, "label_0") model.train(max_epochs=1, train_size=0.5) p1 = model.predict() model.save(save_path, overwrite=True, save_anndata=True) model = SCANVI.load(save_path) model.get_latent_representation() tmp_adata = scvi.data.synthetic_iid(n_genes=200) with pytest.raises(ValueError): SCANVI.load(save_path, tmp_adata) model = SCANVI.load(save_path, adata) p2 = model.predict() np.testing.assert_array_equal(p1, p2) assert model.is_trained is True
def test_saving_and_loading(save_path): def test_save_load_model(cls, adata, save_path): model = cls(adata, latent_distribution="normal") model.train(1) z1 = model.get_latent_representation(adata) test_idx1 = model.test_indices model.save(save_path, overwrite=True) model = cls.load(adata, save_path) z2 = model.get_latent_representation() test_idx2 = model.test_indices np.testing.assert_array_equal(z1, z2) np.testing.assert_array_equal(test_idx1, test_idx2) assert model.is_trained is True save_path = os.path.join(save_path, "tmp") adata = synthetic_iid() for cls in [SCVI, LinearSCVI, TOTALVI]: print(cls) test_save_load_model(cls, adata, save_path) # AUTOZI model = AUTOZI(adata, latent_distribution="normal") model.train(1) ab1 = model.get_alphas_betas() model.save(save_path, overwrite=True) model = AUTOZI.load(adata, save_path) ab2 = model.get_alphas_betas() np.testing.assert_array_equal(ab1["alpha_posterior"], ab2["alpha_posterior"]) np.testing.assert_array_equal(ab1["beta_posterior"], ab2["beta_posterior"]) assert model.is_trained is True # SCANVI model = SCANVI(adata, "undefined_0") model.train(n_epochs_unsupervised=1, n_epochs_semisupervised=1) p1 = model.predict() model.save(save_path, overwrite=True) model = SCANVI.load(adata, save_path) p2 = model.predict() np.testing.assert_array_equal(p1, p2) assert model.is_trained is True # GIMVI model = GIMVI(adata, adata) model.train(1) z1 = model.get_latent_representation([adata]) z2 = model.get_latent_representation([adata]) np.testing.assert_array_equal(z1, z2) model.save(save_path, overwrite=True) model = GIMVI.load(adata, adata, save_path) z2 = model.get_latent_representation([adata]) np.testing.assert_array_equal(z1, z2) assert model.is_trained is True
def test_multiple_covariates_scvi(save_path): adata = synthetic_iid() adata.obs["cont1"] = np.random.normal(size=(adata.shape[0], )) adata.obs["cont2"] = np.random.normal(size=(adata.shape[0], )) adata.obs["cat1"] = np.random.randint(0, 5, size=(adata.shape[0], )) adata.obs["cat2"] = np.random.randint(0, 5, size=(adata.shape[0], )) SCVI.setup_anndata( adata, batch_key="batch", labels_key="labels", continuous_covariate_keys=["cont1", "cont2"], categorical_covariate_keys=["cat1", "cat2"], ) m = SCVI(adata) m.train(1) m = SCANVI(adata, unlabeled_category="Unknown") m.train(1) TOTALVI.setup_anndata( adata, batch_key="batch", protein_expression_obsm_key="protein_expression", protein_names_uns_key="protein_names", continuous_covariate_keys=["cont1", "cont2"], categorical_covariate_keys=["cat1", "cat2"], ) m = TOTALVI(adata) m.train(1)
def test_scanvi(save_path): adata = synthetic_iid() model = SCANVI(adata, "label_0", n_latent=10) model.train(1, train_size=0.5, check_val_every_n_epoch=1) logged_keys = model.history.keys() assert "elbo_validation" in logged_keys assert "reconstruction_loss_validation" in logged_keys assert "kl_local_validation" in logged_keys assert "elbo_train" in logged_keys assert "reconstruction_loss_train" in logged_keys assert "kl_local_train" in logged_keys assert "classification_loss_validation" in logged_keys adata2 = synthetic_iid() predictions = model.predict(adata2, indices=[1, 2, 3]) assert len(predictions) == 3 model.predict() model.predict(adata2, soft=True) model.predict(adata2, soft=True, indices=[1, 2, 3]) model.get_normalized_expression(adata2) model.differential_expression(groupby="labels", group1="label_1") model.differential_expression(groupby="labels", group1="label_1", group2="label_2") # test that all data labeled runs unknown_label = "asdf" a = scvi.data.synthetic_iid() scvi.data.setup_anndata(a, batch_key="batch", labels_key="labels") m = scvi.model.SCANVI(a, unknown_label) m.train(1) # test mix of labeled and unlabeled data unknown_label = "label_0" a = scvi.data.synthetic_iid() scvi.data.setup_anndata(a, batch_key="batch", labels_key="labels") m = scvi.model.SCANVI(a, unknown_label) m.train(1, train_size=0.9) # test from_scvi_model a = scvi.data.synthetic_iid() m = scvi.model.SCVI(a, use_observed_lib_size=False) a2 = scvi.data.synthetic_iid() scanvi_model = scvi.model.SCANVI.from_scvi_model(m, "label_0", adata=a2) scanvi_model = scvi.model.SCANVI.from_scvi_model( m, "label_0", use_labels_groups=False ) scanvi_model.train(1)
def test_scanvi(): adata = synthetic_iid() model = SCANVI(adata, "label_0", n_latent=10) model.train(1, train_size=0.5, frequency=1) assert len(model.history["unsupervised_trainer_history"]) == 2 assert len(model.history["semisupervised_trainer_history"]) == 3 adata2 = synthetic_iid() predictions = model.predict(adata2, indices=[1, 2, 3]) assert len(predictions) == 3 model.predict() model.predict(adata2, soft=True) model.predict(adata2, soft=True, indices=[1, 2, 3]) model.get_normalized_expression(adata2) model.differential_expression(groupby="labels", group1="label_1") model.differential_expression(groupby="labels", group1="label_1", group2="label_2")
def test_scanvi_online_update(save_path): # ref has semi-observed labels n_latent = 5 adata1 = synthetic_iid(run_setup_anndata=False) new_labels = adata1.obs.labels.to_numpy() new_labels[0] = "Unknown" adata1.obs["labels"] = pd.Categorical(new_labels) setup_anndata(adata1, batch_key="batch", labels_key="labels") model = SCANVI(adata1, "Unknown", n_latent=n_latent, encode_covariates=True) model.train(max_epochs=1, check_val_every_n_epoch=1) dir_path = os.path.join(save_path, "saved_model/") model.save(dir_path, overwrite=True) adata2 = synthetic_iid(run_setup_anndata=False) adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) adata2.obs["labels"] = "Unknown" model = SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True) model.train(max_epochs=1) model.get_latent_representation() model.predict() # ref has fully-observed labels n_latent = 5 adata1 = synthetic_iid(run_setup_anndata=False) new_labels = adata1.obs.labels.to_numpy() adata1.obs["labels"] = pd.Categorical(new_labels) setup_anndata(adata1, batch_key="batch", labels_key="labels") model = SCANVI(adata1, "Unknown", n_latent=n_latent, encode_covariates=True) model.train(max_epochs=1, check_val_every_n_epoch=1) dir_path = os.path.join(save_path, "saved_model/") model.save(dir_path, overwrite=True) # query has one new label adata2 = synthetic_iid(run_setup_anndata=False) adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) new_labels = adata2.obs.labels.to_numpy() new_labels[0] = "Unknown" adata2.obs["labels"] = pd.Categorical(new_labels) model2 = SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True) model2._unlabeled_indices = np.arange(adata2.n_obs) model2._labeled_indices = [] model2.train(max_epochs=1, plan_kwargs=dict(weight_decay=0.0)) model2.get_latent_representation() model2.predict() # test classifier frozen class_query_weight = ( model2.module.classifier.classifier[0] .fc_layers[0][0] .weight.detach() .cpu() .numpy() ) class_ref_weight = ( model.module.classifier.classifier[0] .fc_layers[0][0] .weight.detach() .cpu() .numpy() ) # weight decay makes difference np.testing.assert_allclose(class_query_weight, class_ref_weight, atol=1e-07) # test classifier unfrozen model2 = SCANVI.load_query_data(adata2, dir_path, freeze_classifier=False) model2._unlabeled_indices = np.arange(adata2.n_obs) model2._labeled_indices = [] model2.train(max_epochs=1) class_query_weight = ( model2.module.classifier.classifier[0] .fc_layers[0][0] .weight.detach() .cpu() .numpy() ) class_ref_weight = ( model.module.classifier.classifier[0] .fc_layers[0][0] .weight.detach() .cpu() .numpy() ) with pytest.raises(AssertionError): np.testing.assert_allclose(class_query_weight, class_ref_weight, atol=1e-07) # test saving and loading of online scanvi a = scvi.data.synthetic_iid(run_setup_anndata=False) ref = a[a.obs["labels"] != "label_2"].copy() # only has labels 0 and 1 scvi.data.setup_anndata(ref, batch_key="batch", labels_key="labels") m = SCANVI(ref, "label_2") m.train(max_epochs=1) m.save(save_path, overwrite=True) query = a[a.obs["labels"] != "label_0"].copy() query = scvi.data.synthetic_iid() # has labels 0 and 2. 2 is unknown m_q = SCANVI.load_query_data(query, save_path) m_q.save(save_path, overwrite=True) m_q = SCANVI.load(save_path, query) m_q.predict() m_q.get_elbo()
def test_scanvi(save_path): adata = synthetic_iid() model = SCANVI(adata, "label_0", n_latent=10) model.train(1, train_size=0.5, check_val_every_n_epoch=1) logged_keys = model.history.keys() assert "elbo_validation" in logged_keys assert "reconstruction_loss_validation" in logged_keys assert "kl_local_validation" in logged_keys assert "elbo_train" in logged_keys assert "reconstruction_loss_train" in logged_keys assert "kl_local_train" in logged_keys adata2 = synthetic_iid() predictions = model.predict(adata2, indices=[1, 2, 3]) assert len(predictions) == 3 model.predict() model.predict(adata2, soft=True) model.predict(adata2, soft=True, indices=[1, 2, 3]) model.get_normalized_expression(adata2) model.differential_expression(groupby="labels", group1="label_1") model.differential_expression(groupby="labels", group1="label_1", group2="label_2") # test that all data labeled runs unknown_label = "asdf" a = scvi.data.synthetic_iid() scvi.data.setup_anndata(a, batch_key="batch", labels_key="labels") m = scvi.model.SCANVI(a, unknown_label) m.train(1) # check the number of indices n_train_idx = len(m.train_indices) n_validation_idx = len(m.validation_indices) n_test_idx = len(m.test_indices) assert n_train_idx + n_validation_idx + n_test_idx == a.n_obs assert np.isclose(n_train_idx / a.n_obs, 0.9) assert np.isclose(n_validation_idx / a.n_obs, 0.1) assert np.isclose(n_test_idx / a.n_obs, 0) # test mix of labeled and unlabeled data unknown_label = "label_0" a = scvi.data.synthetic_iid() scvi.data.setup_anndata(a, batch_key="batch", labels_key="labels") m = scvi.model.SCANVI(a, unknown_label) m.train(1, train_size=0.9) # check the number of indices n_train_idx = len(m.train_indices) n_validation_idx = len(m.validation_indices) n_test_idx = len(m.test_indices) assert n_train_idx + n_validation_idx + n_test_idx == a.n_obs assert np.isclose(n_train_idx / a.n_obs, 0.9, rtol=0.05) assert np.isclose(n_validation_idx / a.n_obs, 0.1, rtol=0.05) assert np.isclose(n_test_idx / a.n_obs, 0, rtol=0.05) # check that training indices have proper mix of labeled and unlabeled data labelled_idx = np.where(a.obs["labels"] != unknown_label)[0] unlabelled_idx = np.where(a.obs["labels"] == unknown_label)[0] # labeled training idx labeled_train_idx = [i for i in m.train_indices if i in labelled_idx] # unlabeled training idx unlabeled_train_idx = [i for i in m.train_indices if i in unlabelled_idx] n_labeled_idx = len(m._labeled_indices) n_unlabeled_idx = len(m._unlabeled_indices) # labeled vs unlabeled ratio in adata adata_ratio = n_unlabeled_idx / n_labeled_idx # labeled vs unlabeled ratio in train set train_ratio = len(unlabeled_train_idx) / len(labeled_train_idx) assert np.isclose(adata_ratio, train_ratio, atol=0.05)
def test_scanvi(): adata = synthetic_iid() model = SCANVI(adata, "undefined_0", n_latent=10) model.train(1) adata2 = synthetic_iid() predictions = model.predict(adata2, indices=[1, 2, 3]) assert len(predictions) == 3 model.predict() model.predict(adata2, soft=True) model.predict(adata2, soft=True, indices=[1, 2, 3]) model.get_normalized_expression(adata2) model.differential_expression(groupby="labels", group1="undefined_1") model.differential_expression(groupby="labels", group1="undefined_1", group2="undefined_2")
def test_save_load_scanvi(legacy=False): prefix = "SCANVI_" model = SCANVI(adata, "label_0") model.train(max_epochs=1, train_size=0.5) p1 = model.predict() if legacy: legacy_save(model, save_path, overwrite=True, save_anndata=True, prefix=prefix) else: model.save(save_path, overwrite=True, save_anndata=True, prefix=prefix) model = SCANVI.load(save_path, prefix=prefix) model.get_latent_representation() tmp_adata = scvi.data.synthetic_iid(n_genes=200) with pytest.raises(ValueError): SCANVI.load(save_path, adata=tmp_adata, prefix=prefix) model = SCANVI.load(save_path, adata=adata, prefix=prefix) assert "test" in adata.uns["_scvi"]["data_registry"] assert adata.uns["_scvi"]["data_registry"]["test"] == dict( attr_name="obs", attr_key="cont1") p2 = model.predict() np.testing.assert_array_equal(p1, p2) assert model.is_trained is True
def test_scanvi_online_update(save_path): # ref has semi-observed labels n_latent = 5 adata1 = synthetic_iid(run_setup_anndata=False) new_labels = adata1.obs.labels.to_numpy() new_labels[0] = "Unknown" adata1.obs["labels"] = pd.Categorical(new_labels) setup_anndata(adata1, batch_key="batch", labels_key="labels") model = SCANVI(adata1, "Unknown", n_latent=n_latent, encode_covariates=True) model.train(n_epochs_unsupervised=1, n_epochs_semisupervised=1, frequency=1) dir_path = os.path.join(save_path, "saved_model/") model.save(dir_path, overwrite=True) adata2 = synthetic_iid(run_setup_anndata=False) adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) adata2.obs["labels"] = "Unknown" model = SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True) model.train( n_epochs_unsupervised=1, n_epochs_semisupervised=1, train_base_model=False ) model.get_latent_representation() model.predict() # ref has fully-observed labels n_latent = 5 adata1 = synthetic_iid(run_setup_anndata=False) new_labels = adata1.obs.labels.to_numpy() adata1.obs["labels"] = pd.Categorical(new_labels) setup_anndata(adata1, batch_key="batch", labels_key="labels") model = SCANVI(adata1, "Unknown", n_latent=n_latent, encode_covariates=True) model.train(n_epochs_unsupervised=1, n_epochs_semisupervised=1, frequency=1) dir_path = os.path.join(save_path, "saved_model/") model.save(dir_path, overwrite=True) # query has one new label adata2 = synthetic_iid(run_setup_anndata=False) adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) new_labels = adata2.obs.labels.to_numpy() new_labels[0] = "Unknown" adata2.obs["labels"] = pd.Categorical(new_labels) model = SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True) model._unlabeled_indices = np.arange(adata2.n_obs) model._labeled_indices = [] model.train( n_epochs_unsupervised=1, n_epochs_semisupervised=1, train_base_model=False ) model.get_latent_representation() model.predict()
def test_scanvi(save_path): adata = synthetic_iid() SCANVI.setup_anndata( adata, "label_0", batch_key="batch", labels_key="labels", ) model = SCANVI(adata, n_latent=10) model.train(1, train_size=0.5, check_val_every_n_epoch=1) logged_keys = model.history.keys() assert "elbo_validation" in logged_keys assert "reconstruction_loss_validation" in logged_keys assert "kl_local_validation" in logged_keys assert "elbo_train" in logged_keys assert "reconstruction_loss_train" in logged_keys assert "kl_local_train" in logged_keys assert "classification_loss_validation" in logged_keys adata2 = synthetic_iid() predictions = model.predict(adata2, indices=[1, 2, 3]) assert len(predictions) == 3 model.predict() df = model.predict(adata2, soft=True) assert isinstance(df, pd.DataFrame) model.predict(adata2, soft=True, indices=[1, 2, 3]) model.get_normalized_expression(adata2) model.differential_expression(groupby="labels", group1="label_1") model.differential_expression(groupby="labels", group1="label_1", group2="label_2") # test that all data labeled runs unknown_label = "asdf" a = scvi.data.synthetic_iid() scvi.model.SCANVI.setup_anndata(a, unknown_label, batch_key="batch", labels_key="labels") m = scvi.model.SCANVI(a) m.train(1) # test mix of labeled and unlabeled data unknown_label = "label_0" a = scvi.data.synthetic_iid() scvi.model.SCANVI.setup_anndata(a, unknown_label, batch_key="batch", labels_key="labels") m = scvi.model.SCANVI(a) m.train(1, train_size=0.9) # test from_scvi_model a = scvi.data.synthetic_iid() SCVI.setup_anndata( a, batch_key="batch", labels_key="labels", ) m = SCVI(a, use_observed_lib_size=False) a2 = scvi.data.synthetic_iid() scanvi_model = scvi.model.SCANVI.from_scvi_model(m, "label_0", adata=a2) # make sure the state_dicts are different objects for the two models assert scanvi_model.module.state_dict() is not m.module.state_dict() scanvi_pxr = scanvi_model.module.state_dict().get("px_r", None) scvi_pxr = m.module.state_dict().get("px_r", None) assert scanvi_pxr is not None and scvi_pxr is not None assert scanvi_pxr is not scvi_pxr scanvi_model.train(1) # Test without label groups scanvi_model = scvi.model.SCANVI.from_scvi_model(m, "label_0", use_labels_groups=False) scanvi_model.train(1) # test from_scvi_model with size_factor a = scvi.data.synthetic_iid() a.obs["size_factor"] = np.random.randint(1, 5, size=(a.shape[0], )) SCVI.setup_anndata(a, batch_key="batch", labels_key="labels", size_factor_key="size_factor") m = SCVI(a, use_observed_lib_size=False) a2 = scvi.data.synthetic_iid() a2.obs["size_factor"] = np.random.randint(1, 5, size=(a2.shape[0], )) scanvi_model = scvi.model.SCANVI.from_scvi_model(m, "label_0", adata=a2) scanvi_model.train(1)
def test_saving_and_loading(save_path): def legacy_save( model, dir_path, prefix=None, overwrite=False, save_anndata=False, **anndata_write_kwargs, ): if not os.path.exists(dir_path) or overwrite: os.makedirs(dir_path, exist_ok=overwrite) else: raise ValueError( "{} already exists. Please provide an unexisting directory for saving." .format(dir_path)) file_name_prefix = prefix or "" if save_anndata: model.adata.write( os.path.join(dir_path, f"{file_name_prefix}adata.h5ad"), **anndata_write_kwargs, ) model_save_path = os.path.join(dir_path, f"{file_name_prefix}model_params.pt") attr_save_path = os.path.join(dir_path, f"{file_name_prefix}attr.pkl") varnames_save_path = os.path.join(dir_path, f"{file_name_prefix}var_names.csv") torch.save(model.module.state_dict(), model_save_path) var_names = model.adata.var_names.astype(str) var_names = var_names.to_numpy() np.savetxt(varnames_save_path, var_names, fmt="%s") # get all the user attributes user_attributes = model._get_user_attributes() # only save the public attributes with _ at the very end user_attributes = { a[0]: a[1] for a in user_attributes if a[0][-1] == "_" } with open(attr_save_path, "wb") as f: pickle.dump(user_attributes, f) def test_save_load_model(cls, adata, save_path, prefix=None, legacy=False): if cls is TOTALVI: cls.setup_anndata( adata, batch_key="batch", labels_key="labels", protein_expression_obsm_key="protein_expression", protein_names_uns_key="protein_names", ) else: cls.setup_anndata(adata, batch_key="batch", labels_key="labels") model = cls(adata, latent_distribution="normal") model.train(1, train_size=0.2) z1 = model.get_latent_representation(adata) test_idx1 = model.validation_indices if legacy: legacy_save(model, save_path, overwrite=True, save_anndata=True, prefix=prefix) else: model.save(save_path, overwrite=True, save_anndata=True, prefix=prefix) model = cls.load(save_path, prefix=prefix) model.get_latent_representation() # Load with mismatched genes. tmp_adata = synthetic_iid(n_genes=200, ) with pytest.raises(ValueError): cls.load(save_path, adata=tmp_adata, prefix=prefix) # Load with different batches. tmp_adata = synthetic_iid() tmp_adata.obs["batch"] = tmp_adata.obs["batch"].cat.rename_categories( ["batch_2", "batch_3"]) with pytest.raises(ValueError): cls.load(save_path, adata=tmp_adata, prefix=prefix) model = cls.load(save_path, adata=adata, prefix=prefix) assert "batch" in model.adata_manager.data_registry assert model.adata_manager.data_registry["batch"] == dict( attr_name="obs", attr_key="_scvi_batch") z2 = model.get_latent_representation() test_idx2 = model.validation_indices np.testing.assert_array_equal(z1, z2) np.testing.assert_array_equal(test_idx1, test_idx2) assert model.is_trained is True save_path = os.path.join(save_path, "tmp") adata = synthetic_iid() for cls in [SCVI, LinearSCVI, TOTALVI, PEAKVI]: test_save_load_model(cls, adata, save_path, prefix=f"{cls.__name__}_", legacy=True) test_save_load_model(cls, adata, save_path, prefix=f"{cls.__name__}_") # Test load prioritizes newer save paradigm and thus mismatches legacy save. with pytest.raises(AssertionError): test_save_load_model(cls, adata, save_path, prefix=f"{cls.__name__}_", legacy=True) # AUTOZI def test_save_load_autozi(legacy=False): prefix = "AUTOZI_" model = AUTOZI(adata, latent_distribution="normal") model.train(1, train_size=0.5) ab1 = model.get_alphas_betas() if legacy: legacy_save(model, save_path, overwrite=True, save_anndata=True, prefix=prefix) else: model.save(save_path, overwrite=True, save_anndata=True, prefix=prefix) model = AUTOZI.load(save_path, prefix=prefix) model.get_latent_representation() tmp_adata = scvi.data.synthetic_iid(n_genes=200) with pytest.raises(ValueError): AUTOZI.load(save_path, adata=tmp_adata, prefix=prefix) model = AUTOZI.load(save_path, adata=adata, prefix=prefix) assert "batch" in model.adata_manager.data_registry assert model.adata_manager.data_registry["batch"] == dict( attr_name="obs", attr_key="_scvi_batch") ab2 = model.get_alphas_betas() np.testing.assert_array_equal(ab1["alpha_posterior"], ab2["alpha_posterior"]) np.testing.assert_array_equal(ab1["beta_posterior"], ab2["beta_posterior"]) assert model.is_trained is True AUTOZI.setup_anndata(adata, batch_key="batch", labels_key="labels") test_save_load_autozi(legacy=True) test_save_load_autozi() # Test load prioritizes newer save paradigm and thus mismatches legacy save. with pytest.raises(AssertionError): test_save_load_autozi(legacy=True) # SCANVI def test_save_load_scanvi(legacy=False): prefix = "SCANVI_" model = SCANVI(adata) model.train(max_epochs=1, train_size=0.5) p1 = model.predict() if legacy: legacy_save(model, save_path, overwrite=True, save_anndata=True, prefix=prefix) else: model.save(save_path, overwrite=True, save_anndata=True, prefix=prefix) model = SCANVI.load(save_path, prefix=prefix) model.get_latent_representation() tmp_adata = scvi.data.synthetic_iid(n_genes=200) with pytest.raises(ValueError): SCANVI.load(save_path, adata=tmp_adata, prefix=prefix) model = SCANVI.load(save_path, adata=adata, prefix=prefix) assert "batch" in model.adata_manager.data_registry assert model.adata_manager.data_registry["batch"] == dict( attr_name="obs", attr_key="_scvi_batch") p2 = model.predict() np.testing.assert_array_equal(p1, p2) assert model.is_trained is True SCANVI.setup_anndata(adata, "label_0", batch_key="batch", labels_key="labels") test_save_load_scanvi(legacy=True) test_save_load_scanvi() # Test load prioritizes newer save paradigm and thus mismatches legacy save. with pytest.raises(AssertionError): test_save_load_scanvi(legacy=True)