def test_autozi(): data = synthetic_iid(n_batches=1, run_setup_anndata=False) AUTOZI.setup_anndata( data, batch_key="batch", labels_key="labels", ) for disp_zi in ["gene", "gene-label"]: autozivae = AUTOZI( data, dispersion=disp_zi, zero_inflation=disp_zi, ) autozivae.train(1, plan_kwargs=dict(lr=1e-2), check_val_every_n_epoch=1) assert len(autozivae.history["elbo_train"]) == 1 assert len(autozivae.history["elbo_validation"]) == 1 autozivae.get_elbo(indices=autozivae.validation_indices) autozivae.get_reconstruction_error( indices=autozivae.validation_indices) autozivae.get_marginal_ll(indices=autozivae.validation_indices, n_mc_samples=3) autozivae.get_alphas_betas() # Model library size. for disp_zi in ["gene", "gene-label"]: autozivae = AUTOZI( data, dispersion=disp_zi, zero_inflation=disp_zi, use_observed_lib_size=False, ) autozivae.train(1, plan_kwargs=dict(lr=1e-2), check_val_every_n_epoch=1) assert hasattr(autozivae.module, "library_log_means") and hasattr( autozivae.module, "library_log_vars") assert len(autozivae.history["elbo_train"]) == 1 assert len(autozivae.history["elbo_validation"]) == 1 autozivae.get_elbo(indices=autozivae.validation_indices) autozivae.get_reconstruction_error( indices=autozivae.validation_indices) autozivae.get_marginal_ll(indices=autozivae.validation_indices, n_mc_samples=3) autozivae.get_alphas_betas()
def test_saving_and_loading(save_path): def legacy_save( model, dir_path, prefix=None, overwrite=False, save_anndata=False, **anndata_write_kwargs, ): if not os.path.exists(dir_path) or overwrite: os.makedirs(dir_path, exist_ok=overwrite) else: raise ValueError( "{} already exists. Please provide an unexisting directory for saving." .format(dir_path)) file_name_prefix = prefix or "" if save_anndata: model.adata.write( os.path.join(dir_path, f"{file_name_prefix}adata.h5ad"), **anndata_write_kwargs, ) model_save_path = os.path.join(dir_path, f"{file_name_prefix}model_params.pt") attr_save_path = os.path.join(dir_path, f"{file_name_prefix}attr.pkl") varnames_save_path = os.path.join(dir_path, f"{file_name_prefix}var_names.csv") torch.save(model.module.state_dict(), model_save_path) var_names = model.adata.var_names.astype(str) var_names = var_names.to_numpy() np.savetxt(varnames_save_path, var_names, fmt="%s") # get all the user attributes user_attributes = model._get_user_attributes() # only save the public attributes with _ at the very end user_attributes = { a[0]: a[1] for a in user_attributes if a[0][-1] == "_" } with open(attr_save_path, "wb") as f: pickle.dump(user_attributes, f) def test_save_load_model(cls, adata, save_path, prefix=None, legacy=False): if cls is TOTALVI: cls.setup_anndata( adata, batch_key="batch", labels_key="labels", protein_expression_obsm_key="protein_expression", protein_names_uns_key="protein_names", ) else: cls.setup_anndata(adata, batch_key="batch", labels_key="labels") model = cls(adata, latent_distribution="normal") model.train(1, train_size=0.2) z1 = model.get_latent_representation(adata) test_idx1 = model.validation_indices if legacy: legacy_save(model, save_path, overwrite=True, save_anndata=True, prefix=prefix) else: model.save(save_path, overwrite=True, save_anndata=True, prefix=prefix) model = cls.load(save_path, prefix=prefix) model.get_latent_representation() # Load with mismatched genes. tmp_adata = synthetic_iid(n_genes=200, ) with pytest.raises(ValueError): cls.load(save_path, adata=tmp_adata, prefix=prefix) # Load with different batches. tmp_adata = synthetic_iid() tmp_adata.obs["batch"] = tmp_adata.obs["batch"].cat.rename_categories( ["batch_2", "batch_3"]) with pytest.raises(ValueError): cls.load(save_path, adata=tmp_adata, prefix=prefix) model = cls.load(save_path, adata=adata, prefix=prefix) assert "batch" in model.adata_manager.data_registry assert model.adata_manager.data_registry["batch"] == dict( attr_name="obs", attr_key="_scvi_batch") z2 = model.get_latent_representation() test_idx2 = model.validation_indices np.testing.assert_array_equal(z1, z2) np.testing.assert_array_equal(test_idx1, test_idx2) assert model.is_trained is True save_path = os.path.join(save_path, "tmp") adata = synthetic_iid() for cls in [SCVI, LinearSCVI, TOTALVI, PEAKVI]: test_save_load_model(cls, adata, save_path, prefix=f"{cls.__name__}_", legacy=True) test_save_load_model(cls, adata, save_path, prefix=f"{cls.__name__}_") # Test load prioritizes newer save paradigm and thus mismatches legacy save. with pytest.raises(AssertionError): test_save_load_model(cls, adata, save_path, prefix=f"{cls.__name__}_", legacy=True) # AUTOZI def test_save_load_autozi(legacy=False): prefix = "AUTOZI_" model = AUTOZI(adata, latent_distribution="normal") model.train(1, train_size=0.5) ab1 = model.get_alphas_betas() if legacy: legacy_save(model, save_path, overwrite=True, save_anndata=True, prefix=prefix) else: model.save(save_path, overwrite=True, save_anndata=True, prefix=prefix) model = AUTOZI.load(save_path, prefix=prefix) model.get_latent_representation() tmp_adata = scvi.data.synthetic_iid(n_genes=200) with pytest.raises(ValueError): AUTOZI.load(save_path, adata=tmp_adata, prefix=prefix) model = AUTOZI.load(save_path, adata=adata, prefix=prefix) assert "batch" in model.adata_manager.data_registry assert model.adata_manager.data_registry["batch"] == dict( attr_name="obs", attr_key="_scvi_batch") ab2 = model.get_alphas_betas() np.testing.assert_array_equal(ab1["alpha_posterior"], ab2["alpha_posterior"]) np.testing.assert_array_equal(ab1["beta_posterior"], ab2["beta_posterior"]) assert model.is_trained is True AUTOZI.setup_anndata(adata, batch_key="batch", labels_key="labels") test_save_load_autozi(legacy=True) test_save_load_autozi() # Test load prioritizes newer save paradigm and thus mismatches legacy save. with pytest.raises(AssertionError): test_save_load_autozi(legacy=True) # SCANVI def test_save_load_scanvi(legacy=False): prefix = "SCANVI_" model = SCANVI(adata) model.train(max_epochs=1, train_size=0.5) p1 = model.predict() if legacy: legacy_save(model, save_path, overwrite=True, save_anndata=True, prefix=prefix) else: model.save(save_path, overwrite=True, save_anndata=True, prefix=prefix) model = SCANVI.load(save_path, prefix=prefix) model.get_latent_representation() tmp_adata = scvi.data.synthetic_iid(n_genes=200) with pytest.raises(ValueError): SCANVI.load(save_path, adata=tmp_adata, prefix=prefix) model = SCANVI.load(save_path, adata=adata, prefix=prefix) assert "batch" in model.adata_manager.data_registry assert model.adata_manager.data_registry["batch"] == dict( attr_name="obs", attr_key="_scvi_batch") p2 = model.predict() np.testing.assert_array_equal(p1, p2) assert model.is_trained is True SCANVI.setup_anndata(adata, "label_0", batch_key="batch", labels_key="labels") test_save_load_scanvi(legacy=True) test_save_load_scanvi() # Test load prioritizes newer save paradigm and thus mismatches legacy save. with pytest.raises(AssertionError): test_save_load_scanvi(legacy=True)